]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
fixed read_train_data_params(): always read cat_count & cat_map in case of classifica...
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static const float ord_var_epsilon = FLT_EPSILON*2;
44 static const float ord_nan = FLT_MAX*0.5f;
45 static const int min_block_size = 1 << 16;
46 static const int block_size_delta = 1 << 10;
47
48 CvDTreeTrainData::CvDTreeTrainData()
49 {
50     var_idx = var_type = cat_count = cat_ofs = cat_map =
51         priors = counts = buf = direction = split_buf = 0;
52     tree_storage = temp_storage = 0;
53
54     clear();
55 }
56
57
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
59                       const CvMat* _responses, const CvMat* _var_idx,
60                       const CvMat* _sample_idx, const CvMat* _var_type,
61                       const CvMat* _missing_mask,
62                       CvDTreeParams _params, bool _shared )
63 {
64     var_idx = var_type = cat_count = cat_ofs = cat_map =
65         priors = counts = buf = direction = split_buf = 0;
66     tree_storage = temp_storage = 0;
67     
68     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
69               _var_type, _missing_mask, _params, _shared );
70 }
71
72
73 CvDTreeTrainData::~CvDTreeTrainData()
74 {
75     clear();
76 }
77
78
79 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
80 {
81     bool ok = false;
82     
83     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
84
85     __BEGIN__;
86
87     // set parameters
88     params = _params;
89
90     if( params.max_categories < 2 )
91         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
92     params.max_categories = MIN( params.max_categories, 15 );
93
94     if( params.max_depth < 0 )
95         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
96     params.max_depth = MIN( params.max_depth, 25 );
97
98     params.min_sample_count = MAX(params.min_sample_count,1);
99
100     if( params.cv_folds < 0 )
101         CV_ERROR( CV_StsOutOfRange,
102         "params.cv_folds should be =0 (the tree is not pruned) "
103         "or n>0 (tree is pruned using n-fold cross-validation)" );
104
105     if( params.cv_folds == 1 )
106         params.cv_folds = 0;
107
108     if( params.regression_accuracy < 0 )
109         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
110
111     ok = true;
112
113     __END__;
114
115     return ok;
116 }
117
118
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
122
123 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
124 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
125
126 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
127     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
128     const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
129     bool _shared )
130 {
131     CvMat* sample_idx = 0;
132     CvMat* var_type0 = 0;
133     CvMat* tmp_map = 0;
134     int** int_ptr = 0;
135
136     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
137
138     __BEGIN__;
139
140     int sample_all = 0, r_type = 0, cv_n;
141     int total_c_count = 0;
142     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
144     int vi, i;
145     char err[100];
146     const int *sidx = 0, *vidx = 0;
147
148     clear();
149
150     var_all = 0;
151     rng = cvRNG(-1);
152
153     CV_CALL( set_params( _params ));
154
155     // check parameter types and sizes
156     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
157     if( _tflag == CV_ROW_SAMPLE )
158     {
159         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
160         dv_step = 1;
161         if( _missing_mask )
162             ms_step = _missing_mask->step, mv_step = 1;
163     }
164     else
165     {
166         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
167         ds_step = 1;
168         if( _missing_mask )
169             mv_step = _missing_mask->step, ms_step = 1;
170     }
171
172     sample_count = sample_all;
173     var_count = var_all;
174
175     if( _sample_idx )
176     {
177         CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
178         sidx = sample_idx->data.i;
179         sample_count = sample_idx->rows + sample_idx->cols - 1;
180     }
181
182     if( _var_idx )
183     {
184         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
185         vidx = var_idx->data.i;
186         var_count = var_idx->rows + var_idx->cols - 1;
187     }
188
189     if( !CV_IS_MAT(_responses) ||
190         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
191          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
192         _responses->rows != 1 && _responses->cols != 1 ||
193         _responses->rows + _responses->cols - 1 != sample_all )
194         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
195                   "floating-point vector containing as many elements as "
196                   "the total number of samples in the training data matrix" );
197
198     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
199     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
200
201     cat_var_count = 0;
202     ord_var_count = -1;
203
204     is_classifier = r_type == CV_VAR_CATEGORICAL;
205
206     // step 0. calc the number of categorical vars
207     for( vi = 0; vi < var_count; vi++ )
208     {
209         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
210             cat_var_count++ : ord_var_count--;
211     }
212
213     ord_var_count = ~ord_var_count;
214     cv_n = params.cv_folds;
215     // set the two last elements of var_type array to be able
216     // to locate responses and cross-validation labels using
217     // the corresponding get_* functions.
218     var_type->data.i[var_count] = cat_var_count;
219     var_type->data.i[var_count+1] = cat_var_count+1;
220
221     // in case of single ordered predictor we need dummy cv_labels
222     // for safe split_node_data() operation
223     have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
224
225     buf_size = (ord_var_count*2 + cat_var_count + 1 +
226         (have_cv_labels ? 1 : 0))*sample_count + 2;
227     shared = _shared;
228     buf_count = shared ? 3 : 2;
229     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
230     CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
231     CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
232     CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
233
234     // now calculate the maximum size of split,
235     // create memory storage that will keep nodes and splits of the decision tree
236     // allocate root node and the buffer for the whole training data
237     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
238         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
239     tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
240     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
241     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
242     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
243
244     temp_block_size = nv_size = var_count*sizeof(int);
245     if( cv_n )
246     {
247         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
248             CV_ERROR( CV_StsOutOfRange,
249                 "The many folds in cross-validation for such a small dataset" );
250
251         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
252         temp_block_size = MAX(temp_block_size, cv_size);
253     }
254
255     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
256     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
257     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
258     if( cv_size )
259         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
260
261     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
262     CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
263
264     max_c_count = 1;
265
266     // transform the training data to convenient representation
267     for( vi = 0; vi <= var_count; vi++ )
268     {
269         int ci;
270         const uchar* mask = 0;
271         int m_step = 0, step;
272         const int* idata = 0;
273         const float* fdata = 0;
274         int num_valid = 0;
275
276         if( vi < var_count ) // analyze i-th input variable
277         {
278             int vi0 = vidx ? vidx[vi] : vi;
279             ci = get_var_type(vi);
280             step = ds_step; m_step = ms_step;
281             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
282                 idata = _train_data->data.i + vi0*dv_step;
283             else
284                 fdata = _train_data->data.fl + vi0*dv_step;
285             if( _missing_mask )
286                 mask = _missing_mask->data.ptr + vi0*mv_step;
287         }
288         else // analyze _responses
289         {
290             ci = cat_var_count;
291             step = CV_IS_MAT_CONT(_responses->type) ?
292                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
293             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
294                 idata = _responses->data.i;
295             else
296                 fdata = _responses->data.fl;
297         }
298
299         if( vi < var_count && ci >= 0 ||
300             vi == var_count && is_classifier ) // process categorical variable or response
301         {
302             int c_count, prev_label, prev_i;
303             int* c_map, *dst = get_cat_var_data( data_root, vi );
304
305             // copy data
306             for( i = 0; i < sample_count; i++ )
307             {
308                 int val = INT_MAX, si = sidx ? sidx[i] : i;
309                 if( !mask || !mask[si*m_step] )
310                 {
311                     if( idata )
312                         val = idata[si*step];
313                     else
314                     {
315                         float t = fdata[si*step];
316                         val = cvRound(t);
317                         if( val != t )
318                         {
319                             sprintf( err, "%d-th value of %d-th (categorical) "
320                                 "variable is not an integer", i, vi );
321                             CV_ERROR( CV_StsBadArg, err );
322                         }
323                     }
324
325                     if( val == INT_MAX )
326                     {
327                         sprintf( err, "%d-th value of %d-th (categorical) "
328                             "variable is too large", i, vi );
329                         CV_ERROR( CV_StsBadArg, err );
330                     }
331                     num_valid++;
332                 }
333                 dst[i] = val;
334                 int_ptr[i] = dst + i;
335             }
336
337             // sort all the values, including the missing measurements
338             // that should all move to the end
339             icvSortIntPtr( int_ptr, sample_count, 0 );
340             //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
341
342             c_count = num_valid > 0;
343
344             // count the categories
345             for( i = 1; i < num_valid; i++ )
346                 c_count += *int_ptr[i] != *int_ptr[i-1];
347
348             if( vi > 0 )
349                 max_c_count = MAX( max_c_count, c_count );
350             cat_count->data.i[ci] = c_count;
351             cat_ofs->data.i[ci] = total_c_count;
352
353             // resize cat_map, if need
354             if( cat_map->cols < total_c_count + c_count )
355             {
356                 tmp_map = cat_map;
357                 CV_CALL( cat_map = cvCreateMat( 1,
358                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
359                 for( i = 0; i < total_c_count; i++ )
360                     cat_map->data.i[i] = tmp_map->data.i[i];
361                 cvReleaseMat( &tmp_map );
362             }
363
364             c_map = cat_map->data.i + total_c_count;
365             total_c_count += c_count;
366
367             // compact the class indices and build the map
368             prev_label = ~*int_ptr[0];
369             c_count = -1;
370
371             for( i = 0, prev_i = -1; i < num_valid; i++ )
372             {
373                 int cur_label = *int_ptr[i];
374                 if( cur_label != prev_label )
375                 {
376                     c_map[++c_count] = prev_label = cur_label;
377                     prev_i = i;
378                 }
379                 *int_ptr[i] = c_count;
380             }
381
382             // replace labels for missing values with -1
383             for( ; i < sample_count; i++ )
384                 *int_ptr[i] = -1;
385         }
386         else if( ci < 0 ) // process ordered variable
387         {
388             CvPair32s32f* dst = get_ord_var_data( data_root, vi );
389
390             for( i = 0; i < sample_count; i++ )
391             {
392                 float val = ord_nan;
393                 int si = sidx ? sidx[i] : i;
394                 if( !mask || !mask[si*m_step] )
395                 {
396                     if( idata )
397                         val = (float)idata[si*step];
398                     else
399                         val = fdata[si*step];
400
401                     if( fabs(val) >= ord_nan )
402                     {
403                         sprintf( err, "%d-th value of %d-th (ordered) "
404                             "variable (=%g) is too large", i, vi, val );
405                         CV_ERROR( CV_StsBadArg, err );
406                     }
407                     num_valid++;
408                 }
409                 dst[i].i = i;
410                 dst[i].val = val;
411             }
412
413             icvSortPairs( dst, sample_count, 0 );
414         }
415         else // special case: process ordered response,
416              // it will be stored similarly to categorical vars (i.e. no pairs)
417         {
418             float* dst = get_ord_responses( data_root );
419
420             for( i = 0; i < sample_count; i++ )
421             {
422                 float val = ord_nan;
423                 int si = sidx ? sidx[i] : i;
424                 if( idata )
425                     val = (float)idata[si*step];
426                 else
427                     val = fdata[si*step];
428
429                 if( fabs(val) >= ord_nan )
430                 {
431                     sprintf( err, "%d-th value of %d-th (ordered) "
432                         "variable (=%g) is out of range", i, vi, val );
433                     CV_ERROR( CV_StsBadArg, err );
434                 }
435                 dst[i] = val;
436             }
437
438             cat_count->data.i[cat_var_count] = 0;
439             cat_ofs->data.i[cat_var_count] = total_c_count;
440             num_valid = sample_count;
441         }
442
443         if( vi < var_count )
444             data_root->num_valid[vi] = num_valid;
445     }
446
447     if( cv_n )
448     {
449         int* dst = get_cv_labels(data_root);
450         CvRNG* r = &rng;
451
452         for( i = vi = 0; i < sample_count; i++ )
453         {
454             dst[i] = vi++;
455             vi &= vi < cv_n ? -1 : 0;
456         }
457
458         for( i = 0; i < sample_count; i++ )
459         {
460             int a = cvRandInt(r) % sample_count;
461             int b = cvRandInt(r) % sample_count;
462             CV_SWAP( dst[a], dst[b], vi );
463         }
464     }
465
466     cat_map->cols = MAX( total_c_count, 1 );
467
468     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
469         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
470     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
471
472     have_priors = is_classifier && params.priors;
473     if( is_classifier )
474     {
475         int m = get_num_classes(), rows = 4;
476         double sum = 0;
477         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
478         for( i = 0; i < m; i++ )
479         {
480             double val = have_priors ? params.priors[i] : 1.;
481             if( val <= 0 )
482                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
483             priors->data.db[i] = val;
484             sum += val;
485         }
486         // normalize weights
487         if( have_priors )
488             cvScale( priors, priors, 1./sum );
489
490         if( cat_var_count > 0 || params.cv_folds > 0 )
491         {
492             // need storage for cjk (see find_split_cat_gini) and risks/errors
493             rows += MAX( max_c_count, params.cv_folds ) + 1;
494             // add buffer for k-means clustering
495             if( m > 2 && max_c_count > params.max_categories )
496                 rows += params.max_categories + (max_c_count+m-1)/m;
497         }
498
499         CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
500     }
501
502     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
503     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
504
505     __END__;
506
507     cvFree( &int_ptr );
508     cvReleaseMat( &sample_idx );
509     cvReleaseMat( &var_type0 );
510     cvReleaseMat( &tmp_map );
511 }
512
513
514 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
515 {
516     CvDTreeNode* root = 0;
517     CvMat* isubsample_idx = 0;
518     CvMat* subsample_co = 0;
519     
520     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
521
522     __BEGIN__;
523
524     if( !data_root )
525         CV_ERROR( CV_StsError, "No training data has been set" );
526     
527     if( _subsample_idx )
528         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
529
530     if( !isubsample_idx )
531     {
532         // make a copy of the root node
533         CvDTreeNode temp;
534         int i;
535         root = new_node( 0, 1, 0, 0 );
536         temp = *root;
537         *root = *data_root;
538         root->num_valid = temp.num_valid;
539         for( i = 0; i < var_count; i++ )
540             root->num_valid[i] = data_root->num_valid[i];
541         root->cv_Tn = temp.cv_Tn;
542         root->cv_node_risk = temp.cv_node_risk;
543         root->cv_node_error = temp.cv_node_error;
544     }
545     else
546     {
547         int* sidx = isubsample_idx->data.i;
548         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
549         int* co, cur_ofs = 0;
550         int vi, i, total = data_root->sample_count;
551         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
552         root = new_node( 0, count, 1, 0 );
553
554         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
555         cvZero( subsample_co );
556         co = subsample_co->data.i;
557         for( i = 0; i < count; i++ )
558             co[sidx[i]*2]++;
559         for( i = 0; i < total; i++ )
560         {
561             if( co[i*2] )
562             {
563                 co[i*2+1] = cur_ofs;
564                 cur_ofs += co[i*2];
565             }
566             else
567                 co[i*2+1] = -1;
568         }
569
570         for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
571         {
572             int ci = get_var_type(vi);
573
574             if( ci >= 0 || vi >= var_count )
575             {
576                 const int* src = get_cat_var_data( data_root, vi );
577                 int* dst = get_cat_var_data( root, vi );
578                 int num_valid = 0;
579
580                 for( i = 0; i < count; i++ )
581                 {
582                     int val = src[sidx[i]];
583                     dst[i] = val;
584                     num_valid += val >= 0;
585                 }
586
587                 if( vi < var_count )
588                     root->num_valid[vi] = num_valid;
589             }
590             else
591             {
592                 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
593                 CvPair32s32f* dst = get_ord_var_data( root, vi );
594                 int j = 0, num_valid = data_root->num_valid[vi], idx, count_i;
595
596                 for( i = 0; i < num_valid; i++ )
597                 {
598                     idx = src[i].i;
599                     count_i = co[idx*2];
600                     if( count_i )
601                     {
602                         float val = src[i].val;
603                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
604                         {
605                             dst[j].val = val;
606                             dst[j].i = cur_ofs;
607                         }
608                     }
609                 }
610
611                 root->num_valid[vi] = j;
612
613                 for( ; i < total; i++ )
614                 {
615                     idx = src[i].i;
616                     count_i = co[idx*2];
617                     if( count_i )
618                     {
619                         float val = src[i].val;
620                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
621                         {
622                             dst[j].val = val;
623                             dst[j].i = cur_ofs;
624                         }
625                     }
626                 }
627             }
628         }
629     }
630
631     __END__;
632
633     cvReleaseMat( &isubsample_idx );
634     cvReleaseMat( &subsample_co );
635
636     return root;
637 }
638
639
640 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
641                                     float* values, uchar* missing,
642                                     float* responses, bool get_class_idx )
643 {
644     CvMat* subsample_idx = 0;
645     CvMat* subsample_co = 0;
646     
647     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
648
649     __BEGIN__;
650
651     int i, vi, total = sample_count, count = total, cur_ofs = 0;
652     int* sidx = 0;
653     int* co = 0;
654
655     if( _subsample_idx )
656     {
657         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
658         sidx = subsample_idx->data.i;
659         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
660         co = subsample_co->data.i;
661         cvZero( subsample_co );
662         count = subsample_idx->cols + subsample_idx->rows - 1;
663         for( i = 0; i < count; i++ )
664             co[sidx[i]*2]++;
665         for( i = 0; i < total; i++ )
666         {
667             int count_i = co[i*2];
668             if( count_i )
669             {
670                 co[i*2+1] = cur_ofs*var_count;
671                 cur_ofs += count_i;
672             }
673         }
674     }
675
676     memset( missing, 1, count*var_count );
677
678     for( vi = 0; vi < var_count; vi++ )
679     {
680         int ci = get_var_type(vi);
681         if( ci >= 0 ) // categorical
682         {
683             float* dst = values + vi;
684             uchar* m = missing + vi;
685             const int* src = get_cat_var_data(data_root, vi);
686
687             for( i = 0; i < count; i++, dst += var_count, m += var_count )
688             {
689                 int idx = sidx ? sidx[i] : i;
690                 int val = src[idx];
691                 *dst = (float)val;
692                 *m = val < 0;
693             }
694         }
695         else // ordered
696         {
697             float* dst = values + vi;
698             uchar* m = missing + vi;
699             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
700             int count1 = data_root->num_valid[vi];
701
702             for( i = 0; i < count1; i++ )
703             {
704                 int idx = src[i].i;
705                 int count_i = 1;
706                 if( co )
707                 {
708                     count_i = co[idx*2];
709                     cur_ofs = co[idx*2+1];
710                 }
711                 else
712                     cur_ofs = idx*var_count;
713                 if( count_i )
714                 {
715                     float val = src[i].val;
716                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
717                     {
718                         dst[cur_ofs] = val;
719                         m[cur_ofs] = 0;
720                     }
721                 }
722             }
723         }
724     }
725
726     // copy responses
727     if( is_classifier )
728     {
729         const int* src = get_class_labels(data_root);
730         for( i = 0; i < count; i++ )
731         {
732             int idx = sidx ? sidx[i] : i;
733             int val = get_class_idx ? src[idx] :
734                 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
735             responses[i] = (float)val;
736         }
737     }
738     else
739     {
740         const float* src = get_ord_responses(data_root);
741         for( i = 0; i < count; i++ )
742         {
743             int idx = sidx ? sidx[i] : i;
744             responses[i] = src[idx];
745         }
746     }
747
748     __END__;
749
750     cvReleaseMat( &subsample_idx );
751     cvReleaseMat( &subsample_co );
752 }
753
754
755 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
756                                          int storage_idx, int offset )
757 {
758     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
759
760     node->sample_count = count;
761     node->depth = parent ? parent->depth + 1 : 0;
762     node->parent = parent;
763     node->left = node->right = 0;
764     node->split = 0;
765     node->value = 0;
766     node->class_idx = -1;
767     node->maxlr = 0.;
768
769     node->buf_idx = storage_idx;
770     node->offset = offset;
771     if( nv_heap )
772         node->num_valid = (int*)cvSetNew( nv_heap );
773     else
774         node->num_valid = 0;
775     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
776     node->Tn = INT_MAX;
777     node->complexity = 0;
778
779     if( params.cv_folds > 0 && cv_heap )
780     {
781         int cv_n = params.cv_folds;
782         node->cv_Tn = (int*)cvSetNew( cv_heap );
783         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
784         node->cv_node_error = node->cv_node_risk + cv_n;
785     }
786     else
787     {
788         node->cv_Tn = 0;
789         node->cv_node_risk = 0;
790         node->cv_node_error = 0;
791     }
792
793     return node;
794 }
795
796
797 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
798                 int split_point, int inversed, float quality )
799 {
800     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
801     split->var_idx = vi;
802     split->ord.c = cmp_val;
803     split->ord.split_point = split_point;
804     split->inversed = inversed;
805     split->quality = quality;
806     split->next = 0;
807
808     return split;
809 }
810
811
812 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
813 {
814     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
815     int i, n = (max_c_count + 31)/32;
816
817     split->var_idx = vi;
818     split->inversed = 0;
819     split->quality = quality;
820     for( i = 0; i < n; i++ )
821         split->subset[i] = 0;
822     split->next = 0;
823
824     return split;
825 }
826
827
828 void CvDTreeTrainData::free_node( CvDTreeNode* node )
829 {
830     CvDTreeSplit* split = node->split;
831     free_node_data( node );
832     while( split )
833     {
834         CvDTreeSplit* next = split->next;
835         cvSetRemoveByPtr( split_heap, split );
836         split = next;
837     }
838     node->split = 0;
839     cvSetRemoveByPtr( node_heap, node );
840 }
841
842
843 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
844 {
845     if( node->num_valid )
846     {
847         cvSetRemoveByPtr( nv_heap, node->num_valid );
848         node->num_valid = 0;
849     }
850     // do not free cv_* fields, as all the cross-validation related data is released at once.
851 }
852
853
854 void CvDTreeTrainData::free_train_data()
855 {
856     cvReleaseMat( &counts );
857     cvReleaseMat( &buf );
858     cvReleaseMat( &direction );
859     cvReleaseMat( &split_buf );
860     cvReleaseMemStorage( &temp_storage );
861     cv_heap = nv_heap = 0;
862 }
863
864
865 void CvDTreeTrainData::clear()
866 {
867     free_train_data();
868
869     cvReleaseMemStorage( &tree_storage );
870
871     cvReleaseMat( &var_idx );
872     cvReleaseMat( &var_type );
873     cvReleaseMat( &cat_count );
874     cvReleaseMat( &cat_ofs );
875     cvReleaseMat( &cat_map );
876     cvReleaseMat( &priors );
877
878     node_heap = split_heap = 0;
879
880     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
881     have_cv_labels = have_priors = is_classifier = false;
882
883     buf_count = buf_size = 0;
884     shared = false;
885
886     data_root = 0;
887
888     rng = cvRNG(-1);
889 }
890
891
892 int CvDTreeTrainData::get_num_classes() const
893 {
894     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
895 }
896
897
898 int CvDTreeTrainData::get_var_type(int vi) const
899 {
900     return var_type->data.i[vi];
901 }
902
903
904 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
905 {
906     int oi = ~get_var_type(vi);
907     assert( 0 <= oi && oi < ord_var_count );
908     return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
909                            n->offset + oi*n->sample_count*2);
910 }
911
912
913 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
914 {
915     return get_cat_var_data( n, var_count );
916 }
917
918
919 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
920 {
921     return (float*)get_cat_var_data( n, var_count );
922 }
923
924
925 int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
926 {
927     return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
928 }
929
930
931 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
932 {
933     int ci = get_var_type(vi);
934     assert( 0 <= ci && ci <= cat_var_count + 1 );
935     return buf->data.i + n->buf_idx*buf->cols + n->offset +
936            (ord_var_count*2 + ci)*n->sample_count;
937 }
938
939
940 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
941 {
942     int idx = n->buf_idx + 1;
943     if( idx >= buf_count )
944         idx = shared ? 1 : 0;
945     return idx;
946 }
947
948 /////////////////////// Decision Tree /////////////////////////
949
950 CvDTree::CvDTree()
951 {
952     data = 0;
953     var_importance = 0;
954     clear();
955 }
956
957
958 void CvDTree::clear()
959 {
960     cvReleaseMat( &var_importance );
961     if( data )
962     {
963         if( !data->shared )
964             delete data;
965         else
966             free_tree();
967         data = 0;
968     }
969     root = 0;
970     pruned_tree_idx = -1;
971 }
972
973
974 CvDTree::~CvDTree()
975 {
976     clear();
977 }
978
979
980 bool CvDTree::train( const CvMat* _train_data, int _tflag,
981                      const CvMat* _responses, const CvMat* _var_idx,
982                      const CvMat* _sample_idx, const CvMat* _var_type,
983                      const CvMat* _missing_mask, CvDTreeParams _params )
984 {
985     bool result = false;
986
987     CV_FUNCNAME( "CvDTree::train" );
988
989     __BEGIN__;
990
991     clear();
992     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
993                                  _var_idx, _sample_idx, _var_type,
994                                  _missing_mask, _params, false );
995     CV_CALL( result = do_train(0));
996
997     __END__;
998
999     return result;
1000 }
1001
1002
1003 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1004 {
1005     bool result = false;
1006
1007     CV_FUNCNAME( "CvDTree::train" );
1008
1009     __BEGIN__;
1010
1011     clear();
1012     data = _data;
1013     data->shared = true;
1014     CV_CALL( result = do_train(_subsample_idx));
1015
1016     __END__;
1017
1018     return result;
1019 }
1020
1021
1022 bool CvDTree::do_train( const CvMat* _subsample_idx )
1023 {
1024     bool result = false;
1025
1026     CV_FUNCNAME( "CvDTree::train" );
1027
1028     __BEGIN__;
1029
1030     root = data->subsample_data( _subsample_idx );
1031
1032     try_split_node(root);
1033     
1034     if( data->params.cv_folds > 0 )
1035         prune_cv();
1036
1037     if( !data->shared )
1038         data->free_train_data();
1039
1040     result = true;
1041
1042     __END__;
1043
1044     return result;
1045 }
1046
1047
1048 #define DTREE_CAT_DIR(idx,subset) \
1049     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
1050
1051 void CvDTree::try_split_node( CvDTreeNode* node )
1052 {
1053     CvDTreeSplit* best_split = 0;
1054     int i, n = node->sample_count, vi;
1055     bool can_split = true;
1056     double quality_scale;
1057
1058     calc_node_value( node );
1059
1060     if( node->sample_count <= data->params.min_sample_count ||
1061         node->depth >= data->params.max_depth )
1062         can_split = false;
1063
1064     if( can_split && data->is_classifier )
1065     {
1066         // check if we have a "pure" node,
1067         // we assume that cls_count is filled by calc_node_value()
1068         int* cls_count = data->counts->data.i;
1069         int nz = 0, m = data->get_num_classes();
1070         for( i = 0; i < m; i++ )
1071             nz += cls_count[i] != 0;
1072         if( nz == 1 ) // there is only one class
1073             can_split = false;
1074     }
1075     else if( can_split )
1076     {
1077         const float* responses = data->get_ord_responses( node );
1078         float diff = responses[n-1] - responses[0];
1079         if( diff < data->params.regression_accuracy )
1080             can_split = false;
1081     }
1082
1083     if( can_split )
1084     {
1085         best_split = find_best_split(node);
1086         // TODO: check the split quality ...
1087         node->split = best_split;
1088     }
1089
1090     if( !can_split || !best_split )
1091     {
1092         data->free_node_data(node);
1093         return;
1094     }
1095
1096     quality_scale = calc_node_dir( node );
1097
1098     if( data->params.use_surrogates )
1099     {
1100         // find all the surrogate splits
1101         // and sort them by their similarity to the primary one
1102         for( vi = 0; vi < data->var_count; vi++ )
1103         {
1104             CvDTreeSplit* split;
1105             int ci = data->get_var_type(vi);
1106
1107             if( vi == best_split->var_idx )
1108                 continue;
1109
1110             if( ci >= 0 )
1111                 split = find_surrogate_split_cat( node, vi );
1112             else
1113                 split = find_surrogate_split_ord( node, vi );
1114
1115             if( split )
1116             {
1117                 // insert the split
1118                 CvDTreeSplit* prev_split = node->split;
1119                 split->quality = (float)(split->quality*quality_scale);
1120
1121                 while( prev_split->next &&
1122                        prev_split->next->quality > split->quality )
1123                     prev_split = prev_split->next;
1124                 split->next = prev_split->next;
1125                 prev_split->next = split;
1126             }
1127         }
1128     }
1129
1130     split_node_data( node );
1131     try_split_node( node->left );
1132     try_split_node( node->right );
1133 }
1134
1135
1136 // calculate direction (left(-1),right(1),missing(0))
1137 // for each sample using the best split
1138 // the function returns scale coefficients for surrogate split quality factors.
1139 // the scale is applied to normalize surrogate split quality relatively to the
1140 // best (primary) split quality. That is, if a surrogate split is absolutely
1141 // identical to the primary split, its quality will be set to the maximum value = 
1142 // quality of the primary split; otherwise, it will be lower.
1143 // besides, the function compute node->maxlr,
1144 // minimum possible quality (w/o considering the above mentioned scale)
1145 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1146 // are not discarded.
1147 double CvDTree::calc_node_dir( CvDTreeNode* node )
1148 {
1149     char* dir = (char*)data->direction->data.ptr;
1150     int i, n = node->sample_count, vi = node->split->var_idx;
1151     double L, R;
1152
1153     assert( !node->split->inversed );
1154
1155     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1156     {
1157         const int* labels = data->get_cat_var_data(node,vi);
1158         const int* subset = node->split->subset;
1159
1160         if( !data->have_priors )
1161         {
1162             int sum = 0, sum_abs = 0;
1163
1164             for( i = 0; i < n; i++ )
1165             {
1166                 int idx = labels[i];
1167                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1168                 sum += d; sum_abs += d & 1;
1169                 dir[i] = (char)d;
1170             }
1171
1172             R = (sum_abs + sum) >> 1;
1173             L = (sum_abs - sum) >> 1;
1174         }
1175         else
1176         {
1177             const int* responses = data->get_class_labels(node);
1178             const double* priors = data->priors->data.db;
1179             double sum = 0, sum_abs = 0;
1180
1181             for( i = 0; i < n; i++ )
1182             {
1183                 int idx = labels[i];
1184                 double w = priors[responses[i]];
1185                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1186                 sum += d*w; sum_abs += (d & 1)*w;
1187                 dir[i] = (char)d;
1188             }
1189
1190             R = (sum_abs + sum) * 0.5;
1191             L = (sum_abs - sum) * 0.5;
1192         }
1193     }
1194     else // split on ordered var
1195     {
1196         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1197         int split_point = node->split->ord.split_point;
1198         int n1 = node->num_valid[vi];
1199
1200         assert( 0 <= split_point && split_point < n1-1 );
1201
1202         if( !data->have_priors )
1203         {
1204             for( i = 0; i <= split_point; i++ )
1205                 dir[sorted[i].i] = (char)-1;
1206             for( ; i < n1; i++ )
1207                 dir[sorted[i].i] = (char)1;
1208             for( ; i < n; i++ )
1209                 dir[sorted[i].i] = (char)0;
1210
1211             L = split_point-1;
1212             R = n1 - split_point + 1;
1213         }
1214         else
1215         {
1216             const int* responses = data->get_class_labels(node);
1217             const double* priors = data->priors->data.db;
1218             L = R = 0;
1219
1220             for( i = 0; i <= split_point; i++ )
1221             {
1222                 int idx = sorted[i].i;
1223                 double w = priors[responses[idx]];
1224                 dir[idx] = (char)-1;
1225                 L += w;
1226             }
1227
1228             for( ; i < n1; i++ )
1229             {
1230                 int idx = sorted[i].i;
1231                 double w = priors[responses[idx]];
1232                 dir[idx] = (char)1;
1233                 R += w;
1234             }
1235
1236             for( ; i < n; i++ )
1237                 dir[sorted[i].i] = (char)0;
1238         }
1239     }
1240
1241     node->maxlr = MAX( L, R );
1242     return node->split->quality/(L + R);
1243 }
1244
1245
1246 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1247 {
1248     int vi;
1249     CvDTreeSplit *best_split = 0, *split = 0, *t;
1250
1251     for( vi = 0; vi < data->var_count; vi++ )
1252     {
1253         int ci = data->get_var_type(vi);
1254         if( node->num_valid[vi] <= 1 )
1255             continue;
1256
1257         if( data->is_classifier )
1258         {
1259             if( ci >= 0 )
1260                 split = find_split_cat_gini( node, vi );
1261             else
1262                 split = find_split_ord_gini( node, vi );
1263         }
1264         else
1265         {
1266             if( ci >= 0 )
1267                 split = find_split_cat_reg( node, vi );
1268             else
1269                 split = find_split_ord_reg( node, vi );
1270         }
1271
1272         if( split )
1273         {
1274             if( !best_split || best_split->quality < split->quality )
1275                 CV_SWAP( best_split, split, t );
1276             if( split )
1277                 cvSetRemoveByPtr( data->split_heap, split );
1278         }
1279     }
1280
1281     return best_split;
1282 }
1283
1284
1285 CvDTreeSplit* CvDTree::find_split_ord_gini( CvDTreeNode* node, int vi )
1286 {
1287     const float epsilon = FLT_EPSILON*2;
1288     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1289     const int* responses = data->get_class_labels(node);
1290     int n = node->sample_count;
1291     int n1 = node->num_valid[vi];
1292     int m = data->get_num_classes();
1293     const int* rc0 = data->counts->data.i;
1294     int* lc = (int*)(rc0 + m);
1295     int* rc = lc + m;
1296     int i, best_i = -1;
1297     double lsum2 = 0, rsum2 = 0, best_val = 0;
1298     const double* priors = data->have_priors ? data->priors->data.db : 0;
1299
1300     // init arrays of class instance counters on both sides of the split
1301     for( i = 0; i < m; i++ )
1302     {
1303         lc[i] = 0;
1304         rc[i] = rc0[i];
1305     }
1306
1307     // compensate for missing values
1308     for( i = n1; i < n; i++ )
1309         rc[responses[sorted[i].i]]--;
1310
1311     if( !priors )
1312     {
1313         int L = 0, R = n1;
1314
1315         for( i = 0; i < m; i++ )
1316             rsum2 += (double)rc[i]*rc[i];
1317
1318         for( i = 0; i < n1 - 1; i++ )
1319         {
1320             int idx = responses[sorted[i].i];
1321             int lv, rv;
1322             L++; R--;
1323             lv = lc[idx]; rv = rc[idx];
1324             lsum2 += lv*2 + 1;
1325             rsum2 -= rv*2 - 1;
1326             lc[idx] = lv + 1; rc[idx] = rv - 1;
1327
1328             if( sorted[i].val + epsilon < sorted[i+1].val )
1329             {
1330                 double val = lsum2/L + rsum2/R;
1331                 if( best_val < val )
1332                 {
1333                     best_val = val;
1334                     best_i = i;
1335                 }
1336             }
1337         }
1338     }
1339     else
1340     {
1341         double L = 0, R = 0;
1342         for( i = 0; i < m; i++ )
1343         {
1344             double wv = rc[i]*priors[i];
1345             R += wv;
1346             rsum2 += wv*wv;
1347         }
1348
1349         for( i = 0; i < n1 - 1; i++ )
1350         {
1351             int idx = responses[sorted[i].i];
1352             int lv, rv;
1353             double p = priors[idx], p2 = p*p;
1354             L += p; R -= p;
1355             lv = lc[idx]; rv = rc[idx];
1356             lsum2 += p2*(lv*2 + 1);
1357             rsum2 -= p2*(rv*2 - 1);
1358             lc[idx] = lv + 1; rc[idx] = rv - 1;
1359
1360             if( sorted[i].val + epsilon < sorted[i+1].val )
1361             {
1362                 double val = lsum2/L + rsum2/R;
1363                 if( best_val < val )
1364                 {
1365                     best_val = val;
1366                     best_i = i;
1367                 }
1368             }
1369         }
1370     }
1371
1372     return best_i >= 0 ? data->new_split_ord( vi,
1373         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1374         0, (float)best_val ) : 0;
1375 }
1376
1377
1378 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1379                                 int* csums, int k, int* labels )
1380 {
1381     // TODO: consider adding priors (class weights) to the algorithm
1382     int iters = 0, max_iters = 100;
1383     int i, j, idx;
1384     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1385     double *v_weights = buf, *c_weights = buf + k;
1386     bool modified = true;
1387     CvRNG* r = &data->rng;
1388
1389     // assign labels randomly
1390     for( i = idx = 0; i < n; i++ )
1391     {
1392         int sum = 0;
1393         const int* v = vectors + i*m;
1394         labels[i] = idx++;
1395         idx &= idx < k ? -1 : 0;
1396
1397         // compute weight of each vector
1398         for( j = 0; j < m; j++ )
1399             sum += v[j];
1400         v_weights[i] = sum ? 1./sum : 0.;
1401     }
1402
1403     for( i = 0; i < n; i++ )
1404     {
1405         int i1 = cvRandInt(r) % n;
1406         int i2 = cvRandInt(r) % n;
1407         CV_SWAP( labels[i1], labels[i2], j );
1408     }
1409
1410     for( iters = 0; iters <= max_iters; iters++ )
1411     {
1412         // calculate csums
1413         for( i = 0; i < k; i++ )
1414         {
1415             for( j = 0; j < m; j++ )
1416                 csums[i*m + j] = 0;
1417         }
1418
1419         for( i = 0; i < n; i++ )
1420         {
1421             const int* v = vectors + i*m;
1422             int* s = csums + labels[i]*m;
1423             for( j = 0; j < m; j++ )
1424                 s[j] += v[j];
1425         }
1426
1427         // exit the loop here, when we have up-to-date csums
1428         if( iters == max_iters || !modified )
1429             break;
1430
1431         modified = false;
1432
1433         // calculate weight of each cluster
1434         for( i = 0; i < k; i++ )
1435         {
1436             const int* s = csums + i*m;
1437             int sum = 0;
1438             for( j = 0; j < m; j++ )
1439                 sum += s[j];
1440             c_weights[i] = sum ? 1./sum : 0;
1441         }
1442
1443         // now for each vector determine the closest cluster
1444         for( i = 0; i < n; i++ )
1445         {
1446             const int* v = vectors + i*m;
1447             double alpha = v_weights[i];
1448             double min_dist2 = DBL_MAX;
1449             int min_idx = -1;
1450
1451             for( idx = 0; idx < k; idx++ )
1452             {
1453                 const int* s = csums + idx*m;
1454                 double dist2 = 0., beta = c_weights[idx];
1455                 for( j = 0; j < m; j++ )
1456                 {
1457                     double t = v[j]*alpha - s[j]*beta;
1458                     dist2 += t*t;
1459                 }
1460                 if( min_dist2 > dist2 )
1461                 {
1462                     min_dist2 = dist2;
1463                     min_idx = idx;
1464                 }
1465             }
1466
1467             if( min_idx != labels[i] )
1468                 modified = true;
1469             labels[i] = min_idx;
1470         }
1471     }
1472 }
1473
1474
1475 CvDTreeSplit* CvDTree::find_split_cat_gini( CvDTreeNode* node, int vi )
1476 {
1477     CvDTreeSplit* split;
1478     const int* labels = data->get_cat_var_data(node, vi);
1479     const int* responses = data->get_class_labels(node);
1480     int ci = data->get_var_type(vi);
1481     int n = node->sample_count;
1482     int m = data->get_num_classes();
1483     int _mi = data->cat_count->data.i[ci], mi = _mi;
1484     const int* rc0 = data->counts->data.i;
1485     int* lc = (int*)(rc0 + m);
1486     int* rc = lc + m;
1487     int* _cjk = rc + m*2, *cjk = _cjk;
1488     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1489     int* cluster_labels = 0;
1490     int** int_ptr = 0;
1491     int i, j, k, idx;
1492     double L = 0, R = 0;
1493     double best_val = 0;
1494     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1495     const double* priors = data->priors->data.db;
1496
1497     // init array of counters:
1498     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1499     for( j = -1; j < mi; j++ )
1500         for( k = 0; k < m; k++ )
1501             cjk[j*m + k] = 0;
1502
1503     for( i = 0; i < n; i++ )
1504     {
1505         int j = labels[i];
1506         int k = responses[i];
1507         cjk[j*m + k]++;
1508     }
1509
1510     if( m > 2 )
1511     {
1512         if( mi > data->params.max_categories )
1513         {
1514             mi = MIN(data->params.max_categories, n);
1515             cjk += _mi*m;
1516             cluster_labels = cjk + mi*m;
1517             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1518         }
1519         subset_i = 1;
1520         subset_n = 1 << mi;
1521     }
1522     else
1523     {
1524         assert( m == 2 );
1525         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1526         for( j = 0; j < mi; j++ )
1527             int_ptr[j] = cjk + j*2 + 1;
1528         icvSortIntPtr( int_ptr, mi, 0 );
1529         subset_i = 0;
1530         subset_n = mi;
1531     }
1532
1533     for( k = 0; k < m; k++ )
1534     {
1535         int sum = 0;
1536         for( j = 0; j < mi; j++ )
1537             sum += cjk[j*m + k];
1538         rc[k] = sum;
1539         lc[k] = 0;
1540     }
1541
1542     for( j = 0; j < mi; j++ )
1543     {
1544         double sum = 0;
1545         for( k = 0; k < m; k++ )
1546             sum += cjk[j*m + k]*priors[k];
1547         c_weights[j] = sum;
1548         R += c_weights[j];
1549     }
1550
1551     for( ; subset_i < subset_n; subset_i++ )
1552     {
1553         double weight;
1554         int* crow;
1555         double lsum2 = 0, rsum2 = 0;
1556
1557         if( m == 2 )
1558             idx = (int_ptr[subset_i] - cjk)/2;
1559         else
1560         {
1561             int graycode = (subset_i>>1)^subset_i;
1562             int diff = graycode ^ prevcode;
1563
1564             // determine index of the changed bit.
1565             Cv32suf u;
1566             idx = diff >= (1 << 16) ? 16 : 0;
1567             u.f = (float)(((diff >> 16) | diff) & 65535);
1568             idx += (u.i >> 23) - 127;
1569             subtract = graycode < prevcode;
1570             prevcode = graycode;
1571         }
1572
1573         crow = cjk + idx*m;
1574         weight = c_weights[idx];
1575         if( weight < FLT_EPSILON )
1576             continue;
1577
1578         if( !subtract )
1579         {
1580             for( k = 0; k < m; k++ )
1581             {
1582                 int t = crow[k];
1583                 int lval = lc[k] + t;
1584                 int rval = rc[k] - t;
1585                 double p = priors[k], p2 = p*p;
1586                 lsum2 += p2*lval*lval;
1587                 rsum2 += p2*rval*rval;
1588                 lc[k] = lval; rc[k] = rval;
1589             }
1590             L += weight;
1591             R -= weight;
1592         }
1593         else
1594         {
1595             for( k = 0; k < m; k++ )
1596             {
1597                 int t = crow[k];
1598                 int lval = lc[k] - t;
1599                 int rval = rc[k] + t;
1600                 double p = priors[k], p2 = p*p;
1601                 lsum2 += p2*lval*lval;
1602                 rsum2 += p2*rval*rval;
1603                 lc[k] = lval; rc[k] = rval;
1604             }
1605             L -= weight;
1606             R += weight;
1607         }
1608
1609         if( L > FLT_EPSILON && R > FLT_EPSILON )
1610         {
1611             double val = lsum2/L + rsum2/R;
1612             if( best_val < val )
1613             {
1614                 best_val = val;
1615                 best_subset = subset_i;
1616             }
1617         }
1618     }
1619
1620     if( best_subset < 0 )
1621         return 0;
1622
1623     split = data->new_split_cat( vi, (float)best_val );
1624
1625     if( m == 2 )
1626     {
1627         for( i = 0; i <= best_subset; i++ )
1628         {
1629             idx = (int_ptr[i] - cjk) >> 1;
1630             split->subset[idx >> 5] |= 1 << (idx & 31);
1631         }
1632     }
1633     else
1634     {
1635         for( i = 0; i < _mi; i++ )
1636         {
1637             idx = cluster_labels ? cluster_labels[i] : i;
1638             if( best_subset & (1 << idx) )
1639                 split->subset[i >> 5] |= 1 << (i & 31);
1640         }
1641     }
1642
1643     return split;
1644 }
1645
1646
1647 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1648 {
1649     const float epsilon = FLT_EPSILON*2;
1650     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1651     const float* responses = data->get_ord_responses(node);
1652     int n = node->sample_count;
1653     int n1 = node->num_valid[vi];
1654     int i, best_i = -1, L = 0, R = n1;
1655     double lsum = 0, rsum, best_val = 0;
1656
1657     rsum = node->value*n;
1658     // compensate for missing values
1659     for( i = n1; i < n; i++ )
1660         rsum -= responses[sorted[i].i];
1661
1662     // find the optimal split
1663     for( i = 0; i < n1 - 1; i++ )
1664     {
1665         float val = responses[sorted[i].i];
1666         L++; R--;
1667         lsum += val;
1668         rsum -= val;
1669
1670         if( sorted[i].val + epsilon < sorted[i+1].val )
1671         {
1672             double val = lsum*lsum/L + rsum*rsum/R;
1673             if( best_val < val )
1674             {
1675                 best_val = val;
1676                 best_i = i;
1677             }
1678         }
1679     }
1680
1681     return best_i >= 0 ? data->new_split_ord( vi,
1682         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1683         0, (float)best_val ) : 0;
1684 }
1685
1686
1687 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1688 {
1689     CvDTreeSplit* split;
1690     const int* labels = data->get_cat_var_data(node, vi);
1691     const float* responses = data->get_ord_responses(node);
1692     int ci = data->get_var_type(vi);
1693     int n = node->sample_count;
1694     int mi = data->cat_count->data.i[ci];
1695     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1696     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1697     double** sum_ptr = 0;
1698     int i, L = 0, R = 0;
1699     double best_val = 0, lsum = 0, rsum = 0;
1700     int best_subset = -1, subset_i;
1701
1702     for( i = -1; i < mi; i++ )
1703         sum[i] = counts[i] = 0;
1704
1705     // calculate sum response and weight of each category of the input var
1706     for( i = 0; i < n; i++ )
1707     {
1708         int idx = labels[i];
1709         double s = sum[idx] + responses[i];
1710         int nc = counts[idx] + 1;
1711         sum[idx] = s;
1712         counts[idx] = nc;
1713     }
1714
1715     // calculate average response in each category
1716     for( i = 0; i < mi; i++ )
1717     {
1718         R += counts[i];
1719         rsum += sum[i];
1720         sum[i] /= MAX(counts[i],1);
1721         sum_ptr[i] = sum + i;
1722     }
1723
1724     icvSortDblPtr( sum_ptr, mi, 0 );
1725
1726     // revert back to unnormalized sum
1727     // (there should be a very little loss of accuracy)
1728     for( i = 0; i < mi; i++ )
1729         sum[i] *= counts[i];
1730
1731     for( subset_i = 0; subset_i < mi-1; subset_i++ )
1732     {
1733         int idx = sum_ptr[subset_i] - sum;
1734         int ni = counts[idx];
1735
1736         if( ni )
1737         {
1738             double s = sum[idx];
1739             lsum += s; L += ni;
1740             rsum -= s; R -= ni;
1741             
1742             if( L && R )
1743             {
1744                 double val = lsum*lsum/L + rsum*rsum/R;
1745                 if( best_val < val )
1746                 {
1747                     best_val = val;
1748                     best_subset = subset_i;
1749                 }
1750             }
1751         }
1752     }
1753
1754     if( best_subset < 0 )
1755         return 0;
1756
1757     split = data->new_split_cat( vi, (float)best_val );
1758     for( i = 0; i <= best_subset; i++ )
1759     {
1760         int idx = sum_ptr[i] - sum;
1761         split->subset[idx >> 5] |= 1 << (idx & 31);
1762     }
1763
1764     return split;
1765 }
1766
1767
1768 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
1769 {
1770     const float epsilon = FLT_EPSILON*2;
1771     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1772     const char* dir = (char*)data->direction->data.ptr;
1773     int n1 = node->num_valid[vi];
1774     // LL - number of samples that both the primary and the surrogate splits send to the left
1775     // LR - ... primary split sends to the left and the surrogate split sends to the right
1776     // RL - ... primary split sends to the right and the surrogate split sends to the left
1777     // RR - ... both send to the right
1778     int i, best_i = -1, best_inversed = 0;
1779     double best_val; 
1780
1781     if( !data->have_priors )
1782     {
1783         int LL = 0, RL = 0, LR, RR;
1784         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
1785         int sum = 0, sum_abs = 0;
1786         
1787         for( i = 0; i < n1; i++ )
1788         {
1789             int d = dir[sorted[i].i];
1790             sum += d; sum_abs += d & 1;
1791         }
1792
1793         // sum_abs = R + L; sum = R - L
1794         RR = (sum_abs + sum) >> 1;
1795         LR = (sum_abs - sum) >> 1;
1796
1797         // initially all the samples are sent to the right by the surrogate split,
1798         // LR of them are sent to the left by primary split, and RR - to the right.
1799         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1800         for( i = 0; i < n1 - 1; i++ )
1801         {
1802             int d = dir[sorted[i].i];
1803
1804             if( d < 0 )
1805             {
1806                 LL++; LR--;
1807                 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1808                 {
1809                     best_val = LL + RR;
1810                     best_i = i; best_inversed = 0;
1811                 }
1812             }
1813             else if( d > 0 )
1814             {
1815                 RL++; RR--;
1816                 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1817                 {
1818                     best_val = RL + LR;
1819                     best_i = i; best_inversed = 1;
1820                 }
1821             }
1822         }
1823         best_val = _best_val;
1824     }
1825     else
1826     {
1827         double LL = 0, RL = 0, LR, RR;
1828         double worst_val = node->maxlr;
1829         double sum = 0, sum_abs = 0;
1830         const double* priors = data->priors->data.db;
1831         const int* responses = data->get_class_labels(node);
1832         best_val = worst_val;
1833         
1834         for( i = 0; i < n1; i++ )
1835         {
1836             int idx = sorted[i].i;
1837             double w = priors[responses[idx]];
1838             int d = dir[idx];
1839             sum += d*w; sum_abs += (d & 1)*w;
1840         }
1841
1842         // sum_abs = R + L; sum = R - L
1843         RR = (sum_abs + sum)*0.5;
1844         LR = (sum_abs - sum)*0.5;
1845
1846         // initially all the samples are sent to the right by the surrogate split,
1847         // LR of them are sent to the left by primary split, and RR - to the right.
1848         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1849         for( i = 0; i < n1 - 1; i++ )
1850         {
1851             int idx = sorted[i].i;
1852             double w = priors[responses[idx]];
1853             int d = dir[idx];
1854
1855             if( d < 0 )
1856             {
1857                 LL += w; LR -= w;
1858                 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1859                 {
1860                     best_val = LL + RR;
1861                     best_i = i; best_inversed = 0;
1862                 }
1863             }
1864             else if( d > 0 )
1865             {
1866                 RL += w; RR -= w;
1867                 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1868                 {
1869                     best_val = RL + LR;
1870                     best_i = i; best_inversed = 1;
1871                 }
1872             }
1873         }
1874     }
1875
1876     return best_i >= 0 ? data->new_split_ord( vi,
1877         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1878         best_inversed, (float)best_val ) : 0;
1879 }
1880
1881
1882 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
1883 {
1884     const int* labels = data->get_cat_var_data(node, vi);
1885     const char* dir = (char*)data->direction->data.ptr;
1886     int n = node->sample_count;
1887     // LL - number of samples that both the primary and the surrogate splits send to the left
1888     // LR - ... primary split sends to the left and the surrogate split sends to the right
1889     // RL - ... primary split sends to the right and the surrogate split sends to the left
1890     // RR - ... both send to the right
1891     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
1892     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
1893     double best_val = 0;
1894     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
1895     double* rc = lc + mi + 1;
1896     
1897     for( i = -1; i < mi; i++ )
1898         lc[i] = rc[i] = 0;
1899
1900     // for each category calculate the weight of samples
1901     // sent to the left (lc) and to the right (rc) by the primary split
1902     if( !data->have_priors )
1903     {
1904         int* _lc = data->counts->data.i + 1;
1905         int* _rc = _lc + mi + 1;
1906
1907         for( i = -1; i < mi; i++ )
1908             _lc[i] = _rc[i] = 0;
1909
1910         for( i = 0; i < n; i++ )
1911         {
1912             int idx = labels[i];
1913             int d = dir[i];
1914             int sum = _lc[idx] + d;
1915             int sum_abs = _rc[idx] + (d & 1);
1916             _lc[idx] = sum; _rc[idx] = sum_abs;
1917         }
1918
1919         for( i = 0; i < mi; i++ )
1920         {
1921             int sum = _lc[i];
1922             int sum_abs = _rc[i];
1923             lc[i] = (sum_abs - sum) >> 1;
1924             rc[i] = (sum_abs + sum) >> 1;
1925         }
1926     }
1927     else
1928     {
1929         const double* priors = data->priors->data.db;
1930         const int* responses = data->get_class_labels(node);
1931
1932         for( i = 0; i < n; i++ )
1933         {
1934             int idx = labels[i];
1935             double w = priors[responses[i]];
1936             int d = dir[i];
1937             double sum = lc[idx] + d*w;
1938             double sum_abs = rc[idx] + (d & 1)*w;
1939             lc[idx] = sum; rc[idx] = sum_abs;
1940         }
1941
1942         for( i = 0; i < mi; i++ )
1943         {
1944             double sum = lc[i];
1945             double sum_abs = rc[i];
1946             lc[i] = (sum_abs - sum) * 0.5;
1947             rc[i] = (sum_abs + sum) * 0.5;
1948         }
1949     }
1950
1951     // 2. now form the split.
1952     // in each category send all the samples to the same direction as majority
1953     for( i = 0; i < mi; i++ )
1954     {
1955         double lval = lc[i], rval = rc[i];
1956         if( lval > rval )
1957         {
1958             split->subset[i >> 5] |= 1 << (i & 31);
1959             best_val += lval;
1960         }
1961         else
1962             best_val += rval;
1963     }
1964
1965     split->quality = (float)best_val;
1966     if( split->quality <= node->maxlr )
1967         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
1968
1969     return split;
1970 }
1971
1972
1973 void CvDTree::calc_node_value( CvDTreeNode* node )
1974 {
1975     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
1976     const int* cv_labels = data->get_cv_labels(node);
1977
1978     if( data->is_classifier )
1979     {
1980         // in case of classification tree:
1981         //  * node value is the label of the class that has the largest weight in the node.
1982         //  * node risk is the weighted number of misclassified samples,
1983         //  * j-th cross-validation fold value and risk are calculated as above,
1984         //    but using the samples with cv_labels(*)!=j.
1985         //  * j-th cross-validation fold error is calculated as the weighted number of
1986         //    misclassified samples with cv_labels(*)==j.
1987
1988         // compute the number of instances of each class
1989         int* cls_count = data->counts->data.i;
1990         const int* responses = data->get_class_labels(node);
1991         int m = data->get_num_classes();
1992         int* cv_cls_count = cls_count + m;
1993         double max_val = -1, total_weight = 0;
1994         int max_k = -1;
1995         double* priors = data->priors->data.db;
1996
1997         for( k = 0; k < m; k++ )
1998             cls_count[k] = 0;
1999
2000         if( cv_n == 0 )
2001         {
2002             for( i = 0; i < n; i++ )
2003                 cls_count[responses[i]]++;
2004         }
2005         else
2006         {
2007             for( j = 0; j < cv_n; j++ )
2008                 for( k = 0; k < m; k++ )
2009                     cv_cls_count[j*m + k] = 0;
2010
2011             for( i = 0; i < n; i++ )
2012             {
2013                 j = cv_labels[i]; k = responses[i];
2014                 cv_cls_count[j*m + k]++;
2015             }
2016
2017             for( j = 0; j < cv_n; j++ )
2018                 for( k = 0; k < m; k++ )
2019                     cls_count[k] += cv_cls_count[j*m + k];
2020         }
2021
2022         for( k = 0; k < m; k++ )
2023         {
2024             double val = cls_count[k]*priors[k];
2025             total_weight += val;
2026             if( max_val < val )
2027             {
2028                 max_val = val;
2029                 max_k = k;
2030             }
2031         }
2032
2033         node->class_idx = max_k;
2034         node->value = data->cat_map->data.i[
2035             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2036         node->node_risk = total_weight - max_val;
2037
2038         for( j = 0; j < cv_n; j++ )
2039         {
2040             double sum_k = 0, sum = 0, max_val_k = 0;
2041             max_val = -1; max_k = -1;
2042
2043             for( k = 0; k < m; k++ )
2044             {
2045                 double w = priors[k];
2046                 double val_k = cv_cls_count[j*m + k]*w;
2047                 double val = cls_count[k]*w - val_k;
2048                 sum_k += val_k;
2049                 sum += val;
2050                 if( max_val < val )
2051                 {
2052                     max_val = val;
2053                     max_val_k = val_k;
2054                     max_k = k;
2055                 }
2056             }
2057
2058             node->cv_Tn[j] = INT_MAX;
2059             node->cv_node_risk[j] = sum - max_val;
2060             node->cv_node_error[j] = sum_k - max_val_k;
2061         }
2062     }
2063     else
2064     {
2065         // in case of regression tree:
2066         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2067         //    n is the number of samples in the node.
2068         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2069         //  * j-th cross-validation fold value and risk are calculated as above,
2070         //    but using the samples with cv_labels(*)!=j.
2071         //  * j-th cross-validation fold error is calculated
2072         //    using samples with cv_labels(*)==j as the test subset:
2073         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2074         //    where node_value_j is the node value calculated
2075         //    as described in the previous bullet, and summation is done
2076         //    over the samples with cv_labels(*)==j.
2077
2078         double sum = 0, sum2 = 0;
2079         const float* values = data->get_ord_responses(node);
2080         double *cv_sum = 0, *cv_sum2 = 0;
2081         int* cv_count = 0;
2082         
2083         if( cv_n == 0 )
2084         {
2085             // if cross-validation is not used, we even do not compute node_risk
2086             // (so the tree sequence T1>...>{root} may not be built).
2087             for( i = 0; i < n; i++ )
2088                 sum += values[i];
2089         }
2090         else
2091         {
2092             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2093             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2094             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2095
2096             for( j = 0; j < cv_n; j++ )
2097             {
2098                 cv_sum[j] = cv_sum2[j] = 0.;
2099                 cv_count[j] = 0;
2100             }
2101
2102             for( i = 0; i < n; i++ )
2103             {
2104                 j = cv_labels[i];
2105                 double t = values[i];
2106                 double s = cv_sum[j] + t;
2107                 double s2 = cv_sum2[j] + t*t;
2108                 int nc = cv_count[j] + 1;
2109                 cv_sum[j] = s;
2110                 cv_sum2[j] = s2;
2111                 cv_count[j] = nc;
2112             }
2113
2114             for( j = 0; j < cv_n; j++ )
2115             {
2116                 sum += cv_sum[j];
2117                 sum2 += cv_sum2[j];
2118             }
2119
2120             node->node_risk = sum2 - (sum/n)*sum;
2121         }
2122
2123         node->value = sum/n;
2124
2125         for( j = 0; j < cv_n; j++ )
2126         {
2127             double s = cv_sum[j], si = sum - s;
2128             double s2 = cv_sum2[j], s2i = sum2 - s2;
2129             int c = cv_count[j], ci = n - c;
2130             double r = si/MAX(ci,1);
2131             node->cv_node_risk[j] = s2i - r*r*ci;
2132             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2133             node->cv_Tn[j] = INT_MAX;
2134         }
2135     }
2136 }
2137
2138
2139 void CvDTree::split_node_data( CvDTreeNode* node )
2140 {
2141     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2142     int nz = n - node->num_valid[node->split->var_idx];
2143     char* dir = (char*)data->direction->data.ptr;
2144     CvDTreeNode *left = 0, *right = 0;
2145     int* new_idx = data->split_buf->data.i;
2146     int new_buf_idx = data->get_child_buf_idx( node );
2147
2148     // try to complete direction using surrogate splits
2149     if( nz && data->params.use_surrogates )
2150     {
2151         CvDTreeSplit* split = node->split->next;
2152         for( ; split != 0 && nz; split = split->next )
2153         {
2154             int inversed_mask = split->inversed ? -1 : 0;
2155             vi = split->var_idx;
2156
2157             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2158             {
2159                 const int* labels = data->get_cat_var_data(node, vi);
2160                 const int* subset = split->subset;
2161
2162                 for( i = 0; i < n; i++ )
2163                 {
2164                     int idx;
2165                     if( !dir[i] && (idx = labels[i]) >= 0 )
2166                     {
2167                         int d = DTREE_CAT_DIR(idx,subset);
2168                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2169                         if( --nz )
2170                             break;
2171                     }
2172                 }
2173             }
2174             else // split on ordered var
2175             {
2176                 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2177                 int split_point = split->ord.split_point;
2178                 int n1 = node->num_valid[vi];
2179
2180                 assert( 0 <= split_point && split_point < n-1 );
2181
2182                 for( i = 0; i < n1; i++ )
2183                 {
2184                     int idx = sorted[i].i;
2185                     if( !dir[idx] )
2186                     {
2187                         int d = i <= split_point ? -1 : 1;
2188                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2189                         if( --nz )
2190                             break;
2191                     }
2192                 }
2193             }
2194         }
2195     }
2196
2197     // find the default direction for the rest
2198     if( nz )
2199     {
2200         for( i = nr = 0; i < n; i++ )
2201             nr += dir[i] > 0;
2202         nl = n - nr - nz;
2203         d0 = nl > nr ? -1 : nr > nl;
2204     }
2205
2206     // make sure that every sample is directed either to the left or to the right
2207     for( i = nl = nr = 0; i < n; i++ )
2208     {
2209         int d = dir[i];
2210         if( !d )
2211         {
2212             d = d0;
2213             if( !d )
2214                 d = d1, d1 = -d1;
2215         }
2216         d = d > 0;
2217         dir[i] = (char)d; // remap (-1,1) to (0,1)
2218         
2219         // initialize new indices for splitting ordered variables
2220         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2221         nr += d;
2222         nl += d^1;
2223     }
2224
2225     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2226     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2227         (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
2228
2229     // split ordered variables, keep both halves sorted.
2230     for( vi = 0; vi < data->var_count; vi++ )
2231     {
2232         int ci = data->get_var_type(vi);
2233         int n1 = node->num_valid[vi];
2234         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2235         CvPair32s32f tl, tr;
2236
2237         if( ci >= 0 )
2238             continue;
2239
2240         src = data->get_ord_var_data(node, vi);
2241         ldst0 = ldst = data->get_ord_var_data(left, vi);
2242         rdst0 = rdst = data->get_ord_var_data(right, vi);
2243         tl = ldst0[nl]; tr = rdst0[nr];
2244
2245         // split sorted
2246         for( i = 0; i < n1; i++ )
2247         {
2248             int idx = src[i].i;
2249             float val = src[i].val;
2250             int d = dir[idx];
2251             idx = new_idx[idx];
2252             ldst->i = rdst->i = idx;
2253             ldst->val = rdst->val = val;
2254             ldst += d^1;
2255             rdst += d;
2256         }
2257
2258         left->num_valid[vi] = ldst - ldst0;
2259         right->num_valid[vi] = rdst - rdst0;
2260
2261         // split missing
2262         for( ; i < n; i++ )
2263         {
2264             int idx = src[i].i;
2265             int d = dir[idx];
2266             idx = new_idx[idx];
2267             ldst->i = rdst->i = idx;
2268             ldst->val = rdst->val = ord_nan;
2269             ldst += d^1;
2270             rdst += d;
2271         }
2272
2273         ldst0[nl] = tl; rdst0[nr] = tr;
2274     }
2275
2276     // split categorical vars, responses and cv_labels using new_idx relocation table
2277     for( vi = 0; vi <= data->var_count + data->have_cv_labels; vi++ )
2278     {
2279         int ci = data->get_var_type(vi);
2280         int n1 = node->num_valid[vi], nr1 = 0;
2281         int *src, *ldst0, *rdst0, *ldst, *rdst;
2282         int tl, tr;
2283
2284         if( ci < 0 )
2285             continue;
2286
2287         src = data->get_cat_var_data(node, vi);
2288         ldst0 = ldst = data->get_cat_var_data(left, vi);
2289         rdst0 = rdst = data->get_cat_var_data(right, vi);
2290         tl = ldst0[nl]; tr = rdst0[nr];
2291
2292         for( i = 0; i < n; i++ )
2293         {
2294             int d = dir[i];
2295             int val = src[i];
2296             *ldst = *rdst = val;
2297             ldst += d^1;
2298             rdst += d;
2299             nr1 += (val >= 0)&d;
2300         }
2301
2302         if( vi < data->var_count )
2303         {
2304             left->num_valid[vi] = n1 - nr1;
2305             right->num_valid[vi] = nr1;
2306         }
2307
2308         ldst0[nl] = tl; rdst0[nr] = tr;
2309     }
2310
2311     // deallocate the parent node data that is not needed anymore
2312     data->free_node_data(node);
2313 }
2314
2315
2316 void CvDTree::prune_cv()
2317 {
2318     CvMat* ab = 0;
2319     CvMat* temp = 0;
2320     CvMat* err_jk = 0;
2321     
2322     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2323     // 2. choose the best tree index (if need, apply 1SE rule).
2324     // 3. store the best index and cut the branches.
2325
2326     CV_FUNCNAME( "CvDTree::prune_cv" );
2327
2328     __BEGIN__;
2329
2330     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2331     // currently, 1SE for regression is not implemented
2332     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2333     double* err;
2334     double min_err = 0, min_err_se = 0;
2335     int min_idx = -1;
2336     
2337     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2338
2339     // build the main tree sequence, calculate alpha's
2340     for(;;tree_count++)
2341     {
2342         double min_alpha = update_tree_rnc(tree_count, -1);
2343         if( cut_tree(tree_count, -1, min_alpha) )
2344             break;
2345
2346         if( ab->cols <= tree_count )
2347         {
2348             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2349             for( ti = 0; ti < ab->cols; ti++ )
2350                 temp->data.db[ti] = ab->data.db[ti];
2351             cvReleaseMat( &ab );
2352             ab = temp;
2353             temp = 0;
2354         }
2355
2356         ab->data.db[tree_count] = min_alpha;
2357     }
2358
2359     ab->data.db[0] = 0.;
2360     for( ti = 1; ti < tree_count-1; ti++ )
2361         ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2362     ab->data.db[tree_count-1] = DBL_MAX*0.5;
2363
2364     CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2365     err = err_jk->data.db;
2366
2367     for( j = 0; j < cv_n; j++ )
2368     {
2369         int tj = 0, tk = 0;
2370         for( ; tk < tree_count; tj++ )
2371         {
2372             double min_alpha = update_tree_rnc(tj, j);
2373             if( cut_tree(tj, j, min_alpha) )
2374                 min_alpha = DBL_MAX;
2375
2376             for( ; tk < tree_count; tk++ )
2377             {
2378                 if( ab->data.db[tk] > min_alpha )
2379                     break;
2380                 err[j*tree_count + tk] = root->tree_error;
2381             }
2382         }
2383     }
2384
2385     for( ti = 0; ti < tree_count; ti++ )
2386     {
2387         double sum_err = 0;
2388         for( j = 0; j < cv_n; j++ )
2389             sum_err += err[j*tree_count + ti];
2390         if( ti == 0 || sum_err < min_err )
2391         {
2392             min_err = sum_err;
2393             min_idx = ti;
2394             if( use_1se )
2395                 min_err_se = sqrt( sum_err*(n - sum_err) );
2396         }
2397         else if( sum_err < min_err + min_err_se )
2398             min_idx = ti;
2399     }
2400
2401     pruned_tree_idx = min_idx;
2402     free_prune_data(data->params.truncate_pruned_tree != 0);
2403
2404     __END__;
2405
2406     cvReleaseMat( &err_jk );
2407     cvReleaseMat( &ab );
2408     cvReleaseMat( &temp );
2409 }
2410
2411
2412 double CvDTree::update_tree_rnc( int T, int fold )
2413 {
2414     CvDTreeNode* node = root;
2415     double min_alpha = DBL_MAX;
2416     
2417     for(;;)
2418     {
2419         CvDTreeNode* parent;
2420         for(;;)
2421         {
2422             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2423             if( t <= T || !node->left )
2424             {
2425                 node->complexity = 1;
2426                 node->tree_risk = node->node_risk;
2427                 node->tree_error = 0.;
2428                 if( fold >= 0 )
2429                 {
2430                     node->tree_risk = node->cv_node_risk[fold];
2431                     node->tree_error = node->cv_node_error[fold];
2432                 }
2433                 break;
2434             }
2435             node = node->left;
2436         }
2437         
2438         for( parent = node->parent; parent && parent->right == node;
2439             node = parent, parent = parent->parent )
2440         {
2441             parent->complexity += node->complexity;
2442             parent->tree_risk += node->tree_risk;
2443             parent->tree_error += node->tree_error;
2444
2445             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2446                 - parent->tree_risk)/(parent->complexity - 1);
2447             min_alpha = MIN( min_alpha, parent->alpha );
2448         }
2449
2450         if( !parent )
2451             break;
2452
2453         parent->complexity = node->complexity;
2454         parent->tree_risk = node->tree_risk;
2455         parent->tree_error = node->tree_error;
2456         node = parent->right;
2457     }
2458
2459     return min_alpha;
2460 }
2461
2462
2463 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2464 {
2465     CvDTreeNode* node = root;
2466     if( !node->left )
2467         return 1;
2468
2469     for(;;)
2470     {
2471         CvDTreeNode* parent;
2472         for(;;)
2473         {
2474             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2475             if( t <= T || !node->left )
2476                 break;
2477             if( node->alpha <= min_alpha + FLT_EPSILON )
2478             {
2479                 if( fold >= 0 )
2480                     node->cv_Tn[fold] = T;
2481                 else
2482                     node->Tn = T;
2483                 if( node == root )
2484                     return 1;
2485                 break;
2486             }
2487             node = node->left;
2488         }
2489         
2490         for( parent = node->parent; parent && parent->right == node;
2491             node = parent, parent = parent->parent )
2492             ;
2493
2494         if( !parent )
2495             break;
2496
2497         node = parent->right;
2498     }
2499
2500     return 0;
2501 }
2502
2503
2504 void CvDTree::free_prune_data(bool cut_tree)
2505 {
2506     CvDTreeNode* node = root;
2507     
2508     for(;;)
2509     {
2510         CvDTreeNode* parent;
2511         for(;;)
2512         {
2513             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2514             // as we will clear the whole cross-validation heap at the end
2515             node->cv_Tn = 0;
2516             node->cv_node_error = node->cv_node_risk = 0;
2517             if( !node->left )
2518                 break;
2519             node = node->left;
2520         }
2521         
2522         for( parent = node->parent; parent && parent->right == node;
2523             node = parent, parent = parent->parent )
2524         {
2525             if( cut_tree && parent->Tn <= pruned_tree_idx )
2526             {
2527                 data->free_node( parent->left );
2528                 data->free_node( parent->right );
2529                 parent->left = parent->right = 0;
2530             }
2531         }
2532
2533         if( !parent )
2534             break;
2535
2536         node = parent->right;
2537     }
2538
2539     if( data->cv_heap )
2540         cvClearSet( data->cv_heap );
2541 }
2542
2543
2544 void CvDTree::free_tree()
2545 {
2546     if( root && data && data->shared )
2547     {
2548         pruned_tree_idx = INT_MIN;
2549         free_prune_data(true);
2550         data->free_node(root);
2551         root = 0;
2552     }
2553 }
2554
2555
2556 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2557     const CvMat* _missing, bool preprocessed_input ) const
2558 {
2559     CvDTreeNode* result = 0;
2560     int* catbuf = 0;
2561
2562     CV_FUNCNAME( "CvDTree::predict" );
2563
2564     __BEGIN__;
2565
2566     int i, step, mstep = 0;
2567     const float* sample;
2568     const uchar* m = 0;
2569     CvDTreeNode* node = root;
2570     const int* vtype;
2571     const int* vidx;
2572     const int* cmap;
2573     const int* cofs;
2574
2575     if( !node )
2576         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2577
2578     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2579         _sample->cols != 1 && _sample->rows != 1 ||
2580         _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2581         _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2582             CV_ERROR( CV_StsBadArg,
2583         "the input sample must be 1d floating-point vector with the same "
2584         "number of elements as the total number of variables used for training" );
2585
2586     sample = _sample->data.fl;
2587     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
2588
2589     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2590     {
2591         int n = data->cat_count->cols;
2592         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2593         for( i = 0; i < n; i++ )
2594             catbuf[i] = -1;
2595     }
2596
2597     if( _missing )
2598     {
2599         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2600         !CV_ARE_SIZES_EQ(_missing, _sample) )
2601             CV_ERROR( CV_StsBadArg,
2602         "the missing data mask must be 8-bit vector of the same size as input sample" );
2603         m = _missing->data.ptr;
2604         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2605     }
2606
2607     vtype = data->var_type->data.i;
2608     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2609     cmap = data->cat_map->data.i;
2610     cofs = data->cat_ofs->data.i;
2611
2612     while( node->Tn > pruned_tree_idx && node->left )
2613     {
2614         CvDTreeSplit* split = node->split;
2615         int dir = 0;
2616         for( ; !dir && split != 0; split = split->next )
2617         {
2618             int vi = split->var_idx;
2619             int ci = vtype[vi];
2620             int i = vidx ? vidx[vi] : vi;
2621             float val = sample[i*step];
2622             if( m && m[i*mstep] )
2623                 continue;
2624             if( ci < 0 ) // ordered
2625                 dir = val <= split->ord.c ? -1 : 1;
2626             else // categorical
2627             {
2628                 int c;
2629                 if( preprocessed_input )
2630                     c = cvRound(val);
2631                 else
2632                 {
2633                     c = catbuf[ci];
2634                     if( c < 0 )
2635                     {
2636                         int a = c = cofs[ci];
2637                         int b = cofs[ci+1];
2638                         int ival = cvRound(val);
2639                         if( ival != val )
2640                             CV_ERROR( CV_StsBadArg,
2641                             "one of input categorical variable is not an integer" );
2642
2643                         while( a < b )
2644                         {
2645                             c = (a + b) >> 1;
2646                             if( ival < cmap[c] )
2647                                 b = c;
2648                             else if( ival > cmap[c] )
2649                                 a = c+1;
2650                             else
2651                                 break;
2652                         }
2653
2654                         if( c < 0 || ival != cmap[c] )
2655                             continue;
2656
2657                         catbuf[ci] = c -= cofs[ci];
2658                     }
2659                 }
2660                 dir = DTREE_CAT_DIR(c, split->subset);
2661             }
2662
2663             if( split->inversed )
2664                 dir = -dir;
2665         }
2666
2667         if( !dir )
2668         {
2669             double diff = node->right->sample_count - node->left->sample_count;
2670             dir = diff < 0 ? -1 : 1;
2671         }
2672         node = dir < 0 ? node->left : node->right;
2673     }
2674
2675     result = node;
2676
2677     __END__;
2678
2679     return result;
2680 }
2681
2682
2683 const CvMat* CvDTree::get_var_importance()
2684 {
2685     if( !var_importance )
2686     {
2687         CvDTreeNode* node = root;
2688         double* importance;
2689         if( !node )
2690             return 0;
2691         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
2692         cvZero( var_importance );
2693         importance = var_importance->data.db;
2694
2695         for(;;)
2696         {
2697             CvDTreeNode* parent;
2698             for( ;; node = node->left )
2699             {
2700                 CvDTreeSplit* split = node->split;
2701                 
2702                 if( !node->left || node->Tn <= pruned_tree_idx )
2703                     break;
2704                 
2705                 for( ; split != 0; split = split->next )
2706                     importance[split->var_idx] += split->quality;
2707             }
2708
2709             for( parent = node->parent; parent && parent->right == node;
2710                 node = parent, parent = parent->parent )
2711                 ;
2712
2713             if( !parent )
2714                 break;
2715
2716             node = parent->right;
2717         }
2718     }
2719
2720     return var_importance;
2721 }
2722
2723
2724 void CvDTree::save( const char* filename, const char* name )
2725 {
2726     CvFileStorage* fs = 0;
2727     
2728     CV_FUNCNAME( "CvDTree::save" );
2729
2730     __BEGIN__;
2731
2732     CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_WRITE ));
2733     if( !fs )
2734         CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2735
2736     write( fs, name ? name : "my_dtree" );
2737
2738     __END__;
2739
2740     cvReleaseFileStorage( &fs );
2741 }
2742
2743
2744 void CvDTree::write_train_data_params( CvFileStorage* fs )
2745 {
2746     CV_FUNCNAME( "CvDTree::write_train_data_params" );
2747
2748     __BEGIN__;
2749
2750     int vi, vcount = data->var_count;
2751
2752     cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
2753     cvWriteInt( fs, "var_all", data->var_all );
2754     cvWriteInt( fs, "var_count", data->var_count );
2755     cvWriteInt( fs, "ord_var_count", data->ord_var_count );
2756     cvWriteInt( fs, "cat_var_count", data->cat_var_count );
2757
2758     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
2759     cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
2760
2761     if( data->is_classifier )
2762     {
2763         cvWriteInt( fs, "max_categories", data->params.max_categories );
2764     }
2765     else
2766     {
2767         cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
2768     }
2769
2770     cvWriteInt( fs, "max_depth", data->params.max_depth );
2771     cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
2772     cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
2773     
2774     if( data->params.cv_folds > 1 )
2775     {
2776         cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
2777         cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
2778     }
2779
2780     if( data->priors )
2781         cvWrite( fs, "priors", data->priors );
2782
2783     cvEndWriteStruct( fs );
2784
2785     if( data->var_idx )
2786         cvWrite( fs, "var_idx", data->var_idx );
2787     
2788     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
2789
2790     for( vi = 0; vi < vcount; vi++ )
2791         cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
2792
2793     cvEndWriteStruct( fs );
2794
2795     if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
2796     {
2797         CV_ASSERT( data->cat_count != 0 );
2798         cvWrite( fs, "cat_count", data->cat_count );
2799         cvWrite( fs, "cat_map", data->cat_map );
2800     }
2801
2802     __END__;
2803 }
2804
2805
2806 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
2807 {
2808     int ci;
2809     
2810     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
2811     cvWriteInt( fs, "var", split->var_idx );
2812     cvWriteReal( fs, "quality", split->quality );
2813
2814     ci = data->get_var_type(split->var_idx);
2815     if( ci >= 0 ) // split on a categorical var
2816     {
2817         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
2818         for( i = 0; i < n; i++ )
2819             to_right += DTREE_CAT_DIR(i,split->subset) > 0;
2820
2821         // ad-hoc rule when to use inverse categorical split notation
2822         // to achieve more compact and clear representation
2823         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
2824         
2825         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
2826                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
2827         for( i = 0; i < n; i++ )
2828         {
2829             int dir = DTREE_CAT_DIR(i,split->subset);
2830             if( dir*default_dir < 0 )
2831                 cvWriteInt( fs, 0, i );
2832         }
2833         cvEndWriteStruct( fs );
2834     }
2835     else
2836         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
2837
2838     cvEndWriteStruct( fs );
2839 }
2840
2841
2842 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
2843 {
2844     CvDTreeSplit* split;
2845     
2846     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2847
2848     cvWriteInt( fs, "depth", node->depth );
2849     cvWriteInt( fs, "sample_count", node->sample_count );
2850     cvWriteReal( fs, "value", node->value );
2851     
2852     if( data->is_classifier )
2853         cvWriteInt( fs, "norm_class_idx", node->class_idx );
2854
2855     cvWriteInt( fs, "Tn", node->Tn );
2856     cvWriteInt( fs, "complexity", node->complexity );
2857     cvWriteReal( fs, "alpha", node->alpha );
2858     cvWriteReal( fs, "node_risk", node->node_risk );
2859     cvWriteReal( fs, "tree_risk", node->tree_risk );
2860     cvWriteReal( fs, "tree_error", node->tree_error );
2861
2862     if( node->left )
2863     {
2864         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
2865
2866         for( split = node->split; split != 0; split = split->next )
2867             write_split( fs, split );
2868
2869         cvEndWriteStruct( fs );
2870     }
2871
2872     cvEndWriteStruct( fs );
2873 }
2874
2875
2876 void CvDTree::write_tree_nodes( CvFileStorage* fs )
2877 {
2878     CV_FUNCNAME( "CvDTree::write_tree_nodes" );
2879
2880     __BEGIN__;
2881
2882     CvDTreeNode* node = root;
2883
2884     // traverse the tree and save all the nodes in depth-first order
2885     for(;;)
2886     {
2887         CvDTreeNode* parent;
2888         for(;;)
2889         {
2890             write_node( fs, node );
2891             if( !node->left )
2892                 break;
2893             node = node->left;
2894         }
2895         
2896         for( parent = node->parent; parent && parent->right == node;
2897             node = parent, parent = parent->parent )
2898             ;
2899
2900         if( !parent )
2901             break;
2902
2903         node = parent->right;
2904     }
2905
2906     __END__;
2907 }
2908
2909
2910 void CvDTree::write( CvFileStorage* fs, const char* name )
2911 {
2912     CV_FUNCNAME( "CvDTree::write" );
2913
2914     __BEGIN__;
2915
2916     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
2917
2918     write_train_data_params( fs );
2919
2920     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
2921     get_var_importance();
2922     cvWrite( fs, "var_importance", var_importance );
2923
2924     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
2925     write_tree_nodes( fs );
2926     cvEndWriteStruct( fs );
2927
2928     cvEndWriteStruct( fs );
2929
2930     __END__;
2931 }
2932
2933
2934 void CvDTree::load( const char* filename, const char* name )
2935 {
2936     CvFileStorage* fs = 0;
2937     
2938     CV_FUNCNAME( "CvDTree::load" );
2939
2940     __BEGIN__;
2941
2942     CvFileNode* tree = 0;
2943
2944     CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_READ ));
2945     if( !fs )
2946         CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2947
2948     if( name )
2949         tree = cvGetFileNodeByName( fs, 0, name );
2950     else
2951     {
2952         CvFileNode* root = cvGetRootFileNode( fs );
2953         if( root->data.seq->total > 0 )
2954             tree = (CvFileNode*)cvGetSeqElem( root->data.seq, 0 );
2955     }
2956
2957     read( fs, tree );
2958
2959     __END__;
2960
2961     cvReleaseFileStorage( &fs );
2962 }
2963
2964
2965 void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
2966 {
2967     CV_FUNCNAME( "CvDTree::read_train_data_params" );
2968
2969     __BEGIN__;
2970     
2971     CvDTreeParams params;
2972     CvFileNode *tparams_node, *vartype_node;
2973     CvSeqReader reader;
2974     int is_classifier, vi, cat_var_count, ord_var_count;
2975     int max_split_size, tree_block_size;
2976
2977     data = new CvDTreeTrainData;
2978
2979     data->is_classifier = is_classifier = cvReadIntByName( fs, node, "is_classifier" ) != 0;
2980     data->var_all = cvReadIntByName( fs, node, "var_all" );
2981     data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
2982     data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
2983     data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
2984
2985     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
2986
2987     if( tparams_node ) // training parameters are not necessary
2988     {
2989         data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
2990
2991         if( is_classifier )
2992         {
2993             data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
2994         }
2995         else
2996         {
2997             data->params.regression_accuracy =
2998                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
2999         }
3000
3001         data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
3002         data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
3003         data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
3004     
3005         if( data->params.cv_folds > 1 )
3006         {
3007             data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
3008             data->params.truncate_pruned_tree =
3009                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
3010         }
3011
3012         data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
3013         if( data->priors && !CV_IS_MAT(data->priors) )
3014             CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
3015     }
3016
3017     CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
3018     if( data->var_idx )
3019     {
3020         if( !CV_IS_MAT(data->var_idx) ||
3021             data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
3022             data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
3023             CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
3024             CV_ERROR( CV_StsParseError,
3025                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
3026
3027         for( vi = 0; vi < data->var_count; vi++ )
3028             if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
3029                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
3030     }
3031     
3032     ////// read var type
3033     CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
3034
3035     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
3036     if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
3037         vartype_node->data.seq->total != data->var_count )
3038         CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3039
3040     cvStartReadSeq( vartype_node->data.seq, &reader );
3041
3042     cat_var_count = 0;
3043     ord_var_count = -1;
3044     for( vi = 0; vi < data->var_count; vi++ )
3045     {
3046         CvFileNode* n = (CvFileNode*)reader.ptr;
3047         if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
3048             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3049         data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
3050         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3051     }
3052     
3053     ord_var_count = ~ord_var_count;
3054     if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
3055         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
3056     //////
3057
3058     if( data->cat_var_count > 0 || is_classifier )
3059     {
3060         int ccount, max_c_count = 0, total_c_count = 0;
3061         CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
3062         CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
3063
3064         if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
3065             data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
3066             CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
3067             data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
3068             data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
3069             CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
3070             CV_ERROR( CV_StsParseError,
3071             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
3072
3073         ccount = cat_var_count + is_classifier;
3074
3075         CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
3076         data->cat_ofs->data.i[0] = 0;
3077
3078         for( vi = 0; vi < ccount; vi++ )
3079         {
3080             int val = data->cat_count->data.i[vi];
3081             if( val <= 0 )
3082                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
3083             max_c_count = MAX( max_c_count, val );
3084             data->cat_ofs->data.i[vi+1] = total_c_count += val;
3085         }
3086
3087         if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
3088             CV_ERROR( CV_StsBadSize,
3089             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
3090         
3091         data->max_c_count = max_c_count;
3092     }
3093
3094     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
3095         (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
3096
3097     tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
3098     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
3099     CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
3100     CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
3101             sizeof(CvDTreeNode), data->tree_storage ));
3102     CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
3103             max_split_size, data->tree_storage ));
3104
3105     __END__;
3106 }
3107
3108
3109 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3110 {
3111     CvDTreeSplit* split = 0;
3112     
3113     CV_FUNCNAME( "CvDTree::read_split" );
3114
3115     __BEGIN__;
3116
3117     int vi, ci;
3118     
3119     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3120         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3121
3122     vi = cvReadIntByName( fs, fnode, "var", -1 );
3123     if( (unsigned)vi >= (unsigned)data->var_count )
3124         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3125
3126     ci = data->get_var_type(vi);
3127     if( ci >= 0 ) // split on categorical var
3128     {
3129         int i, n = data->cat_count->data.i[ci], inversed = 0;
3130         CvSeqReader reader;
3131         CvFileNode* inseq;
3132         split = data->new_split_cat( vi, 0 );
3133         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3134         if( !inseq )
3135         {
3136             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3137             inversed = 1;
3138         }
3139         if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
3140             CV_ERROR( CV_StsParseError,
3141             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3142
3143         cvStartReadSeq( inseq->data.seq, &reader );
3144
3145         for( i = 0; i < reader.seq->total; i++ )
3146         {
3147             CvFileNode* inode = (CvFileNode*)reader.ptr;
3148             int val = inode->data.i;
3149             if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3150                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3151
3152             split->subset[val >> 5] |= 1 << (val & 31);
3153             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3154         }
3155
3156         // for categorical splits we do not use inversed splits,
3157         // instead we inverse the variable set in the split
3158         if( inversed )
3159             for( i = 0; i < (n + 31) >> 5; i++ )
3160                 split->subset[i] ^= -1;
3161     }
3162     else
3163     {
3164         CvFileNode* cmp_node;
3165         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3166
3167         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3168         if( !cmp_node )
3169         {
3170             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3171             split->inversed = 1;
3172         }
3173
3174         split->ord.c = (float)cvReadReal( cmp_node );
3175     }
3176         
3177     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3178
3179     __END__;
3180     
3181     return split;
3182 }
3183
3184
3185 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3186 {
3187     CvDTreeNode* node = 0;
3188     
3189     CV_FUNCNAME( "CvDTree::read_node" );
3190
3191     __BEGIN__;
3192
3193     CvFileNode* splits;
3194     int i, depth;
3195
3196     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3197         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3198
3199     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3200     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3201     if( depth != node->depth )
3202         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3203
3204     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3205     node->value = cvReadRealByName( fs, fnode, "value" );
3206     if( data->is_classifier )
3207         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3208
3209     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3210     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3211     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3212     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3213     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3214     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3215
3216     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3217     if( splits )
3218     {
3219         CvSeqReader reader;
3220         CvDTreeSplit* last_split = 0;
3221
3222         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3223             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3224
3225         cvStartReadSeq( splits->data.seq, &reader );
3226         for( i = 0; i < reader.seq->total; i++ )
3227         {
3228             CvDTreeSplit* split;
3229             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3230             if( !last_split )
3231                 node->split = last_split = split;
3232             else
3233                 last_split = last_split->next = split;
3234
3235             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3236         }
3237     }
3238
3239     __END__;
3240     
3241     return node;
3242 }
3243
3244
3245 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3246 {
3247     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3248
3249     __BEGIN__;
3250
3251     CvSeqReader reader;
3252     CvDTreeNode _root;
3253     CvDTreeNode* parent = &_root;
3254     int i;
3255     parent->left = parent->right = parent->parent = 0;
3256
3257     cvStartReadSeq( fnode->data.seq, &reader );
3258
3259     for( i = 0; i < reader.seq->total; i++ )
3260     {
3261         CvDTreeNode* node;
3262         
3263         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3264         if( !parent->left )
3265             parent->left = node;
3266         else
3267             parent->right = node;
3268         if( node->split )
3269             parent = node;
3270         else
3271         {
3272             while( parent && parent->right )
3273                 parent = parent->parent;
3274         }
3275
3276         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3277     }
3278
3279     root = _root.left;
3280
3281     __END__;
3282 }
3283
3284
3285 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3286 {
3287     CV_FUNCNAME( "CvDTree::read" );
3288
3289     __BEGIN__;
3290
3291     CvFileNode* tree_nodes;
3292
3293     clear();
3294     read_train_data_params( fs, fnode );
3295
3296     tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
3297     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3298         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3299
3300     pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
3301
3302     read_tree_nodes( fs, tree_nodes );
3303     get_var_importance(); // recompute variable importance
3304
3305     __END__;
3306 }
3307
3308 /* End of file. */