]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
a lot of minor changes: API cleaned up, compiler errors/warinings fixed
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static const float ord_var_epsilon = FLT_EPSILON*2;
44 static const float ord_nan = FLT_MAX*0.5f;
45 static const int min_block_size = 1 << 16;
46 static const int block_size_delta = 1 << 10;
47
48 CvDTreeTrainData::CvDTreeTrainData()
49 {
50     var_idx = var_type = cat_count = cat_ofs = cat_map =
51         priors = counts = buf = direction = split_buf = 0;
52     tree_storage = temp_storage = 0;
53
54     clear();
55 }
56
57
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
59                       const CvMat* _responses, const CvMat* _var_idx,
60                       const CvMat* _sample_idx, const CvMat* _var_type,
61                       const CvMat* _missing_mask,
62                       CvDTreeParams _params, bool _shared, bool _add_weights )
63 {
64     var_idx = var_type = cat_count = cat_ofs = cat_map =
65         priors = counts = buf = direction = split_buf = 0;
66     tree_storage = temp_storage = 0;
67     
68     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
69               _var_type, _missing_mask, _params, _shared, _add_weights );
70 }
71
72
73 CvDTreeTrainData::~CvDTreeTrainData()
74 {
75     clear();
76 }
77
78
79 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
80 {
81     bool ok = false;
82     
83     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
84
85     __BEGIN__;
86
87     // set parameters
88     params = _params;
89
90     if( params.max_categories < 2 )
91         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
92     params.max_categories = MIN( params.max_categories, 15 );
93
94     if( params.max_depth < 0 )
95         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
96     params.max_depth = MIN( params.max_depth, 25 );
97
98     params.min_sample_count = MAX(params.min_sample_count,1);
99
100     if( params.cv_folds < 0 )
101         CV_ERROR( CV_StsOutOfRange,
102         "params.cv_folds should be =0 (the tree is not pruned) "
103         "or n>0 (tree is pruned using n-fold cross-validation)" );
104
105     if( params.cv_folds == 1 )
106         params.cv_folds = 0;
107
108     if( params.regression_accuracy < 0 )
109         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
110
111     ok = true;
112
113     __END__;
114
115     return ok;
116 }
117
118
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
122
123 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
124 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
125
126 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
127     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
128     const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
129     bool _shared, bool _add_weights )
130 {
131     CvMat* sample_idx = 0;
132     CvMat* var_type0 = 0;
133     CvMat* tmp_map = 0;
134     int** int_ptr = 0;
135
136     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
137
138     __BEGIN__;
139
140     int sample_all = 0, r_type = 0, cv_n;
141     int total_c_count = 0;
142     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
144     int vi, i;
145     char err[100];
146     const int *sidx = 0, *vidx = 0;
147
148     clear();
149
150     var_all = 0;
151     rng = cvRNG(-1);
152
153     CV_CALL( set_params( _params ));
154
155     // check parameter types and sizes
156     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
157     if( _tflag == CV_ROW_SAMPLE )
158     {
159         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
160         dv_step = 1;
161         if( _missing_mask )
162             ms_step = _missing_mask->step, mv_step = 1;
163     }
164     else
165     {
166         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
167         ds_step = 1;
168         if( _missing_mask )
169             mv_step = _missing_mask->step, ms_step = 1;
170     }
171
172     sample_count = sample_all;
173     var_count = var_all;
174
175     if( _sample_idx )
176     {
177         CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
178         sidx = sample_idx->data.i;
179         sample_count = sample_idx->rows + sample_idx->cols - 1;
180     }
181
182     if( _var_idx )
183     {
184         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
185         vidx = var_idx->data.i;
186         var_count = var_idx->rows + var_idx->cols - 1;
187     }
188
189     if( !CV_IS_MAT(_responses) ||
190         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
191          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
192         _responses->rows != 1 && _responses->cols != 1 ||
193         _responses->rows + _responses->cols - 1 != sample_all )
194         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
195                   "floating-point vector containing as many elements as "
196                   "the total number of samples in the training data matrix" );
197
198     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
199     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
200
201     cat_var_count = 0;
202     ord_var_count = -1;
203
204     is_classifier = r_type == CV_VAR_CATEGORICAL;
205
206     // step 0. calc the number of categorical vars
207     for( vi = 0; vi < var_count; vi++ )
208     {
209         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
210             cat_var_count++ : ord_var_count--;
211     }
212
213     ord_var_count = ~ord_var_count;
214     cv_n = params.cv_folds;
215     // set the two last elements of var_type array to be able
216     // to locate responses and cross-validation labels using
217     // the corresponding get_* functions.
218     var_type->data.i[var_count] = cat_var_count;
219     var_type->data.i[var_count+1] = cat_var_count+1;
220
221     // in case of single ordered predictor we need dummy cv_labels
222     // for safe split_node_data() operation
223     have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
224     have_weights = _add_weights;
225
226     buf_size = (ord_var_count*2 + cat_var_count + 1 +
227         (have_cv_labels ? 1 : 0) + (have_weights ? 1 : 0))*sample_count + 2;
228     shared = _shared;
229     buf_count = shared ? 3 : 2;
230     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
231     CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
232     CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
233     CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
234
235     // now calculate the maximum size of split,
236     // create memory storage that will keep nodes and splits of the decision tree
237     // allocate root node and the buffer for the whole training data
238     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
239         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
240     tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
241     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
242     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
243     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
244
245     temp_block_size = nv_size = var_count*sizeof(int);
246     if( cv_n )
247     {
248         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
249             CV_ERROR( CV_StsOutOfRange,
250                 "The many folds in cross-validation for such a small dataset" );
251
252         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
253         temp_block_size = MAX(temp_block_size, cv_size);
254     }
255
256     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
257     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
258     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
259     if( cv_size )
260         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
261
262     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
263     CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
264
265     max_c_count = 1;
266
267     // transform the training data to convenient representation
268     for( vi = 0; vi <= var_count; vi++ )
269     {
270         int ci;
271         const uchar* mask = 0;
272         int m_step = 0, step;
273         const int* idata = 0;
274         const float* fdata = 0;
275         int num_valid = 0;
276
277         if( vi < var_count ) // analyze i-th input variable
278         {
279             int vi0 = vidx ? vidx[vi] : vi;
280             ci = get_var_type(vi);
281             step = ds_step; m_step = ms_step;
282             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
283                 idata = _train_data->data.i + vi0*dv_step;
284             else
285                 fdata = _train_data->data.fl + vi0*dv_step;
286             if( _missing_mask )
287                 mask = _missing_mask->data.ptr + vi0*mv_step;
288         }
289         else // analyze _responses
290         {
291             ci = cat_var_count;
292             step = CV_IS_MAT_CONT(_responses->type) ?
293                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
294             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
295                 idata = _responses->data.i;
296             else
297                 fdata = _responses->data.fl;
298         }
299
300         if( vi < var_count && ci >= 0 ||
301             vi == var_count && is_classifier ) // process categorical variable or response
302         {
303             int c_count, prev_label, prev_i;
304             int* c_map, *dst = get_cat_var_data( data_root, vi );
305
306             // copy data
307             for( i = 0; i < sample_count; i++ )
308             {
309                 int val = INT_MAX, si = sidx ? sidx[i] : i;
310                 if( !mask || !mask[si*m_step] )
311                 {
312                     if( idata )
313                         val = idata[si*step];
314                     else
315                     {
316                         float t = fdata[si*step];
317                         val = cvRound(t);
318                         if( val != t )
319                         {
320                             sprintf( err, "%d-th value of %d-th (categorical) "
321                                 "variable is not an integer", i, vi );
322                             CV_ERROR( CV_StsBadArg, err );
323                         }
324                     }
325
326                     if( val == INT_MAX )
327                     {
328                         sprintf( err, "%d-th value of %d-th (categorical) "
329                             "variable is too large", i, vi );
330                         CV_ERROR( CV_StsBadArg, err );
331                     }
332                     num_valid++;
333                 }
334                 dst[i] = val;
335                 int_ptr[i] = dst + i;
336             }
337
338             // sort all the values, including the missing measurements
339             // that should all move to the end
340             icvSortIntPtr( int_ptr, sample_count, 0 );
341             //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
342
343             c_count = num_valid > 0;
344
345             // count the categories
346             for( i = 1; i < num_valid; i++ )
347                 c_count += *int_ptr[i] != *int_ptr[i-1];
348
349             if( vi > 0 )
350                 max_c_count = MAX( max_c_count, c_count );
351             cat_count->data.i[ci] = c_count;
352             cat_ofs->data.i[ci] = total_c_count;
353
354             // resize cat_map, if need
355             if( cat_map->cols < total_c_count + c_count )
356             {
357                 tmp_map = cat_map;
358                 CV_CALL( cat_map = cvCreateMat( 1,
359                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
360                 for( i = 0; i < total_c_count; i++ )
361                     cat_map->data.i[i] = tmp_map->data.i[i];
362                 cvReleaseMat( &tmp_map );
363             }
364
365             c_map = cat_map->data.i + total_c_count;
366             total_c_count += c_count;
367
368             // compact the class indices and build the map
369             prev_label = ~*int_ptr[0];
370             c_count = -1;
371
372             for( i = 0, prev_i = -1; i < num_valid; i++ )
373             {
374                 int cur_label = *int_ptr[i];
375                 if( cur_label != prev_label )
376                 {
377                     c_map[++c_count] = prev_label = cur_label;
378                     prev_i = i;
379                 }
380                 *int_ptr[i] = c_count;
381             }
382
383             // replace labels for missing values with -1
384             for( ; i < sample_count; i++ )
385                 *int_ptr[i] = -1;
386         }
387         else if( ci < 0 ) // process ordered variable
388         {
389             CvPair32s32f* dst = get_ord_var_data( data_root, vi );
390
391             for( i = 0; i < sample_count; i++ )
392             {
393                 float val = ord_nan;
394                 int si = sidx ? sidx[i] : i;
395                 if( !mask || !mask[si*m_step] )
396                 {
397                     if( idata )
398                         val = (float)idata[si*step];
399                     else
400                         val = fdata[si*step];
401
402                     if( fabs(val) >= ord_nan )
403                     {
404                         sprintf( err, "%d-th value of %d-th (ordered) "
405                             "variable (=%g) is too large", i, vi, val );
406                         CV_ERROR( CV_StsBadArg, err );
407                     }
408                     num_valid++;
409                 }
410                 dst[i].i = i;
411                 dst[i].val = val;
412             }
413
414             icvSortPairs( dst, sample_count, 0 );
415         }
416         else // special case: process ordered response,
417              // it will be stored similarly to categorical vars (i.e. no pairs)
418         {
419             float* dst = get_ord_responses( data_root );
420
421             for( i = 0; i < sample_count; i++ )
422             {
423                 float val = ord_nan;
424                 int si = sidx ? sidx[i] : i;
425                 if( idata )
426                     val = (float)idata[si*step];
427                 else
428                     val = fdata[si*step];
429
430                 if( fabs(val) >= ord_nan )
431                 {
432                     sprintf( err, "%d-th value of %d-th (ordered) "
433                         "variable (=%g) is out of range", i, vi, val );
434                     CV_ERROR( CV_StsBadArg, err );
435                 }
436                 dst[i] = val;
437             }
438
439             cat_count->data.i[cat_var_count] = 0;
440             cat_ofs->data.i[cat_var_count] = total_c_count;
441             num_valid = sample_count;
442         }
443
444         if( vi < var_count )
445             data_root->set_num_valid(vi, num_valid);
446     }
447
448     if( cv_n )
449     {
450         int* dst = get_cv_labels(data_root);
451         CvRNG* r = &rng;
452
453         for( i = vi = 0; i < sample_count; i++ )
454         {
455             dst[i] = vi++;
456             vi &= vi < cv_n ? -1 : 0;
457         }
458
459         for( i = 0; i < sample_count; i++ )
460         {
461             int a = cvRandInt(r) % sample_count;
462             int b = cvRandInt(r) % sample_count;
463             CV_SWAP( dst[a], dst[b], vi );
464         }
465     }
466
467     cat_map->cols = MAX( total_c_count, 1 );
468
469     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
470         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
471     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
472
473     have_priors = is_classifier && params.priors;
474     if( is_classifier )
475     {
476         int m = get_num_classes(), rows = 4;
477         double sum = 0;
478         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
479         for( i = 0; i < m; i++ )
480         {
481             double val = have_priors ? params.priors[i] : 1.;
482             if( val <= 0 )
483                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
484             priors->data.db[i] = val;
485             sum += val;
486         }
487         // normalize weights
488         if( have_priors )
489             cvScale( priors, priors, 1./sum );
490
491         if( cat_var_count > 0 || params.cv_folds > 0 )
492         {
493             // need storage for cjk (see find_split_cat_gini) and risks/errors
494             rows += MAX( max_c_count, params.cv_folds ) + 1;
495             // add buffer for k-means clustering
496             if( m > 2 && max_c_count > params.max_categories )
497                 rows += params.max_categories + (max_c_count+m-1)/m;
498         }
499
500         CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
501     }
502
503     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
504     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
505
506     __END__;
507
508     cvFree( &int_ptr );
509     cvReleaseMat( &sample_idx );
510     cvReleaseMat( &var_type0 );
511     cvReleaseMat( &tmp_map );
512 }
513
514
515 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
516 {
517     CvDTreeNode* root = 0;
518     CvMat* isubsample_idx = 0;
519     CvMat* subsample_co = 0;
520     
521     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
522
523     __BEGIN__;
524
525     if( !data_root )
526         CV_ERROR( CV_StsError, "No training data has been set" );
527     
528     if( _subsample_idx )
529         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
530
531     if( !isubsample_idx )
532     {
533         // make a copy of the root node
534         CvDTreeNode temp;
535         int i;
536         root = new_node( 0, 1, 0, 0 );
537         temp = *root;
538         *root = *data_root;
539         root->num_valid = temp.num_valid;
540         if( root->num_valid )
541         {
542             for( i = 0; i < var_count; i++ )
543                 root->num_valid[i] = data_root->num_valid[i];
544         }
545         root->cv_Tn = temp.cv_Tn;
546         root->cv_node_risk = temp.cv_node_risk;
547         root->cv_node_error = temp.cv_node_error;
548     }
549     else
550     {
551         int* sidx = isubsample_idx->data.i;
552         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
553         int* co, cur_ofs = 0;
554         int vi, i, total = data_root->sample_count;
555         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
556         root = new_node( 0, count, 1, 0 );
557
558         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
559         cvZero( subsample_co );
560         co = subsample_co->data.i;
561         for( i = 0; i < count; i++ )
562             co[sidx[i]*2]++;
563         for( i = 0; i < total; i++ )
564         {
565             if( co[i*2] )
566             {
567                 co[i*2+1] = cur_ofs;
568                 cur_ofs += co[i*2];
569             }
570             else
571                 co[i*2+1] = -1;
572         }
573
574         for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
575         {
576             int ci = get_var_type(vi);
577
578             if( ci >= 0 || vi >= var_count )
579             {
580                 const int* src = get_cat_var_data( data_root, vi );
581                 int* dst = get_cat_var_data( root, vi );
582                 int num_valid = 0;
583
584                 for( i = 0; i < count; i++ )
585                 {
586                     int val = src[sidx[i]];
587                     dst[i] = val;
588                     num_valid += val >= 0;
589                 }
590
591                 if( vi < var_count )
592                     root->set_num_valid(vi, num_valid);
593             }
594             else
595             {
596                 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
597                 CvPair32s32f* dst = get_ord_var_data( root, vi );
598                 int j = 0, idx, count_i;
599                 int num_valid = data_root->get_num_valid(vi);
600
601                 for( i = 0; i < num_valid; i++ )
602                 {
603                     idx = src[i].i;
604                     count_i = co[idx*2];
605                     if( count_i )
606                     {
607                         float val = src[i].val;
608                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
609                         {
610                             dst[j].val = val;
611                             dst[j].i = cur_ofs;
612                         }
613                     }
614                 }
615
616                 root->set_num_valid(vi, j);
617
618                 for( ; i < total; i++ )
619                 {
620                     idx = src[i].i;
621                     count_i = co[idx*2];
622                     if( count_i )
623                     {
624                         float val = src[i].val;
625                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
626                         {
627                             dst[j].val = val;
628                             dst[j].i = cur_ofs;
629                         }
630                     }
631                 }
632             }
633         }
634     }
635
636     __END__;
637
638     cvReleaseMat( &isubsample_idx );
639     cvReleaseMat( &subsample_co );
640
641     return root;
642 }
643
644
645 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
646                                     float* values, uchar* missing,
647                                     float* responses, bool get_class_idx )
648 {
649     CvMat* subsample_idx = 0;
650     CvMat* subsample_co = 0;
651     
652     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
653
654     __BEGIN__;
655
656     int i, vi, total = sample_count, count = total, cur_ofs = 0;
657     int* sidx = 0;
658     int* co = 0;
659
660     if( _subsample_idx )
661     {
662         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
663         sidx = subsample_idx->data.i;
664         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
665         co = subsample_co->data.i;
666         cvZero( subsample_co );
667         count = subsample_idx->cols + subsample_idx->rows - 1;
668         for( i = 0; i < count; i++ )
669             co[sidx[i]*2]++;
670         for( i = 0; i < total; i++ )
671         {
672             int count_i = co[i*2];
673             if( count_i )
674             {
675                 co[i*2+1] = cur_ofs*var_count;
676                 cur_ofs += count_i;
677             }
678         }
679     }
680
681     memset( missing, 1, count*var_count );
682
683     for( vi = 0; vi < var_count; vi++ )
684     {
685         int ci = get_var_type(vi);
686         if( ci >= 0 ) // categorical
687         {
688             float* dst = values + vi;
689             uchar* m = missing + vi;
690             const int* src = get_cat_var_data(data_root, vi);
691
692             for( i = 0; i < count; i++, dst += var_count, m += var_count )
693             {
694                 int idx = sidx ? sidx[i] : i;
695                 int val = src[idx];
696                 *dst = (float)val;
697                 *m = val < 0;
698             }
699         }
700         else // ordered
701         {
702             float* dst = values + vi;
703             uchar* m = missing + vi;
704             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
705             int count1 = data_root->get_num_valid(vi);
706
707             for( i = 0; i < count1; i++ )
708             {
709                 int idx = src[i].i;
710                 int count_i = 1;
711                 if( co )
712                 {
713                     count_i = co[idx*2];
714                     cur_ofs = co[idx*2+1];
715                 }
716                 else
717                     cur_ofs = idx*var_count;
718                 if( count_i )
719                 {
720                     float val = src[i].val;
721                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
722                     {
723                         dst[cur_ofs] = val;
724                         m[cur_ofs] = 0;
725                     }
726                 }
727             }
728         }
729     }
730
731     // copy responses
732     if( is_classifier )
733     {
734         const int* src = get_class_labels(data_root);
735         for( i = 0; i < count; i++ )
736         {
737             int idx = sidx ? sidx[i] : i;
738             int val = get_class_idx ? src[idx] :
739                 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
740             responses[i] = (float)val;
741         }
742     }
743     else
744     {
745         const float* src = get_ord_responses(data_root);
746         for( i = 0; i < count; i++ )
747         {
748             int idx = sidx ? sidx[i] : i;
749             responses[i] = src[idx];
750         }
751     }
752
753     __END__;
754
755     cvReleaseMat( &subsample_idx );
756     cvReleaseMat( &subsample_co );
757 }
758
759
760 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
761                                          int storage_idx, int offset )
762 {
763     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
764
765     node->sample_count = count;
766     node->depth = parent ? parent->depth + 1 : 0;
767     node->parent = parent;
768     node->left = node->right = 0;
769     node->split = 0;
770     node->value = 0;
771     node->class_idx = 0;
772     node->maxlr = 0.;
773
774     node->buf_idx = storage_idx;
775     node->offset = offset;
776     if( nv_heap )
777         node->num_valid = (int*)cvSetNew( nv_heap );
778     else
779         node->num_valid = 0;
780     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
781     node->complexity = 0;
782
783     if( params.cv_folds > 0 && cv_heap )
784     {
785         int cv_n = params.cv_folds;
786         node->Tn = INT_MAX;
787         node->cv_Tn = (int*)cvSetNew( cv_heap );
788         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
789         node->cv_node_error = node->cv_node_risk + cv_n;
790     }
791     else
792     {
793         node->Tn = 0;
794         node->cv_Tn = 0;
795         node->cv_node_risk = 0;
796         node->cv_node_error = 0;
797     }
798
799     return node;
800 }
801
802
803 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
804                 int split_point, int inversed, float quality )
805 {
806     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
807     split->var_idx = vi;
808     split->ord.c = cmp_val;
809     split->ord.split_point = split_point;
810     split->inversed = inversed;
811     split->quality = quality;
812     split->next = 0;
813
814     return split;
815 }
816
817
818 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
819 {
820     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
821     int i, n = (max_c_count + 31)/32;
822
823     split->var_idx = vi;
824     split->inversed = 0;
825     split->quality = quality;
826     for( i = 0; i < n; i++ )
827         split->subset[i] = 0;
828     split->next = 0;
829
830     return split;
831 }
832
833
834 void CvDTreeTrainData::free_node( CvDTreeNode* node )
835 {
836     CvDTreeSplit* split = node->split;
837     free_node_data( node );
838     while( split )
839     {
840         CvDTreeSplit* next = split->next;
841         cvSetRemoveByPtr( split_heap, split );
842         split = next;
843     }
844     node->split = 0;
845     cvSetRemoveByPtr( node_heap, node );
846 }
847
848
849 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
850 {
851     if( node->num_valid )
852     {
853         cvSetRemoveByPtr( nv_heap, node->num_valid );
854         node->num_valid = 0;
855     }
856     // do not free cv_* fields, as all the cross-validation related data is released at once.
857 }
858
859
860 void CvDTreeTrainData::free_train_data()
861 {
862     cvReleaseMat( &counts );
863     cvReleaseMat( &buf );
864     cvReleaseMat( &direction );
865     cvReleaseMat( &split_buf );
866     cvReleaseMemStorage( &temp_storage );
867     cv_heap = nv_heap = 0;
868 }
869
870
871 void CvDTreeTrainData::clear()
872 {
873     free_train_data();
874
875     cvReleaseMemStorage( &tree_storage );
876
877     cvReleaseMat( &var_idx );
878     cvReleaseMat( &var_type );
879     cvReleaseMat( &cat_count );
880     cvReleaseMat( &cat_ofs );
881     cvReleaseMat( &cat_map );
882     cvReleaseMat( &priors );
883
884     node_heap = split_heap = 0;
885
886     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
887     have_cv_labels = have_priors = is_classifier = false;
888
889     buf_count = buf_size = 0;
890     shared = false;
891
892     data_root = 0;
893
894     rng = cvRNG(-1);
895 }
896
897
898 int CvDTreeTrainData::get_num_classes() const
899 {
900     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
901 }
902
903
904 int CvDTreeTrainData::get_var_type(int vi) const
905 {
906     return var_type->data.i[vi];
907 }
908
909
910 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
911 {
912     int oi = ~get_var_type(vi);
913     assert( 0 <= oi && oi < ord_var_count );
914     return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
915                            n->offset + oi*n->sample_count*2);
916 }
917
918
919 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
920 {
921     return get_cat_var_data( n, var_count );
922 }
923
924
925 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
926 {
927     return (float*)get_cat_var_data( n, var_count );
928 }
929
930
931 int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
932 {
933     return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
934 }
935
936
937 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
938 {
939     int ci = get_var_type(vi);
940     assert( 0 <= ci && ci <= cat_var_count + 1 );
941     return buf->data.i + n->buf_idx*buf->cols + n->offset +
942            (ord_var_count*2 + ci)*n->sample_count;
943 }
944
945
946 float* CvDTreeTrainData::get_weights( CvDTreeNode* n )
947 {
948     return have_weights ?
949         (float*)get_cat_var_data( n, var_count + 1 + (params.cv_folds > 0) ) : 0;
950 }
951
952
953 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
954 {
955     int idx = n->buf_idx + 1;
956     if( idx >= buf_count )
957         idx = shared ? 1 : 0;
958     return idx;
959 }
960
961
962 /////////////////////// Decision Tree /////////////////////////
963
964 CvDTree::CvDTree()
965 {
966     data = 0;
967     var_importance = 0;
968     clear();
969 }
970
971
972 void CvDTree::clear()
973 {
974     cvReleaseMat( &var_importance );
975     if( data )
976     {
977         if( !data->shared )
978             delete data;
979         else
980             free_tree();
981         data = 0;
982     }
983     root = 0;
984     pruned_tree_idx = -1;
985 }
986
987
988 CvDTree::~CvDTree()
989 {
990     clear();
991 }
992
993
994 bool CvDTree::train( const CvMat* _train_data, int _tflag,
995                      const CvMat* _responses, const CvMat* _var_idx,
996                      const CvMat* _sample_idx, const CvMat* _var_type,
997                      const CvMat* _missing_mask, CvDTreeParams _params )
998 {
999     bool result = false;
1000
1001     CV_FUNCNAME( "CvDTree::train" );
1002
1003     __BEGIN__;
1004
1005     clear();
1006     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1007                                  _var_idx, _sample_idx, _var_type,
1008                                  _missing_mask, _params, false );
1009     CV_CALL( result = do_train(0));
1010
1011     __END__;
1012
1013     return result;
1014 }
1015
1016
1017 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1018 {
1019     bool result = false;
1020
1021     CV_FUNCNAME( "CvDTree::train" );
1022
1023     __BEGIN__;
1024
1025     clear();
1026     data = _data;
1027     data->shared = true;
1028     CV_CALL( result = do_train(_subsample_idx));
1029
1030     __END__;
1031
1032     return result;
1033 }
1034
1035
1036 bool CvDTree::do_train( const CvMat* _subsample_idx )
1037 {
1038     bool result = false;
1039
1040     CV_FUNCNAME( "CvDTree::train" );
1041
1042     __BEGIN__;
1043
1044     root = data->subsample_data( _subsample_idx );
1045
1046     try_split_node(root);
1047     
1048     if( data->params.cv_folds > 0 )
1049         prune_cv();
1050
1051     if( !data->shared )
1052         data->free_train_data();
1053
1054     result = true;
1055
1056     __END__;
1057
1058     return result;
1059 }
1060
1061
1062 #define DTREE_CAT_DIR(idx,subset) \
1063     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
1064
1065 void CvDTree::try_split_node( CvDTreeNode* node )
1066 {
1067     CvDTreeSplit* best_split = 0;
1068     int i, n = node->sample_count, vi;
1069     bool can_split = true;
1070     double quality_scale;
1071
1072     calc_node_value( node );
1073
1074     if( node->sample_count <= data->params.min_sample_count ||
1075         node->depth >= data->params.max_depth )
1076         can_split = false;
1077
1078     if( can_split && data->is_classifier )
1079     {
1080         // check if we have a "pure" node,
1081         // we assume that cls_count is filled by calc_node_value()
1082         int* cls_count = data->counts->data.i;
1083         int nz = 0, m = data->get_num_classes();
1084         for( i = 0; i < m; i++ )
1085             nz += cls_count[i] != 0;
1086         if( nz == 1 ) // there is only one class
1087             can_split = false;
1088     }
1089     else if( can_split )
1090     {
1091         const float* responses = data->get_ord_responses( node );
1092         float diff = responses[n-1] - responses[0];
1093         if( diff < data->params.regression_accuracy )
1094             can_split = false;
1095     }
1096
1097     if( can_split )
1098     {
1099         best_split = find_best_split(node);
1100         // TODO: check the split quality ...
1101         node->split = best_split;
1102     }
1103
1104     if( !can_split || !best_split )
1105     {
1106         data->free_node_data(node);
1107         return;
1108     }
1109
1110     quality_scale = calc_node_dir( node );
1111
1112     if( data->params.use_surrogates )
1113     {
1114         // find all the surrogate splits
1115         // and sort them by their similarity to the primary one
1116         for( vi = 0; vi < data->var_count; vi++ )
1117         {
1118             CvDTreeSplit* split;
1119             int ci = data->get_var_type(vi);
1120
1121             if( vi == best_split->var_idx )
1122                 continue;
1123
1124             if( ci >= 0 )
1125                 split = find_surrogate_split_cat( node, vi );
1126             else
1127                 split = find_surrogate_split_ord( node, vi );
1128
1129             if( split )
1130             {
1131                 // insert the split
1132                 CvDTreeSplit* prev_split = node->split;
1133                 split->quality = (float)(split->quality*quality_scale);
1134
1135                 while( prev_split->next &&
1136                        prev_split->next->quality > split->quality )
1137                     prev_split = prev_split->next;
1138                 split->next = prev_split->next;
1139                 prev_split->next = split;
1140             }
1141         }
1142     }
1143
1144     split_node_data( node );
1145     try_split_node( node->left );
1146     try_split_node( node->right );
1147 }
1148
1149
1150 // calculate direction (left(-1),right(1),missing(0))
1151 // for each sample using the best split
1152 // the function returns scale coefficients for surrogate split quality factors.
1153 // the scale is applied to normalize surrogate split quality relatively to the
1154 // best (primary) split quality. That is, if a surrogate split is absolutely
1155 // identical to the primary split, its quality will be set to the maximum value = 
1156 // quality of the primary split; otherwise, it will be lower.
1157 // besides, the function compute node->maxlr,
1158 // minimum possible quality (w/o considering the above mentioned scale)
1159 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1160 // are not discarded.
1161 double CvDTree::calc_node_dir( CvDTreeNode* node )
1162 {
1163     char* dir = (char*)data->direction->data.ptr;
1164     int i, n = node->sample_count, vi = node->split->var_idx;
1165     double L, R;
1166
1167     assert( !node->split->inversed );
1168
1169     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1170     {
1171         const int* labels = data->get_cat_var_data(node,vi);
1172         const int* subset = node->split->subset;
1173
1174         if( !data->have_priors )
1175         {
1176             int sum = 0, sum_abs = 0;
1177
1178             for( i = 0; i < n; i++ )
1179             {
1180                 int idx = labels[i];
1181                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1182                 sum += d; sum_abs += d & 1;
1183                 dir[i] = (char)d;
1184             }
1185
1186             R = (sum_abs + sum) >> 1;
1187             L = (sum_abs - sum) >> 1;
1188         }
1189         else
1190         {
1191             const int* responses = data->get_class_labels(node);
1192             const double* priors = data->priors->data.db;
1193             double sum = 0, sum_abs = 0;
1194
1195             for( i = 0; i < n; i++ )
1196             {
1197                 int idx = labels[i];
1198                 double w = priors[responses[i]];
1199                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1200                 sum += d*w; sum_abs += (d & 1)*w;
1201                 dir[i] = (char)d;
1202             }
1203
1204             R = (sum_abs + sum) * 0.5;
1205             L = (sum_abs - sum) * 0.5;
1206         }
1207     }
1208     else // split on ordered var
1209     {
1210         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1211         int split_point = node->split->ord.split_point;
1212         int n1 = node->get_num_valid(vi);
1213
1214         assert( 0 <= split_point && split_point < n1-1 );
1215
1216         if( !data->have_priors )
1217         {
1218             for( i = 0; i <= split_point; i++ )
1219                 dir[sorted[i].i] = (char)-1;
1220             for( ; i < n1; i++ )
1221                 dir[sorted[i].i] = (char)1;
1222             for( ; i < n; i++ )
1223                 dir[sorted[i].i] = (char)0;
1224
1225             L = split_point-1;
1226             R = n1 - split_point + 1;
1227         }
1228         else
1229         {
1230             const int* responses = data->get_class_labels(node);
1231             const double* priors = data->priors->data.db;
1232             L = R = 0;
1233
1234             for( i = 0; i <= split_point; i++ )
1235             {
1236                 int idx = sorted[i].i;
1237                 double w = priors[responses[idx]];
1238                 dir[idx] = (char)-1;
1239                 L += w;
1240             }
1241
1242             for( ; i < n1; i++ )
1243             {
1244                 int idx = sorted[i].i;
1245                 double w = priors[responses[idx]];
1246                 dir[idx] = (char)1;
1247                 R += w;
1248             }
1249
1250             for( ; i < n; i++ )
1251                 dir[sorted[i].i] = (char)0;
1252         }
1253     }
1254
1255     node->maxlr = MAX( L, R );
1256     return node->split->quality/(L + R);
1257 }
1258
1259
1260 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1261 {
1262     int vi;
1263     CvDTreeSplit *best_split = 0, *split = 0, *t;
1264
1265     for( vi = 0; vi < data->var_count; vi++ )
1266     {
1267         int ci = data->get_var_type(vi);
1268         if( node->get_num_valid(vi) <= 1 )
1269             continue;
1270
1271         if( data->is_classifier )
1272         {
1273             if( ci >= 0 )
1274                 split = find_split_cat_class( node, vi );
1275             else
1276                 split = find_split_ord_class( node, vi );
1277         }
1278         else
1279         {
1280             if( ci >= 0 )
1281                 split = find_split_cat_reg( node, vi );
1282             else
1283                 split = find_split_ord_reg( node, vi );
1284         }
1285
1286         if( split )
1287         {
1288             if( !best_split || best_split->quality < split->quality )
1289                 CV_SWAP( best_split, split, t );
1290             if( split )
1291                 cvSetRemoveByPtr( data->split_heap, split );
1292         }
1293     }
1294
1295     return best_split;
1296 }
1297
1298
1299 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1300 {
1301     const float epsilon = FLT_EPSILON*2;
1302     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1303     const int* responses = data->get_class_labels(node);
1304     int n = node->sample_count;
1305     int n1 = node->get_num_valid(vi);
1306     int m = data->get_num_classes();
1307     const int* rc0 = data->counts->data.i;
1308     int* lc = (int*)(rc0 + m);
1309     int* rc = lc + m;
1310     int i, best_i = -1;
1311     double lsum2 = 0, rsum2 = 0, best_val = 0;
1312     const double* priors = data->have_priors ? data->priors->data.db : 0;
1313
1314     // init arrays of class instance counters on both sides of the split
1315     for( i = 0; i < m; i++ )
1316     {
1317         lc[i] = 0;
1318         rc[i] = rc0[i];
1319     }
1320
1321     // compensate for missing values
1322     for( i = n1; i < n; i++ )
1323         rc[responses[sorted[i].i]]--;
1324
1325     if( !priors )
1326     {
1327         int L = 0, R = n1;
1328
1329         for( i = 0; i < m; i++ )
1330             rsum2 += (double)rc[i]*rc[i];
1331
1332         for( i = 0; i < n1 - 1; i++ )
1333         {
1334             int idx = responses[sorted[i].i];
1335             int lv, rv;
1336             L++; R--;
1337             lv = lc[idx]; rv = rc[idx];
1338             lsum2 += lv*2 + 1;
1339             rsum2 -= rv*2 - 1;
1340             lc[idx] = lv + 1; rc[idx] = rv - 1;
1341
1342             if( sorted[i].val + epsilon < sorted[i+1].val )
1343             {
1344                 double val = lsum2/L + rsum2/R;
1345                 if( best_val < val )
1346                 {
1347                     best_val = val;
1348                     best_i = i;
1349                 }
1350             }
1351         }
1352     }
1353     else
1354     {
1355         double L = 0, R = 0;
1356         for( i = 0; i < m; i++ )
1357         {
1358             double wv = rc[i]*priors[i];
1359             R += wv;
1360             rsum2 += wv*wv;
1361         }
1362
1363         for( i = 0; i < n1 - 1; i++ )
1364         {
1365             int idx = responses[sorted[i].i];
1366             int lv, rv;
1367             double p = priors[idx], p2 = p*p;
1368             L += p; R -= p;
1369             lv = lc[idx]; rv = rc[idx];
1370             lsum2 += p2*(lv*2 + 1);
1371             rsum2 -= p2*(rv*2 - 1);
1372             lc[idx] = lv + 1; rc[idx] = rv - 1;
1373
1374             if( sorted[i].val + epsilon < sorted[i+1].val )
1375             {
1376                 double val = lsum2/L + rsum2/R;
1377                 if( best_val < val )
1378                 {
1379                     best_val = val;
1380                     best_i = i;
1381                 }
1382             }
1383         }
1384     }
1385
1386     return best_i >= 0 ? data->new_split_ord( vi,
1387         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1388         0, (float)best_val ) : 0;
1389 }
1390
1391
1392 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1393                                 int* csums, int k, int* labels )
1394 {
1395     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1396     int iters = 0, max_iters = 100;
1397     int i, j, idx;
1398     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1399     double *v_weights = buf, *c_weights = buf + k;
1400     bool modified = true;
1401     CvRNG* r = &data->rng;
1402
1403     // assign labels randomly
1404     for( i = idx = 0; i < n; i++ )
1405     {
1406         int sum = 0;
1407         const int* v = vectors + i*m;
1408         labels[i] = idx++;
1409         idx &= idx < k ? -1 : 0;
1410
1411         // compute weight of each vector
1412         for( j = 0; j < m; j++ )
1413             sum += v[j];
1414         v_weights[i] = sum ? 1./sum : 0.;
1415     }
1416
1417     for( i = 0; i < n; i++ )
1418     {
1419         int i1 = cvRandInt(r) % n;
1420         int i2 = cvRandInt(r) % n;
1421         CV_SWAP( labels[i1], labels[i2], j );
1422     }
1423
1424     for( iters = 0; iters <= max_iters; iters++ )
1425     {
1426         // calculate csums
1427         for( i = 0; i < k; i++ )
1428         {
1429             for( j = 0; j < m; j++ )
1430                 csums[i*m + j] = 0;
1431         }
1432
1433         for( i = 0; i < n; i++ )
1434         {
1435             const int* v = vectors + i*m;
1436             int* s = csums + labels[i]*m;
1437             for( j = 0; j < m; j++ )
1438                 s[j] += v[j];
1439         }
1440
1441         // exit the loop here, when we have up-to-date csums
1442         if( iters == max_iters || !modified )
1443             break;
1444
1445         modified = false;
1446
1447         // calculate weight of each cluster
1448         for( i = 0; i < k; i++ )
1449         {
1450             const int* s = csums + i*m;
1451             int sum = 0;
1452             for( j = 0; j < m; j++ )
1453                 sum += s[j];
1454             c_weights[i] = sum ? 1./sum : 0;
1455         }
1456
1457         // now for each vector determine the closest cluster
1458         for( i = 0; i < n; i++ )
1459         {
1460             const int* v = vectors + i*m;
1461             double alpha = v_weights[i];
1462             double min_dist2 = DBL_MAX;
1463             int min_idx = -1;
1464
1465             for( idx = 0; idx < k; idx++ )
1466             {
1467                 const int* s = csums + idx*m;
1468                 double dist2 = 0., beta = c_weights[idx];
1469                 for( j = 0; j < m; j++ )
1470                 {
1471                     double t = v[j]*alpha - s[j]*beta;
1472                     dist2 += t*t;
1473                 }
1474                 if( min_dist2 > dist2 )
1475                 {
1476                     min_dist2 = dist2;
1477                     min_idx = idx;
1478                 }
1479             }
1480
1481             if( min_idx != labels[i] )
1482                 modified = true;
1483             labels[i] = min_idx;
1484         }
1485     }
1486 }
1487
1488
1489 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1490 {
1491     CvDTreeSplit* split;
1492     const int* labels = data->get_cat_var_data(node, vi);
1493     const int* responses = data->get_class_labels(node);
1494     int ci = data->get_var_type(vi);
1495     int n = node->sample_count;
1496     int m = data->get_num_classes();
1497     int _mi = data->cat_count->data.i[ci], mi = _mi;
1498     const int* rc0 = data->counts->data.i;
1499     int* lc = (int*)(rc0 + m);
1500     int* rc = lc + m;
1501     int* _cjk = rc + m*2, *cjk = _cjk;
1502     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1503     int* cluster_labels = 0;
1504     int** int_ptr = 0;
1505     int i, j, k, idx;
1506     double L = 0, R = 0;
1507     double best_val = 0;
1508     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1509     const double* priors = data->priors->data.db;
1510
1511     // init array of counters:
1512     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1513     for( j = -1; j < mi; j++ )
1514         for( k = 0; k < m; k++ )
1515             cjk[j*m + k] = 0;
1516
1517     for( i = 0; i < n; i++ )
1518     {
1519         int j = labels[i];
1520         int k = responses[i];
1521         cjk[j*m + k]++;
1522     }
1523
1524     if( m > 2 )
1525     {
1526         if( mi > data->params.max_categories )
1527         {
1528             mi = MIN(data->params.max_categories, n);
1529             cjk += _mi*m;
1530             cluster_labels = cjk + mi*m;
1531             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1532         }
1533         subset_i = 1;
1534         subset_n = 1 << mi;
1535     }
1536     else
1537     {
1538         assert( m == 2 );
1539         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1540         for( j = 0; j < mi; j++ )
1541             int_ptr[j] = cjk + j*2 + 1;
1542         icvSortIntPtr( int_ptr, mi, 0 );
1543         subset_i = 0;
1544         subset_n = mi;
1545     }
1546
1547     for( k = 0; k < m; k++ )
1548     {
1549         int sum = 0;
1550         for( j = 0; j < mi; j++ )
1551             sum += cjk[j*m + k];
1552         rc[k] = sum;
1553         lc[k] = 0;
1554     }
1555
1556     for( j = 0; j < mi; j++ )
1557     {
1558         double sum = 0;
1559         for( k = 0; k < m; k++ )
1560             sum += cjk[j*m + k]*priors[k];
1561         c_weights[j] = sum;
1562         R += c_weights[j];
1563     }
1564
1565     for( ; subset_i < subset_n; subset_i++ )
1566     {
1567         double weight;
1568         int* crow;
1569         double lsum2 = 0, rsum2 = 0;
1570
1571         if( m == 2 )
1572             idx = (int)(int_ptr[subset_i] - cjk)/2;
1573         else
1574         {
1575             int graycode = (subset_i>>1)^subset_i;
1576             int diff = graycode ^ prevcode;
1577
1578             // determine index of the changed bit.
1579             Cv32suf u;
1580             idx = diff >= (1 << 16) ? 16 : 0;
1581             u.f = (float)(((diff >> 16) | diff) & 65535);
1582             idx += (u.i >> 23) - 127;
1583             subtract = graycode < prevcode;
1584             prevcode = graycode;
1585         }
1586
1587         crow = cjk + idx*m;
1588         weight = c_weights[idx];
1589         if( weight < FLT_EPSILON )
1590             continue;
1591
1592         if( !subtract )
1593         {
1594             for( k = 0; k < m; k++ )
1595             {
1596                 int t = crow[k];
1597                 int lval = lc[k] + t;
1598                 int rval = rc[k] - t;
1599                 double p = priors[k], p2 = p*p;
1600                 lsum2 += p2*lval*lval;
1601                 rsum2 += p2*rval*rval;
1602                 lc[k] = lval; rc[k] = rval;
1603             }
1604             L += weight;
1605             R -= weight;
1606         }
1607         else
1608         {
1609             for( k = 0; k < m; k++ )
1610             {
1611                 int t = crow[k];
1612                 int lval = lc[k] - t;
1613                 int rval = rc[k] + t;
1614                 double p = priors[k], p2 = p*p;
1615                 lsum2 += p2*lval*lval;
1616                 rsum2 += p2*rval*rval;
1617                 lc[k] = lval; rc[k] = rval;
1618             }
1619             L -= weight;
1620             R += weight;
1621         }
1622
1623         if( L > FLT_EPSILON && R > FLT_EPSILON )
1624         {
1625             double val = lsum2/L + rsum2/R;
1626             if( best_val < val )
1627             {
1628                 best_val = val;
1629                 best_subset = subset_i;
1630             }
1631         }
1632     }
1633
1634     if( best_subset < 0 )
1635         return 0;
1636
1637     split = data->new_split_cat( vi, (float)best_val );
1638
1639     if( m == 2 )
1640     {
1641         for( i = 0; i <= best_subset; i++ )
1642         {
1643             idx = (int)(int_ptr[i] - cjk) >> 1;
1644             split->subset[idx >> 5] |= 1 << (idx & 31);
1645         }
1646     }
1647     else
1648     {
1649         for( i = 0; i < _mi; i++ )
1650         {
1651             idx = cluster_labels ? cluster_labels[i] : i;
1652             if( best_subset & (1 << idx) )
1653                 split->subset[i >> 5] |= 1 << (i & 31);
1654         }
1655     }
1656
1657     return split;
1658 }
1659
1660
1661 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1662 {
1663     const float epsilon = FLT_EPSILON*2;
1664     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1665     const float* responses = data->get_ord_responses(node);
1666     int n = node->sample_count;
1667     int n1 = node->get_num_valid(vi);
1668     int i, best_i = -1;
1669     double best_val = 0, lsum = 0, rsum = node->value*n;
1670     int L = 0, R = n1;
1671
1672     // compensate for missing values
1673     for( i = n1; i < n; i++ )
1674         rsum -= responses[sorted[i].i];
1675
1676     // find the optimal split
1677     for( i = 0; i < n1 - 1; i++ )
1678     {
1679         float val = responses[sorted[i].i];
1680         L++; R--;
1681         lsum += val;
1682         rsum -= val;
1683
1684         if( sorted[i].val + epsilon < sorted[i+1].val )
1685         {
1686             double val = lsum*lsum/L + rsum*rsum/R;
1687             if( best_val < val )
1688             {
1689                 best_val = val;
1690                 best_i = i;
1691             }
1692         }
1693     }
1694
1695     return best_i >= 0 ? data->new_split_ord( vi,
1696         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1697         0, (float)best_val ) : 0;
1698 }
1699
1700
1701 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1702 {
1703     CvDTreeSplit* split;
1704     const int* labels = data->get_cat_var_data(node, vi);
1705     const float* responses = data->get_ord_responses(node);
1706     int ci = data->get_var_type(vi);
1707     int n = node->sample_count;
1708     int mi = data->cat_count->data.i[ci];
1709     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1710     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1711     double** sum_ptr = 0;
1712     int i, L = 0, R = 0;
1713     double best_val = 0, lsum = 0, rsum = 0;
1714     int best_subset = -1, subset_i;
1715
1716     for( i = -1; i < mi; i++ )
1717         sum[i] = counts[i] = 0;
1718
1719     // calculate sum response and weight of each category of the input var
1720     for( i = 0; i < n; i++ )
1721     {
1722         int idx = labels[i];
1723         double s = sum[idx] + responses[i];
1724         int nc = counts[idx] + 1;
1725         sum[idx] = s;
1726         counts[idx] = nc;
1727     }
1728
1729     // calculate average response in each category
1730     for( i = 0; i < mi; i++ )
1731     {
1732         R += counts[i];
1733         rsum += sum[i];
1734         sum[i] /= MAX(counts[i],1);
1735         sum_ptr[i] = sum + i;
1736     }
1737
1738     icvSortDblPtr( sum_ptr, mi, 0 );
1739
1740     // revert back to unnormalized sum
1741     // (there should be a very little loss of accuracy)
1742     for( i = 0; i < mi; i++ )
1743         sum[i] *= counts[i];
1744
1745     for( subset_i = 0; subset_i < mi-1; subset_i++ )
1746     {
1747         int idx = (int)(sum_ptr[subset_i] - sum);
1748         int ni = counts[idx];
1749
1750         if( ni )
1751         {
1752             double s = sum[idx];
1753             lsum += s; L += ni;
1754             rsum -= s; R -= ni;
1755             
1756             if( L && R )
1757             {
1758                 double val = lsum*lsum/L + rsum*rsum/R;
1759                 if( best_val < val )
1760                 {
1761                     best_val = val;
1762                     best_subset = subset_i;
1763                 }
1764             }
1765         }
1766     }
1767
1768     if( best_subset < 0 )
1769         return 0;
1770
1771     split = data->new_split_cat( vi, (float)best_val );
1772     for( i = 0; i <= best_subset; i++ )
1773     {
1774         int idx = (int)(sum_ptr[i] - sum);
1775         split->subset[idx >> 5] |= 1 << (idx & 31);
1776     }
1777
1778     return split;
1779 }
1780
1781
1782 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
1783 {
1784     const float epsilon = FLT_EPSILON*2;
1785     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1786     const char* dir = (char*)data->direction->data.ptr;
1787     int n1 = node->get_num_valid(vi);
1788     // LL - number of samples that both the primary and the surrogate splits send to the left
1789     // LR - ... primary split sends to the left and the surrogate split sends to the right
1790     // RL - ... primary split sends to the right and the surrogate split sends to the left
1791     // RR - ... both send to the right
1792     int i, best_i = -1, best_inversed = 0;
1793     double best_val; 
1794
1795     if( !data->have_priors )
1796     {
1797         int LL = 0, RL = 0, LR, RR;
1798         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
1799         int sum = 0, sum_abs = 0;
1800         
1801         for( i = 0; i < n1; i++ )
1802         {
1803             int d = dir[sorted[i].i];
1804             sum += d; sum_abs += d & 1;
1805         }
1806
1807         // sum_abs = R + L; sum = R - L
1808         RR = (sum_abs + sum) >> 1;
1809         LR = (sum_abs - sum) >> 1;
1810
1811         // initially all the samples are sent to the right by the surrogate split,
1812         // LR of them are sent to the left by primary split, and RR - to the right.
1813         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1814         for( i = 0; i < n1 - 1; i++ )
1815         {
1816             int d = dir[sorted[i].i];
1817
1818             if( d < 0 )
1819             {
1820                 LL++; LR--;
1821                 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1822                 {
1823                     best_val = LL + RR;
1824                     best_i = i; best_inversed = 0;
1825                 }
1826             }
1827             else if( d > 0 )
1828             {
1829                 RL++; RR--;
1830                 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1831                 {
1832                     best_val = RL + LR;
1833                     best_i = i; best_inversed = 1;
1834                 }
1835             }
1836         }
1837         best_val = _best_val;
1838     }
1839     else
1840     {
1841         double LL = 0, RL = 0, LR, RR;
1842         double worst_val = node->maxlr;
1843         double sum = 0, sum_abs = 0;
1844         const double* priors = data->priors->data.db;
1845         const int* responses = data->get_class_labels(node);
1846         best_val = worst_val;
1847         
1848         for( i = 0; i < n1; i++ )
1849         {
1850             int idx = sorted[i].i;
1851             double w = priors[responses[idx]];
1852             int d = dir[idx];
1853             sum += d*w; sum_abs += (d & 1)*w;
1854         }
1855
1856         // sum_abs = R + L; sum = R - L
1857         RR = (sum_abs + sum)*0.5;
1858         LR = (sum_abs - sum)*0.5;
1859
1860         // initially all the samples are sent to the right by the surrogate split,
1861         // LR of them are sent to the left by primary split, and RR - to the right.
1862         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1863         for( i = 0; i < n1 - 1; i++ )
1864         {
1865             int idx = sorted[i].i;
1866             double w = priors[responses[idx]];
1867             int d = dir[idx];
1868
1869             if( d < 0 )
1870             {
1871                 LL += w; LR -= w;
1872                 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1873                 {
1874                     best_val = LL + RR;
1875                     best_i = i; best_inversed = 0;
1876                 }
1877             }
1878             else if( d > 0 )
1879             {
1880                 RL += w; RR -= w;
1881                 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1882                 {
1883                     best_val = RL + LR;
1884                     best_i = i; best_inversed = 1;
1885                 }
1886             }
1887         }
1888     }
1889
1890     return best_i >= 0 ? data->new_split_ord( vi,
1891         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1892         best_inversed, (float)best_val ) : 0;
1893 }
1894
1895
1896 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
1897 {
1898     const int* labels = data->get_cat_var_data(node, vi);
1899     const char* dir = (char*)data->direction->data.ptr;
1900     int n = node->sample_count;
1901     // LL - number of samples that both the primary and the surrogate splits send to the left
1902     // LR - ... primary split sends to the left and the surrogate split sends to the right
1903     // RL - ... primary split sends to the right and the surrogate split sends to the left
1904     // RR - ... both send to the right
1905     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
1906     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
1907     double best_val = 0;
1908     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
1909     double* rc = lc + mi + 1;
1910     
1911     for( i = -1; i < mi; i++ )
1912         lc[i] = rc[i] = 0;
1913
1914     // for each category calculate the weight of samples
1915     // sent to the left (lc) and to the right (rc) by the primary split
1916     if( !data->have_priors )
1917     {
1918         int* _lc = data->counts->data.i + 1;
1919         int* _rc = _lc + mi + 1;
1920
1921         for( i = -1; i < mi; i++ )
1922             _lc[i] = _rc[i] = 0;
1923
1924         for( i = 0; i < n; i++ )
1925         {
1926             int idx = labels[i];
1927             int d = dir[i];
1928             int sum = _lc[idx] + d;
1929             int sum_abs = _rc[idx] + (d & 1);
1930             _lc[idx] = sum; _rc[idx] = sum_abs;
1931         }
1932
1933         for( i = 0; i < mi; i++ )
1934         {
1935             int sum = _lc[i];
1936             int sum_abs = _rc[i];
1937             lc[i] = (sum_abs - sum) >> 1;
1938             rc[i] = (sum_abs + sum) >> 1;
1939         }
1940     }
1941     else
1942     {
1943         const double* priors = data->priors->data.db;
1944         const int* responses = data->get_class_labels(node);
1945
1946         for( i = 0; i < n; i++ )
1947         {
1948             int idx = labels[i];
1949             double w = priors[responses[i]];
1950             int d = dir[i];
1951             double sum = lc[idx] + d*w;
1952             double sum_abs = rc[idx] + (d & 1)*w;
1953             lc[idx] = sum; rc[idx] = sum_abs;
1954         }
1955
1956         for( i = 0; i < mi; i++ )
1957         {
1958             double sum = lc[i];
1959             double sum_abs = rc[i];
1960             lc[i] = (sum_abs - sum) * 0.5;
1961             rc[i] = (sum_abs + sum) * 0.5;
1962         }
1963     }
1964
1965     // 2. now form the split.
1966     // in each category send all the samples to the same direction as majority
1967     for( i = 0; i < mi; i++ )
1968     {
1969         double lval = lc[i], rval = rc[i];
1970         if( lval > rval )
1971         {
1972             split->subset[i >> 5] |= 1 << (i & 31);
1973             best_val += lval;
1974         }
1975         else
1976             best_val += rval;
1977     }
1978
1979     split->quality = (float)best_val;
1980     if( split->quality <= node->maxlr )
1981         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
1982
1983     return split;
1984 }
1985
1986
1987 void CvDTree::calc_node_value( CvDTreeNode* node )
1988 {
1989     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
1990     const int* cv_labels = data->get_cv_labels(node);
1991
1992     if( data->is_classifier )
1993     {
1994         // in case of classification tree:
1995         //  * node value is the label of the class that has the largest weight in the node.
1996         //  * node risk is the weighted number of misclassified samples,
1997         //  * j-th cross-validation fold value and risk are calculated as above,
1998         //    but using the samples with cv_labels(*)!=j.
1999         //  * j-th cross-validation fold error is calculated as the weighted number of
2000         //    misclassified samples with cv_labels(*)==j.
2001
2002         // compute the number of instances of each class
2003         int* cls_count = data->counts->data.i;
2004         const int* responses = data->get_class_labels(node);
2005         int m = data->get_num_classes();
2006         int* cv_cls_count = cls_count + m;
2007         double max_val = -1, total_weight = 0;
2008         int max_k = -1;
2009         double* priors = data->priors->data.db;
2010
2011         for( k = 0; k < m; k++ )
2012             cls_count[k] = 0;
2013
2014         if( cv_n == 0 )
2015         {
2016             for( i = 0; i < n; i++ )
2017                 cls_count[responses[i]]++;
2018         }
2019         else
2020         {
2021             for( j = 0; j < cv_n; j++ )
2022                 for( k = 0; k < m; k++ )
2023                     cv_cls_count[j*m + k] = 0;
2024
2025             for( i = 0; i < n; i++ )
2026             {
2027                 j = cv_labels[i]; k = responses[i];
2028                 cv_cls_count[j*m + k]++;
2029             }
2030
2031             for( j = 0; j < cv_n; j++ )
2032                 for( k = 0; k < m; k++ )
2033                     cls_count[k] += cv_cls_count[j*m + k];
2034         }
2035
2036         for( k = 0; k < m; k++ )
2037         {
2038             double val = cls_count[k]*priors[k];
2039             total_weight += val;
2040             if( max_val < val )
2041             {
2042                 max_val = val;
2043                 max_k = k;
2044             }
2045         }
2046
2047         node->class_idx = max_k;
2048         node->value = data->cat_map->data.i[
2049             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2050         node->node_risk = total_weight - max_val;
2051
2052         for( j = 0; j < cv_n; j++ )
2053         {
2054             double sum_k = 0, sum = 0, max_val_k = 0;
2055             max_val = -1; max_k = -1;
2056
2057             for( k = 0; k < m; k++ )
2058             {
2059                 double w = priors[k];
2060                 double val_k = cv_cls_count[j*m + k]*w;
2061                 double val = cls_count[k]*w - val_k;
2062                 sum_k += val_k;
2063                 sum += val;
2064                 if( max_val < val )
2065                 {
2066                     max_val = val;
2067                     max_val_k = val_k;
2068                     max_k = k;
2069                 }
2070             }
2071
2072             node->cv_Tn[j] = INT_MAX;
2073             node->cv_node_risk[j] = sum - max_val;
2074             node->cv_node_error[j] = sum_k - max_val_k;
2075         }
2076     }
2077     else
2078     {
2079         // in case of regression tree:
2080         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2081         //    n is the number of samples in the node.
2082         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2083         //  * j-th cross-validation fold value and risk are calculated as above,
2084         //    but using the samples with cv_labels(*)!=j.
2085         //  * j-th cross-validation fold error is calculated
2086         //    using samples with cv_labels(*)==j as the test subset:
2087         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2088         //    where node_value_j is the node value calculated
2089         //    as described in the previous bullet, and summation is done
2090         //    over the samples with cv_labels(*)==j.
2091
2092         double sum = 0, sum2 = 0;
2093         const float* values = data->get_ord_responses(node);
2094         double *cv_sum = 0, *cv_sum2 = 0;
2095         int* cv_count = 0;
2096         
2097         if( cv_n == 0 )
2098         {
2099             // if cross-validation is not used, we even do not compute node_risk
2100             // (so the tree sequence T1>...>{root} may not be built).
2101             for( i = 0; i < n; i++ )
2102                 sum += values[i];
2103         }
2104         else
2105         {
2106             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2107             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2108             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2109
2110             for( j = 0; j < cv_n; j++ )
2111             {
2112                 cv_sum[j] = cv_sum2[j] = 0.;
2113                 cv_count[j] = 0;
2114             }
2115
2116             for( i = 0; i < n; i++ )
2117             {
2118                 j = cv_labels[i];
2119                 double t = values[i];
2120                 double s = cv_sum[j] + t;
2121                 double s2 = cv_sum2[j] + t*t;
2122                 int nc = cv_count[j] + 1;
2123                 cv_sum[j] = s;
2124                 cv_sum2[j] = s2;
2125                 cv_count[j] = nc;
2126             }
2127
2128             for( j = 0; j < cv_n; j++ )
2129             {
2130                 sum += cv_sum[j];
2131                 sum2 += cv_sum2[j];
2132             }
2133
2134             node->node_risk = sum2 - (sum/n)*sum;
2135         }
2136
2137         node->value = sum/n;
2138
2139         for( j = 0; j < cv_n; j++ )
2140         {
2141             double s = cv_sum[j], si = sum - s;
2142             double s2 = cv_sum2[j], s2i = sum2 - s2;
2143             int c = cv_count[j], ci = n - c;
2144             double r = si/MAX(ci,1);
2145             node->cv_node_risk[j] = s2i - r*r*ci;
2146             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2147             node->cv_Tn[j] = INT_MAX;
2148         }
2149     }
2150 }
2151
2152
2153 void CvDTree::complete_node_dir( CvDTreeNode* node )
2154 {
2155     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2156     int nz = n - node->get_num_valid(node->split->var_idx);
2157     char* dir = (char*)data->direction->data.ptr;
2158
2159     // try to complete direction using surrogate splits
2160     if( nz && data->params.use_surrogates )
2161     {
2162         CvDTreeSplit* split = node->split->next;
2163         for( ; split != 0 && nz; split = split->next )
2164         {
2165             int inversed_mask = split->inversed ? -1 : 0;
2166             vi = split->var_idx;
2167
2168             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2169             {
2170                 const int* labels = data->get_cat_var_data(node, vi);
2171                 const int* subset = split->subset;
2172
2173                 for( i = 0; i < n; i++ )
2174                 {
2175                     int idx;
2176                     if( !dir[i] && (idx = labels[i]) >= 0 )
2177                     {
2178                         int d = DTREE_CAT_DIR(idx,subset);
2179                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2180                         if( --nz )
2181                             break;
2182                     }
2183                 }
2184             }
2185             else // split on ordered var
2186             {
2187                 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2188                 int split_point = split->ord.split_point;
2189                 int n1 = node->get_num_valid(vi);
2190
2191                 assert( 0 <= split_point && split_point < n-1 );
2192
2193                 for( i = 0; i < n1; i++ )
2194                 {
2195                     int idx = sorted[i].i;
2196                     if( !dir[idx] )
2197                     {
2198                         int d = i <= split_point ? -1 : 1;
2199                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2200                         if( --nz )
2201                             break;
2202                     }
2203                 }
2204             }
2205         }
2206     }
2207
2208     // find the default direction for the rest
2209     if( nz )
2210     {
2211         for( i = nr = 0; i < n; i++ )
2212             nr += dir[i] > 0;
2213         nl = n - nr - nz;
2214         d0 = nl > nr ? -1 : nr > nl;
2215     }
2216
2217     // make sure that every sample is directed either to the left or to the right
2218     for( i = 0; i < n; i++ )
2219     {
2220         int d = dir[i];
2221         if( !d )
2222         {
2223             d = d0;
2224             if( !d )
2225                 d = d1, d1 = -d1;
2226         }
2227         d = d > 0;
2228         dir[i] = (char)d; // remap (-1,1) to (0,1)
2229     }
2230 }
2231
2232
2233 void CvDTree::split_node_data( CvDTreeNode* node )
2234 {
2235     int vi, i, n = node->sample_count, nl, nr;
2236     char* dir = (char*)data->direction->data.ptr;
2237     CvDTreeNode *left = 0, *right = 0;
2238     int* new_idx = data->split_buf->data.i;
2239     int new_buf_idx = data->get_child_buf_idx( node );
2240
2241     complete_node_dir(node);
2242
2243     for( i = nl = nr = 0; i < n; i++ )
2244     {
2245         int d = dir[i];
2246         // initialize new indices for splitting ordered variables
2247         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2248         nr += d;
2249         nl += d^1;
2250     }
2251
2252     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2253     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2254         (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
2255
2256     // split ordered variables, keep both halves sorted.
2257     for( vi = 0; vi < data->var_count; vi++ )
2258     {
2259         int ci = data->get_var_type(vi);
2260         int n1 = node->get_num_valid(vi);
2261         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2262         CvPair32s32f tl, tr;
2263
2264         if( ci >= 0 )
2265             continue;
2266
2267         src = data->get_ord_var_data(node, vi);
2268         ldst0 = ldst = data->get_ord_var_data(left, vi);
2269         rdst0 = rdst = data->get_ord_var_data(right, vi);
2270         tl = ldst0[nl]; tr = rdst0[nr];
2271
2272         // split sorted
2273         for( i = 0; i < n1; i++ )
2274         {
2275             int idx = src[i].i;
2276             float val = src[i].val;
2277             int d = dir[idx];
2278             idx = new_idx[idx];
2279             ldst->i = rdst->i = idx;
2280             ldst->val = rdst->val = val;
2281             ldst += d^1;
2282             rdst += d;
2283         }
2284
2285         left->set_num_valid(vi, (int)(ldst - ldst0));
2286         right->set_num_valid(vi, (int)(rdst - rdst0));
2287
2288         // split missing
2289         for( ; i < n; i++ )
2290         {
2291             int idx = src[i].i;
2292             int d = dir[idx];
2293             idx = new_idx[idx];
2294             ldst->i = rdst->i = idx;
2295             ldst->val = rdst->val = ord_nan;
2296             ldst += d^1;
2297             rdst += d;
2298         }
2299
2300         ldst0[nl] = tl; rdst0[nr] = tr;
2301     }
2302
2303     // split categorical vars, responses and cv_labels using new_idx relocation table
2304     for( vi = 0; vi <= data->var_count + data->have_cv_labels + data->have_weights; vi++ )
2305     {
2306         int ci = data->get_var_type(vi);
2307         int n1 = node->get_num_valid(vi), nr1 = 0;
2308         int *src, *ldst0, *rdst0, *ldst, *rdst;
2309         int tl, tr;
2310
2311         if( ci < 0 )
2312             continue;
2313
2314         src = data->get_cat_var_data(node, vi);
2315         ldst0 = ldst = data->get_cat_var_data(left, vi);
2316         rdst0 = rdst = data->get_cat_var_data(right, vi);
2317         tl = ldst0[nl]; tr = rdst0[nr];
2318
2319         for( i = 0; i < n; i++ )
2320         {
2321             int d = dir[i];
2322             int val = src[i];
2323             *ldst = *rdst = val;
2324             ldst += d^1;
2325             rdst += d;
2326             nr1 += (val >= 0)&d;
2327         }
2328
2329         if( vi < data->var_count )
2330         {
2331             left->set_num_valid(vi, n1 - nr1);
2332             right->set_num_valid(vi, nr1);
2333         }
2334
2335         ldst0[nl] = tl; rdst0[nr] = tr;
2336     }
2337
2338     // deallocate the parent node data that is not needed anymore
2339     data->free_node_data(node);
2340 }
2341
2342
2343 void CvDTree::prune_cv()
2344 {
2345     CvMat* ab = 0;
2346     CvMat* temp = 0;
2347     CvMat* err_jk = 0;
2348     
2349     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2350     // 2. choose the best tree index (if need, apply 1SE rule).
2351     // 3. store the best index and cut the branches.
2352
2353     CV_FUNCNAME( "CvDTree::prune_cv" );
2354
2355     __BEGIN__;
2356
2357     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2358     // currently, 1SE for regression is not implemented
2359     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2360     double* err;
2361     double min_err = 0, min_err_se = 0;
2362     int min_idx = -1;
2363     
2364     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2365
2366     // build the main tree sequence, calculate alpha's
2367     for(;;tree_count++)
2368     {
2369         double min_alpha = update_tree_rnc(tree_count, -1);
2370         if( cut_tree(tree_count, -1, min_alpha) )
2371             break;
2372
2373         if( ab->cols <= tree_count )
2374         {
2375             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2376             for( ti = 0; ti < ab->cols; ti++ )
2377                 temp->data.db[ti] = ab->data.db[ti];
2378             cvReleaseMat( &ab );
2379             ab = temp;
2380             temp = 0;
2381         }
2382
2383         ab->data.db[tree_count] = min_alpha;
2384     }
2385
2386     ab->data.db[0] = 0.;
2387     for( ti = 1; ti < tree_count-1; ti++ )
2388         ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2389     ab->data.db[tree_count-1] = DBL_MAX*0.5;
2390
2391     CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2392     err = err_jk->data.db;
2393
2394     for( j = 0; j < cv_n; j++ )
2395     {
2396         int tj = 0, tk = 0;
2397         for( ; tk < tree_count; tj++ )
2398         {
2399             double min_alpha = update_tree_rnc(tj, j);
2400             if( cut_tree(tj, j, min_alpha) )
2401                 min_alpha = DBL_MAX;
2402
2403             for( ; tk < tree_count; tk++ )
2404             {
2405                 if( ab->data.db[tk] > min_alpha )
2406                     break;
2407                 err[j*tree_count + tk] = root->tree_error;
2408             }
2409         }
2410     }
2411
2412     for( ti = 0; ti < tree_count; ti++ )
2413     {
2414         double sum_err = 0;
2415         for( j = 0; j < cv_n; j++ )
2416             sum_err += err[j*tree_count + ti];
2417         if( ti == 0 || sum_err < min_err )
2418         {
2419             min_err = sum_err;
2420             min_idx = ti;
2421             if( use_1se )
2422                 min_err_se = sqrt( sum_err*(n - sum_err) );
2423         }
2424         else if( sum_err < min_err + min_err_se )
2425             min_idx = ti;
2426     }
2427
2428     pruned_tree_idx = min_idx;
2429     free_prune_data(data->params.truncate_pruned_tree != 0);
2430
2431     __END__;
2432
2433     cvReleaseMat( &err_jk );
2434     cvReleaseMat( &ab );
2435     cvReleaseMat( &temp );
2436 }
2437
2438
2439 double CvDTree::update_tree_rnc( int T, int fold )
2440 {
2441     CvDTreeNode* node = root;
2442     double min_alpha = DBL_MAX;
2443     
2444     for(;;)
2445     {
2446         CvDTreeNode* parent;
2447         for(;;)
2448         {
2449             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2450             if( t <= T || !node->left )
2451             {
2452                 node->complexity = 1;
2453                 node->tree_risk = node->node_risk;
2454                 node->tree_error = 0.;
2455                 if( fold >= 0 )
2456                 {
2457                     node->tree_risk = node->cv_node_risk[fold];
2458                     node->tree_error = node->cv_node_error[fold];
2459                 }
2460                 break;
2461             }
2462             node = node->left;
2463         }
2464         
2465         for( parent = node->parent; parent && parent->right == node;
2466             node = parent, parent = parent->parent )
2467         {
2468             parent->complexity += node->complexity;
2469             parent->tree_risk += node->tree_risk;
2470             parent->tree_error += node->tree_error;
2471
2472             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2473                 - parent->tree_risk)/(parent->complexity - 1);
2474             min_alpha = MIN( min_alpha, parent->alpha );
2475         }
2476
2477         if( !parent )
2478             break;
2479
2480         parent->complexity = node->complexity;
2481         parent->tree_risk = node->tree_risk;
2482         parent->tree_error = node->tree_error;
2483         node = parent->right;
2484     }
2485
2486     return min_alpha;
2487 }
2488
2489
2490 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2491 {
2492     CvDTreeNode* node = root;
2493     if( !node->left )
2494         return 1;
2495
2496     for(;;)
2497     {
2498         CvDTreeNode* parent;
2499         for(;;)
2500         {
2501             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2502             if( t <= T || !node->left )
2503                 break;
2504             if( node->alpha <= min_alpha + FLT_EPSILON )
2505             {
2506                 if( fold >= 0 )
2507                     node->cv_Tn[fold] = T;
2508                 else
2509                     node->Tn = T;
2510                 if( node == root )
2511                     return 1;
2512                 break;
2513             }
2514             node = node->left;
2515         }
2516         
2517         for( parent = node->parent; parent && parent->right == node;
2518             node = parent, parent = parent->parent )
2519             ;
2520
2521         if( !parent )
2522             break;
2523
2524         node = parent->right;
2525     }
2526
2527     return 0;
2528 }
2529
2530
2531 void CvDTree::free_prune_data(bool cut_tree)
2532 {
2533     CvDTreeNode* node = root;
2534     
2535     for(;;)
2536     {
2537         CvDTreeNode* parent;
2538         for(;;)
2539         {
2540             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2541             // as we will clear the whole cross-validation heap at the end
2542             node->cv_Tn = 0;
2543             node->cv_node_error = node->cv_node_risk = 0;
2544             if( !node->left )
2545                 break;
2546             node = node->left;
2547         }
2548         
2549         for( parent = node->parent; parent && parent->right == node;
2550             node = parent, parent = parent->parent )
2551         {
2552             if( cut_tree && parent->Tn <= pruned_tree_idx )
2553             {
2554                 data->free_node( parent->left );
2555                 data->free_node( parent->right );
2556                 parent->left = parent->right = 0;
2557             }
2558         }
2559
2560         if( !parent )
2561             break;
2562
2563         node = parent->right;
2564     }
2565
2566     if( data->cv_heap )
2567         cvClearSet( data->cv_heap );
2568 }
2569
2570
2571 void CvDTree::free_tree()
2572 {
2573     if( root && data && data->shared )
2574     {
2575         pruned_tree_idx = INT_MIN;
2576         free_prune_data(true);
2577         data->free_node(root);
2578         root = 0;
2579     }
2580 }
2581
2582
2583 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2584     const CvMat* _missing, bool preprocessed_input ) const
2585 {
2586     CvDTreeNode* result = 0;
2587     int* catbuf = 0;
2588
2589     CV_FUNCNAME( "CvDTree::predict" );
2590
2591     __BEGIN__;
2592
2593     int i, step, mstep = 0;
2594     const float* sample;
2595     const uchar* m = 0;
2596     CvDTreeNode* node = root;
2597     const int* vtype;
2598     const int* vidx;
2599     const int* cmap;
2600     const int* cofs;
2601
2602     if( !node )
2603         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2604
2605     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2606         _sample->cols != 1 && _sample->rows != 1 ||
2607         _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2608         _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2609             CV_ERROR( CV_StsBadArg,
2610         "the input sample must be 1d floating-point vector with the same "
2611         "number of elements as the total number of variables used for training" );
2612
2613     sample = _sample->data.fl;
2614     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
2615
2616     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2617     {
2618         int n = data->cat_count->cols;
2619         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2620         for( i = 0; i < n; i++ )
2621             catbuf[i] = -1;
2622     }
2623
2624     if( _missing )
2625     {
2626         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2627         !CV_ARE_SIZES_EQ(_missing, _sample) )
2628             CV_ERROR( CV_StsBadArg,
2629         "the missing data mask must be 8-bit vector of the same size as input sample" );
2630         m = _missing->data.ptr;
2631         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2632     }
2633
2634     vtype = data->var_type->data.i;
2635     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2636     cmap = data->cat_map->data.i;
2637     cofs = data->cat_ofs->data.i;
2638
2639     while( node->Tn > pruned_tree_idx && node->left )
2640     {
2641         CvDTreeSplit* split = node->split;
2642         int dir = 0;
2643         for( ; !dir && split != 0; split = split->next )
2644         {
2645             int vi = split->var_idx;
2646             int ci = vtype[vi];
2647             int i = vidx ? vidx[vi] : vi;
2648             float val = sample[i*step];
2649             if( m && m[i*mstep] )
2650                 continue;
2651             if( ci < 0 ) // ordered
2652                 dir = val <= split->ord.c ? -1 : 1;
2653             else // categorical
2654             {
2655                 int c;
2656                 if( preprocessed_input )
2657                     c = cvRound(val);
2658                 else
2659                 {
2660                     c = catbuf[ci];
2661                     if( c < 0 )
2662                     {
2663                         int a = c = cofs[ci];
2664                         int b = cofs[ci+1];
2665                         int ival = cvRound(val);
2666                         if( ival != val )
2667                             CV_ERROR( CV_StsBadArg,
2668                             "one of input categorical variable is not an integer" );
2669
2670                         while( a < b )
2671                         {
2672                             c = (a + b) >> 1;
2673                             if( ival < cmap[c] )
2674                                 b = c;
2675                             else if( ival > cmap[c] )
2676                                 a = c+1;
2677                             else
2678                                 break;
2679                         }
2680
2681                         if( c < 0 || ival != cmap[c] )
2682                             continue;
2683
2684                         catbuf[ci] = c -= cofs[ci];
2685                     }
2686                 }
2687                 dir = DTREE_CAT_DIR(c, split->subset);
2688             }
2689
2690             if( split->inversed )
2691                 dir = -dir;
2692         }
2693
2694         if( !dir )
2695         {
2696             double diff = node->right->sample_count - node->left->sample_count;
2697             dir = diff < 0 ? -1 : 1;
2698         }
2699         node = dir < 0 ? node->left : node->right;
2700     }
2701
2702     result = node;
2703
2704     __END__;
2705
2706     return result;
2707 }
2708
2709
2710 const CvMat* CvDTree::get_var_importance()
2711 {
2712     if( !var_importance )
2713     {
2714         CvDTreeNode* node = root;
2715         double* importance;
2716         if( !node )
2717             return 0;
2718         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
2719         cvZero( var_importance );
2720         importance = var_importance->data.db;
2721
2722         for(;;)
2723         {
2724             CvDTreeNode* parent;
2725             for( ;; node = node->left )
2726             {
2727                 CvDTreeSplit* split = node->split;
2728                 
2729                 if( !node->left || node->Tn <= pruned_tree_idx )
2730                     break;
2731                 
2732                 for( ; split != 0; split = split->next )
2733                     importance[split->var_idx] += split->quality;
2734             }
2735
2736             for( parent = node->parent; parent && parent->right == node;
2737                 node = parent, parent = parent->parent )
2738                 ;
2739
2740             if( !parent )
2741                 break;
2742
2743             node = parent->right;
2744         }
2745     }
2746
2747     return var_importance;
2748 }
2749
2750
2751 void CvDTree::save( const char* filename, const char* name )
2752 {
2753     CvFileStorage* fs = 0;
2754     
2755     CV_FUNCNAME( "CvDTree::save" );
2756
2757     __BEGIN__;
2758
2759     CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_WRITE ));
2760     if( !fs )
2761         CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2762
2763     write( fs, name ? name : "my_dtree" );
2764
2765     __END__;
2766
2767     cvReleaseFileStorage( &fs );
2768 }
2769
2770
2771 void CvDTree::write_train_data_params( CvFileStorage* fs )
2772 {
2773     CV_FUNCNAME( "CvDTree::write_train_data_params" );
2774
2775     __BEGIN__;
2776
2777     int vi, vcount = data->var_count;
2778
2779     cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
2780     cvWriteInt( fs, "var_all", data->var_all );
2781     cvWriteInt( fs, "var_count", data->var_count );
2782     cvWriteInt( fs, "ord_var_count", data->ord_var_count );
2783     cvWriteInt( fs, "cat_var_count", data->cat_var_count );
2784
2785     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
2786     cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
2787
2788     if( data->is_classifier )
2789     {
2790         cvWriteInt( fs, "max_categories", data->params.max_categories );
2791     }
2792     else
2793     {
2794         cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
2795     }
2796
2797     cvWriteInt( fs, "max_depth", data->params.max_depth );
2798     cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
2799     cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
2800     
2801     if( data->params.cv_folds > 1 )
2802     {
2803         cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
2804         cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
2805     }
2806
2807     if( data->priors )
2808         cvWrite( fs, "priors", data->priors );
2809
2810     cvEndWriteStruct( fs );
2811
2812     if( data->var_idx )
2813         cvWrite( fs, "var_idx", data->var_idx );
2814     
2815     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
2816
2817     for( vi = 0; vi < vcount; vi++ )
2818         cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
2819
2820     cvEndWriteStruct( fs );
2821
2822     if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
2823     {
2824         CV_ASSERT( data->cat_count != 0 );
2825         cvWrite( fs, "cat_count", data->cat_count );
2826         cvWrite( fs, "cat_map", data->cat_map );
2827     }
2828
2829     __END__;
2830 }
2831
2832
2833 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
2834 {
2835     int ci;
2836     
2837     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
2838     cvWriteInt( fs, "var", split->var_idx );
2839     cvWriteReal( fs, "quality", split->quality );
2840
2841     ci = data->get_var_type(split->var_idx);
2842     if( ci >= 0 ) // split on a categorical var
2843     {
2844         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
2845         for( i = 0; i < n; i++ )
2846             to_right += DTREE_CAT_DIR(i,split->subset) > 0;
2847
2848         // ad-hoc rule when to use inverse categorical split notation
2849         // to achieve more compact and clear representation
2850         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
2851         
2852         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
2853                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
2854         for( i = 0; i < n; i++ )
2855         {
2856             int dir = DTREE_CAT_DIR(i,split->subset);
2857             if( dir*default_dir < 0 )
2858                 cvWriteInt( fs, 0, i );
2859         }
2860         cvEndWriteStruct( fs );
2861     }
2862     else
2863         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
2864
2865     cvEndWriteStruct( fs );
2866 }
2867
2868
2869 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
2870 {
2871     CvDTreeSplit* split;
2872     
2873     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2874
2875     cvWriteInt( fs, "depth", node->depth );
2876     cvWriteInt( fs, "sample_count", node->sample_count );
2877     cvWriteReal( fs, "value", node->value );
2878     
2879     if( data->is_classifier )
2880         cvWriteInt( fs, "norm_class_idx", node->class_idx );
2881
2882     cvWriteInt( fs, "Tn", node->Tn );
2883     cvWriteInt( fs, "complexity", node->complexity );
2884     cvWriteReal( fs, "alpha", node->alpha );
2885     cvWriteReal( fs, "node_risk", node->node_risk );
2886     cvWriteReal( fs, "tree_risk", node->tree_risk );
2887     cvWriteReal( fs, "tree_error", node->tree_error );
2888
2889     if( node->left )
2890     {
2891         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
2892
2893         for( split = node->split; split != 0; split = split->next )
2894             write_split( fs, split );
2895
2896         cvEndWriteStruct( fs );
2897     }
2898
2899     cvEndWriteStruct( fs );
2900 }
2901
2902
2903 void CvDTree::write_tree_nodes( CvFileStorage* fs )
2904 {
2905     CV_FUNCNAME( "CvDTree::write_tree_nodes" );
2906
2907     __BEGIN__;
2908
2909     CvDTreeNode* node = root;
2910
2911     // traverse the tree and save all the nodes in depth-first order
2912     for(;;)
2913     {
2914         CvDTreeNode* parent;
2915         for(;;)
2916         {
2917             write_node( fs, node );
2918             if( !node->left )
2919                 break;
2920             node = node->left;
2921         }
2922         
2923         for( parent = node->parent; parent && parent->right == node;
2924             node = parent, parent = parent->parent )
2925             ;
2926
2927         if( !parent )
2928             break;
2929
2930         node = parent->right;
2931     }
2932
2933     __END__;
2934 }
2935
2936
2937 void CvDTree::write( CvFileStorage* fs, const char* name )
2938 {
2939     CV_FUNCNAME( "CvDTree::write" );
2940
2941     __BEGIN__;
2942
2943     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
2944
2945     write_train_data_params( fs );
2946
2947     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
2948     get_var_importance();
2949     cvWrite( fs, "var_importance", var_importance );
2950
2951     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
2952     write_tree_nodes( fs );
2953     cvEndWriteStruct( fs );
2954
2955     cvEndWriteStruct( fs );
2956
2957     __END__;
2958 }
2959
2960
2961 void CvDTree::load( const char* filename, const char* name )
2962 {
2963     CvFileStorage* fs = 0;
2964     
2965     CV_FUNCNAME( "CvDTree::load" );
2966
2967     __BEGIN__;
2968
2969     CvFileNode* tree = 0;
2970
2971     CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_READ ));
2972     if( !fs )
2973         CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2974
2975     if( name )
2976         tree = cvGetFileNodeByName( fs, 0, name );
2977     else
2978     {
2979         CvFileNode* root = cvGetRootFileNode( fs );
2980         if( root->data.seq->total > 0 )
2981             tree = (CvFileNode*)cvGetSeqElem( root->data.seq, 0 );
2982     }
2983
2984     read( fs, tree );
2985
2986     __END__;
2987
2988     cvReleaseFileStorage( &fs );
2989 }
2990
2991
2992 void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
2993 {
2994     CV_FUNCNAME( "CvDTree::read_train_data_params" );
2995
2996     __BEGIN__;
2997     
2998     CvDTreeParams params;
2999     CvFileNode *tparams_node, *vartype_node;
3000     CvSeqReader reader;
3001     int is_classifier, vi, cat_var_count, ord_var_count;
3002     int max_split_size, tree_block_size;
3003
3004     data = new CvDTreeTrainData;
3005
3006     data->is_classifier = is_classifier = cvReadIntByName( fs, node, "is_classifier" ) != 0;
3007     data->var_all = cvReadIntByName( fs, node, "var_all" );
3008     data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
3009     data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
3010     data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
3011
3012     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
3013
3014     if( tparams_node ) // training parameters are not necessary
3015     {
3016         data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
3017
3018         if( is_classifier )
3019         {
3020             data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
3021         }
3022         else
3023         {
3024             data->params.regression_accuracy =
3025                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
3026         }
3027
3028         data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
3029         data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
3030         data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
3031     
3032         if( data->params.cv_folds > 1 )
3033         {
3034             data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
3035             data->params.truncate_pruned_tree =
3036                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
3037         }
3038
3039         data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
3040         if( data->priors && !CV_IS_MAT(data->priors) )
3041             CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
3042     }
3043
3044     CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
3045     if( data->var_idx )
3046     {
3047         if( !CV_IS_MAT(data->var_idx) ||
3048             data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
3049             data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
3050             CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
3051             CV_ERROR( CV_StsParseError,
3052                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
3053
3054         for( vi = 0; vi < data->var_count; vi++ )
3055             if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
3056                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
3057     }
3058     
3059     ////// read var type
3060     CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
3061
3062     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
3063     if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
3064         vartype_node->data.seq->total != data->var_count )
3065         CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3066
3067     cvStartReadSeq( vartype_node->data.seq, &reader );
3068
3069     cat_var_count = 0;
3070     ord_var_count = -1;
3071     for( vi = 0; vi < data->var_count; vi++ )
3072     {
3073         CvFileNode* n = (CvFileNode*)reader.ptr;
3074         if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
3075             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3076         data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
3077         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3078     }
3079     
3080     ord_var_count = ~ord_var_count;
3081     if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
3082         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
3083     //////
3084
3085     if( data->cat_var_count > 0 || is_classifier )
3086     {
3087         int ccount, max_c_count = 0, total_c_count = 0;
3088         CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
3089         CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
3090
3091         if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
3092             data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
3093             CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
3094             data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
3095             data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
3096             CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
3097             CV_ERROR( CV_StsParseError,
3098             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
3099
3100         ccount = cat_var_count + is_classifier;
3101
3102         CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
3103         data->cat_ofs->data.i[0] = 0;
3104
3105         for( vi = 0; vi < ccount; vi++ )
3106         {
3107             int val = data->cat_count->data.i[vi];
3108             if( val <= 0 )
3109                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
3110             max_c_count = MAX( max_c_count, val );
3111             data->cat_ofs->data.i[vi+1] = total_c_count += val;
3112         }
3113
3114         if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
3115             CV_ERROR( CV_StsBadSize,
3116             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
3117         
3118         data->max_c_count = max_c_count;
3119     }
3120
3121     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
3122         (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
3123
3124     tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
3125     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
3126     CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
3127     CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
3128             sizeof(CvDTreeNode), data->tree_storage ));
3129     CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
3130             max_split_size, data->tree_storage ));
3131
3132     __END__;
3133 }
3134
3135
3136 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3137 {
3138     CvDTreeSplit* split = 0;
3139     
3140     CV_FUNCNAME( "CvDTree::read_split" );
3141
3142     __BEGIN__;
3143
3144     int vi, ci;
3145     
3146     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3147         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3148
3149     vi = cvReadIntByName( fs, fnode, "var", -1 );
3150     if( (unsigned)vi >= (unsigned)data->var_count )
3151         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3152
3153     ci = data->get_var_type(vi);
3154     if( ci >= 0 ) // split on categorical var
3155     {
3156         int i, n = data->cat_count->data.i[ci], inversed = 0;
3157         CvSeqReader reader;
3158         CvFileNode* inseq;
3159         split = data->new_split_cat( vi, 0 );
3160         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3161         if( !inseq )
3162         {
3163             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3164             inversed = 1;
3165         }
3166         if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
3167             CV_ERROR( CV_StsParseError,
3168             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3169
3170         cvStartReadSeq( inseq->data.seq, &reader );
3171
3172         for( i = 0; i < reader.seq->total; i++ )
3173         {
3174             CvFileNode* inode = (CvFileNode*)reader.ptr;
3175             int val = inode->data.i;
3176             if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3177                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3178
3179             split->subset[val >> 5] |= 1 << (val & 31);
3180             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3181         }
3182
3183         // for categorical splits we do not use inversed splits,
3184         // instead we inverse the variable set in the split
3185         if( inversed )
3186             for( i = 0; i < (n + 31) >> 5; i++ )
3187                 split->subset[i] ^= -1;
3188     }
3189     else
3190     {
3191         CvFileNode* cmp_node;
3192         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3193
3194         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3195         if( !cmp_node )
3196         {
3197             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3198             split->inversed = 1;
3199         }
3200
3201         split->ord.c = (float)cvReadReal( cmp_node );
3202     }
3203         
3204     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3205
3206     __END__;
3207     
3208     return split;
3209 }
3210
3211
3212 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3213 {
3214     CvDTreeNode* node = 0;
3215     
3216     CV_FUNCNAME( "CvDTree::read_node" );
3217
3218     __BEGIN__;
3219
3220     CvFileNode* splits;
3221     int i, depth;
3222
3223     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3224         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3225
3226     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3227     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3228     if( depth != node->depth )
3229         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3230
3231     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3232     node->value = cvReadRealByName( fs, fnode, "value" );
3233     if( data->is_classifier )
3234         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3235
3236     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3237     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3238     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3239     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3240     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3241     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3242
3243     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3244     if( splits )
3245     {
3246         CvSeqReader reader;
3247         CvDTreeSplit* last_split = 0;
3248
3249         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3250             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3251
3252         cvStartReadSeq( splits->data.seq, &reader );
3253         for( i = 0; i < reader.seq->total; i++ )
3254         {
3255             CvDTreeSplit* split;
3256             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3257             if( !last_split )
3258                 node->split = last_split = split;
3259             else
3260                 last_split = last_split->next = split;
3261
3262             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3263         }
3264     }
3265
3266     __END__;
3267     
3268     return node;
3269 }
3270
3271
3272 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3273 {
3274     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3275
3276     __BEGIN__;
3277
3278     CvSeqReader reader;
3279     CvDTreeNode _root;
3280     CvDTreeNode* parent = &_root;
3281     int i;
3282     parent->left = parent->right = parent->parent = 0;
3283
3284     cvStartReadSeq( fnode->data.seq, &reader );
3285
3286     for( i = 0; i < reader.seq->total; i++ )
3287     {
3288         CvDTreeNode* node;
3289         
3290         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3291         if( !parent->left )
3292             parent->left = node;
3293         else
3294             parent->right = node;
3295         if( node->split )
3296             parent = node;
3297         else
3298         {
3299             while( parent && parent->right )
3300                 parent = parent->parent;
3301         }
3302
3303         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3304     }
3305
3306     root = _root.left;
3307
3308     __END__;
3309 }
3310
3311
3312 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3313 {
3314     CV_FUNCNAME( "CvDTree::read" );
3315
3316     __BEGIN__;
3317
3318     CvFileNode* tree_nodes;
3319
3320     clear();
3321     read_train_data_params( fs, fnode );
3322
3323     tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
3324     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3325         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3326
3327     pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
3328
3329     read_tree_nodes( fs, tree_nodes );
3330     get_var_importance(); // recompute variable importance
3331
3332     __END__;
3333 }
3334
3335 /* End of file. */