]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
fixed minor compiler warning
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
46
47 CvDTreeTrainData::CvDTreeTrainData()
48 {
49     var_idx = var_type = cat_count = cat_ofs = cat_map =
50         priors = counts = buf = direction = split_buf = 0;
51     tree_storage = temp_storage = 0;
52
53     clear();
54 }
55
56
57 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
58                       const CvMat* _responses, const CvMat* _var_idx,
59                       const CvMat* _sample_idx, const CvMat* _var_type,
60                       const CvMat* _missing_mask,
61                       CvDTreeParams _params, bool _shared, bool _add_weights )
62 {
63     var_idx = var_type = cat_count = cat_ofs = cat_map =
64         priors = counts = buf = direction = split_buf = 0;
65     tree_storage = temp_storage = 0;
66     
67     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
68               _var_type, _missing_mask, _params, _shared, _add_weights );
69 }
70
71
72 CvDTreeTrainData::~CvDTreeTrainData()
73 {
74     clear();
75 }
76
77
78 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
79 {
80     bool ok = false;
81     
82     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
83
84     __BEGIN__;
85
86     // set parameters
87     params = _params;
88
89     if( params.max_categories < 2 )
90         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
91     params.max_categories = MIN( params.max_categories, 15 );
92
93     if( params.max_depth < 0 )
94         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
95     params.max_depth = MIN( params.max_depth, 25 );
96
97     params.min_sample_count = MAX(params.min_sample_count,1);
98
99     if( params.cv_folds < 0 )
100         CV_ERROR( CV_StsOutOfRange,
101         "params.cv_folds should be =0 (the tree is not pruned) "
102         "or n>0 (tree is pruned using n-fold cross-validation)" );
103
104     if( params.cv_folds == 1 )
105         params.cv_folds = 0;
106
107     if( params.regression_accuracy < 0 )
108         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
109
110     ok = true;
111
112     __END__;
113
114     return ok;
115 }
116
117
118 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
119 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
120 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
121
122 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
123 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
124
125 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
126     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
127     const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
128     bool _shared, bool _add_weights )
129 {
130     CvMat* sample_idx = 0;
131     CvMat* var_type0 = 0;
132     CvMat* tmp_map = 0;
133     int** int_ptr = 0;
134
135     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
136
137     __BEGIN__;
138
139     int sample_all = 0, r_type = 0, cv_n;
140     int total_c_count = 0;
141     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
142     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
143     int vi, i;
144     char err[100];
145     const int *sidx = 0, *vidx = 0;
146
147     clear();
148
149     var_all = 0;
150     rng = cvRNG(-1);
151
152     CV_CALL( set_params( _params ));
153
154     // check parameter types and sizes
155     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
156     if( _tflag == CV_ROW_SAMPLE )
157     {
158         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
159         dv_step = 1;
160         if( _missing_mask )
161             ms_step = _missing_mask->step, mv_step = 1;
162     }
163     else
164     {
165         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
166         ds_step = 1;
167         if( _missing_mask )
168             mv_step = _missing_mask->step, ms_step = 1;
169     }
170
171     sample_count = sample_all;
172     var_count = var_all;
173
174     if( _sample_idx )
175     {
176         CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
177         sidx = sample_idx->data.i;
178         sample_count = sample_idx->rows + sample_idx->cols - 1;
179     }
180
181     if( _var_idx )
182     {
183         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
184         vidx = var_idx->data.i;
185         var_count = var_idx->rows + var_idx->cols - 1;
186     }
187
188     if( !CV_IS_MAT(_responses) ||
189         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
190          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
191         _responses->rows != 1 && _responses->cols != 1 ||
192         _responses->rows + _responses->cols - 1 != sample_all )
193         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
194                   "floating-point vector containing as many elements as "
195                   "the total number of samples in the training data matrix" );
196
197     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
198     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
199
200     cat_var_count = 0;
201     ord_var_count = -1;
202
203     is_classifier = r_type == CV_VAR_CATEGORICAL;
204
205     // step 0. calc the number of categorical vars
206     for( vi = 0; vi < var_count; vi++ )
207     {
208         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
209             cat_var_count++ : ord_var_count--;
210     }
211
212     ord_var_count = ~ord_var_count;
213     cv_n = params.cv_folds;
214     // set the two last elements of var_type array to be able
215     // to locate responses and cross-validation labels using
216     // the corresponding get_* functions.
217     var_type->data.i[var_count] = cat_var_count;
218     var_type->data.i[var_count+1] = cat_var_count+1;
219
220     // in case of single ordered predictor we need dummy cv_labels
221     // for safe split_node_data() operation
222     have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
223     have_weights = _add_weights;
224
225     buf_size = (ord_var_count*2 + cat_var_count + 1 +
226         (have_cv_labels ? 1 : 0) + (have_weights ? 1 : 0))*sample_count + 2;
227     shared = _shared;
228     buf_count = shared ? 3 : 2;
229     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
230     CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
231     CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
232     CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
233
234     // now calculate the maximum size of split,
235     // create memory storage that will keep nodes and splits of the decision tree
236     // allocate root node and the buffer for the whole training data
237     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
238         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
239     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
240     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
241     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
242     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
243
244     temp_block_size = nv_size = var_count*sizeof(int);
245     if( cv_n )
246     {
247         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
248             CV_ERROR( CV_StsOutOfRange,
249                 "The many folds in cross-validation for such a small dataset" );
250
251         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
252         temp_block_size = MAX(temp_block_size, cv_size);
253     }
254
255     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
256     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
257     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
258     if( cv_size )
259         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
260
261     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
262     CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
263
264     max_c_count = 1;
265
266     // transform the training data to convenient representation
267     for( vi = 0; vi <= var_count; vi++ )
268     {
269         int ci;
270         const uchar* mask = 0;
271         int m_step = 0, step;
272         const int* idata = 0;
273         const float* fdata = 0;
274         int num_valid = 0;
275
276         if( vi < var_count ) // analyze i-th input variable
277         {
278             int vi0 = vidx ? vidx[vi] : vi;
279             ci = get_var_type(vi);
280             step = ds_step; m_step = ms_step;
281             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
282                 idata = _train_data->data.i + vi0*dv_step;
283             else
284                 fdata = _train_data->data.fl + vi0*dv_step;
285             if( _missing_mask )
286                 mask = _missing_mask->data.ptr + vi0*mv_step;
287         }
288         else // analyze _responses
289         {
290             ci = cat_var_count;
291             step = CV_IS_MAT_CONT(_responses->type) ?
292                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
293             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
294                 idata = _responses->data.i;
295             else
296                 fdata = _responses->data.fl;
297         }
298
299         if( vi < var_count && ci >= 0 ||
300             vi == var_count && is_classifier ) // process categorical variable or response
301         {
302             int c_count, prev_label;
303             int* c_map, *dst = get_cat_var_data( data_root, vi );
304
305             // copy data
306             for( i = 0; i < sample_count; i++ )
307             {
308                 int val = INT_MAX, si = sidx ? sidx[i] : i;
309                 if( !mask || !mask[si*m_step] )
310                 {
311                     if( idata )
312                         val = idata[si*step];
313                     else
314                     {
315                         float t = fdata[si*step];
316                         val = cvRound(t);
317                         if( val != t )
318                         {
319                             sprintf( err, "%d-th value of %d-th (categorical) "
320                                 "variable is not an integer", i, vi );
321                             CV_ERROR( CV_StsBadArg, err );
322                         }
323                     }
324
325                     if( val == INT_MAX )
326                     {
327                         sprintf( err, "%d-th value of %d-th (categorical) "
328                             "variable is too large", i, vi );
329                         CV_ERROR( CV_StsBadArg, err );
330                     }
331                     num_valid++;
332                 }
333                 dst[i] = val;
334                 int_ptr[i] = dst + i;
335             }
336
337             // sort all the values, including the missing measurements
338             // that should all move to the end
339             icvSortIntPtr( int_ptr, sample_count, 0 );
340             //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
341
342             c_count = num_valid > 0;
343
344             // count the categories
345             for( i = 1; i < num_valid; i++ )
346                 c_count += *int_ptr[i] != *int_ptr[i-1];
347
348             if( vi > 0 )
349                 max_c_count = MAX( max_c_count, c_count );
350             cat_count->data.i[ci] = c_count;
351             cat_ofs->data.i[ci] = total_c_count;
352
353             // resize cat_map, if need
354             if( cat_map->cols < total_c_count + c_count )
355             {
356                 tmp_map = cat_map;
357                 CV_CALL( cat_map = cvCreateMat( 1,
358                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
359                 for( i = 0; i < total_c_count; i++ )
360                     cat_map->data.i[i] = tmp_map->data.i[i];
361                 cvReleaseMat( &tmp_map );
362             }
363
364             c_map = cat_map->data.i + total_c_count;
365             total_c_count += c_count;
366
367             // compact the class indices and build the map
368             prev_label = ~*int_ptr[0];
369             c_count = -1;
370
371             for( i = 0; i < num_valid; i++ )
372             {
373                 int cur_label = *int_ptr[i];
374                 if( cur_label != prev_label )
375                     c_map[++c_count] = prev_label = cur_label;
376                 *int_ptr[i] = c_count;
377             }
378
379             // replace labels for missing values with -1
380             for( ; i < sample_count; i++ )
381                 *int_ptr[i] = -1;
382         }
383         else if( ci < 0 ) // process ordered variable
384         {
385             CvPair32s32f* dst = get_ord_var_data( data_root, vi );
386
387             for( i = 0; i < sample_count; i++ )
388             {
389                 float val = ord_nan;
390                 int si = sidx ? sidx[i] : i;
391                 if( !mask || !mask[si*m_step] )
392                 {
393                     if( idata )
394                         val = (float)idata[si*step];
395                     else
396                         val = fdata[si*step];
397
398                     if( fabs(val) >= ord_nan )
399                     {
400                         sprintf( err, "%d-th value of %d-th (ordered) "
401                             "variable (=%g) is too large", i, vi, val );
402                         CV_ERROR( CV_StsBadArg, err );
403                     }
404                     num_valid++;
405                 }
406                 dst[i].i = i;
407                 dst[i].val = val;
408             }
409
410             icvSortPairs( dst, sample_count, 0 );
411         }
412         else // special case: process ordered response,
413              // it will be stored similarly to categorical vars (i.e. no pairs)
414         {
415             float* dst = get_ord_responses( data_root );
416
417             for( i = 0; i < sample_count; i++ )
418             {
419                 float val = ord_nan;
420                 int si = sidx ? sidx[i] : i;
421                 if( idata )
422                     val = (float)idata[si*step];
423                 else
424                     val = fdata[si*step];
425
426                 if( fabs(val) >= ord_nan )
427                 {
428                     sprintf( err, "%d-th value of %d-th (ordered) "
429                         "variable (=%g) is out of range", i, vi, val );
430                     CV_ERROR( CV_StsBadArg, err );
431                 }
432                 dst[i] = val;
433             }
434
435             cat_count->data.i[cat_var_count] = 0;
436             cat_ofs->data.i[cat_var_count] = total_c_count;
437             num_valid = sample_count;
438         }
439
440         if( vi < var_count )
441             data_root->set_num_valid(vi, num_valid);
442     }
443
444     if( cv_n )
445     {
446         int* dst = get_cv_labels(data_root);
447         CvRNG* r = &rng;
448
449         for( i = vi = 0; i < sample_count; i++ )
450         {
451             dst[i] = vi++;
452             vi &= vi < cv_n ? -1 : 0;
453         }
454
455         for( i = 0; i < sample_count; i++ )
456         {
457             int a = cvRandInt(r) % sample_count;
458             int b = cvRandInt(r) % sample_count;
459             CV_SWAP( dst[a], dst[b], vi );
460         }
461     }
462
463     cat_map->cols = MAX( total_c_count, 1 );
464
465     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
466         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
467     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
468
469     have_priors = is_classifier && params.priors;
470     if( is_classifier )
471     {
472         int m = get_num_classes(), rows = 4;
473         double sum = 0;
474         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
475         for( i = 0; i < m; i++ )
476         {
477             double val = have_priors ? params.priors[i] : 1.;
478             if( val <= 0 )
479                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
480             priors->data.db[i] = val;
481             sum += val;
482         }
483         // normalize weights
484         if( have_priors )
485             cvScale( priors, priors, 1./sum );
486
487         if( cat_var_count > 0 || params.cv_folds > 0 )
488         {
489             // need storage for cjk (see find_split_cat_gini) and risks/errors
490             rows += MAX( max_c_count, params.cv_folds ) + 1;
491             // add buffer for k-means clustering
492             if( m > 2 && max_c_count > params.max_categories )
493                 rows += params.max_categories + (max_c_count+m-1)/m;
494         }
495
496         CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
497     }
498
499     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
500     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
501
502     __END__;
503
504     cvFree( &int_ptr );
505     cvReleaseMat( &sample_idx );
506     cvReleaseMat( &var_type0 );
507     cvReleaseMat( &tmp_map );
508 }
509
510
511 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
512 {
513     CvDTreeNode* root = 0;
514     CvMat* isubsample_idx = 0;
515     CvMat* subsample_co = 0;
516     
517     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
518
519     __BEGIN__;
520
521     if( !data_root )
522         CV_ERROR( CV_StsError, "No training data has been set" );
523     
524     if( _subsample_idx )
525         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
526
527     if( !isubsample_idx )
528     {
529         // make a copy of the root node
530         CvDTreeNode temp;
531         int i;
532         root = new_node( 0, 1, 0, 0 );
533         temp = *root;
534         *root = *data_root;
535         root->num_valid = temp.num_valid;
536         if( root->num_valid )
537         {
538             for( i = 0; i < var_count; i++ )
539                 root->num_valid[i] = data_root->num_valid[i];
540         }
541         root->cv_Tn = temp.cv_Tn;
542         root->cv_node_risk = temp.cv_node_risk;
543         root->cv_node_error = temp.cv_node_error;
544     }
545     else
546     {
547         int* sidx = isubsample_idx->data.i;
548         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
549         int* co, cur_ofs = 0;
550         int vi, i, total = data_root->sample_count;
551         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
552         root = new_node( 0, count, 1, 0 );
553
554         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
555         cvZero( subsample_co );
556         co = subsample_co->data.i;
557         for( i = 0; i < count; i++ )
558             co[sidx[i]*2]++;
559         for( i = 0; i < total; i++ )
560         {
561             if( co[i*2] )
562             {
563                 co[i*2+1] = cur_ofs;
564                 cur_ofs += co[i*2];
565             }
566             else
567                 co[i*2+1] = -1;
568         }
569
570         for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
571         {
572             int ci = get_var_type(vi);
573
574             if( ci >= 0 || vi >= var_count )
575             {
576                 const int* src = get_cat_var_data( data_root, vi );
577                 int* dst = get_cat_var_data( root, vi );
578                 int num_valid = 0;
579
580                 for( i = 0; i < count; i++ )
581                 {
582                     int val = src[sidx[i]];
583                     dst[i] = val;
584                     num_valid += val >= 0;
585                 }
586
587                 if( vi < var_count )
588                     root->set_num_valid(vi, num_valid);
589             }
590             else
591             {
592                 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
593                 CvPair32s32f* dst = get_ord_var_data( root, vi );
594                 int j = 0, idx, count_i;
595                 int num_valid = data_root->get_num_valid(vi);
596
597                 for( i = 0; i < num_valid; i++ )
598                 {
599                     idx = src[i].i;
600                     count_i = co[idx*2];
601                     if( count_i )
602                     {
603                         float val = src[i].val;
604                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
605                         {
606                             dst[j].val = val;
607                             dst[j].i = cur_ofs;
608                         }
609                     }
610                 }
611
612                 root->set_num_valid(vi, j);
613
614                 for( ; i < total; i++ )
615                 {
616                     idx = src[i].i;
617                     count_i = co[idx*2];
618                     if( count_i )
619                     {
620                         float val = src[i].val;
621                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
622                         {
623                             dst[j].val = val;
624                             dst[j].i = cur_ofs;
625                         }
626                     }
627                 }
628             }
629         }
630     }
631
632     __END__;
633
634     cvReleaseMat( &isubsample_idx );
635     cvReleaseMat( &subsample_co );
636
637     return root;
638 }
639
640
641 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
642                                     float* values, uchar* missing,
643                                     float* responses, bool get_class_idx )
644 {
645     CvMat* subsample_idx = 0;
646     CvMat* subsample_co = 0;
647     
648     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
649
650     __BEGIN__;
651
652     int i, vi, total = sample_count, count = total, cur_ofs = 0;
653     int* sidx = 0;
654     int* co = 0;
655
656     if( _subsample_idx )
657     {
658         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
659         sidx = subsample_idx->data.i;
660         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
661         co = subsample_co->data.i;
662         cvZero( subsample_co );
663         count = subsample_idx->cols + subsample_idx->rows - 1;
664         for( i = 0; i < count; i++ )
665             co[sidx[i]*2]++;
666         for( i = 0; i < total; i++ )
667         {
668             int count_i = co[i*2];
669             if( count_i )
670             {
671                 co[i*2+1] = cur_ofs*var_count;
672                 cur_ofs += count_i;
673             }
674         }
675     }
676
677     memset( missing, 1, count*var_count );
678
679     for( vi = 0; vi < var_count; vi++ )
680     {
681         int ci = get_var_type(vi);
682         if( ci >= 0 ) // categorical
683         {
684             float* dst = values + vi;
685             uchar* m = missing + vi;
686             const int* src = get_cat_var_data(data_root, vi);
687
688             for( i = 0; i < count; i++, dst += var_count, m += var_count )
689             {
690                 int idx = sidx ? sidx[i] : i;
691                 int val = src[idx];
692                 *dst = (float)val;
693                 *m = val < 0;
694             }
695         }
696         else // ordered
697         {
698             float* dst = values + vi;
699             uchar* m = missing + vi;
700             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
701             int count1 = data_root->get_num_valid(vi);
702
703             for( i = 0; i < count1; i++ )
704             {
705                 int idx = src[i].i;
706                 int count_i = 1;
707                 if( co )
708                 {
709                     count_i = co[idx*2];
710                     cur_ofs = co[idx*2+1];
711                 }
712                 else
713                     cur_ofs = idx*var_count;
714                 if( count_i )
715                 {
716                     float val = src[i].val;
717                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
718                     {
719                         dst[cur_ofs] = val;
720                         m[cur_ofs] = 0;
721                     }
722                 }
723             }
724         }
725     }
726
727     // copy responses
728     if( is_classifier )
729     {
730         const int* src = get_class_labels(data_root);
731         for( i = 0; i < count; i++ )
732         {
733             int idx = sidx ? sidx[i] : i;
734             int val = get_class_idx ? src[idx] :
735                 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
736             responses[i] = (float)val;
737         }
738     }
739     else
740     {
741         const float* src = get_ord_responses(data_root);
742         for( i = 0; i < count; i++ )
743         {
744             int idx = sidx ? sidx[i] : i;
745             responses[i] = src[idx];
746         }
747     }
748
749     __END__;
750
751     cvReleaseMat( &subsample_idx );
752     cvReleaseMat( &subsample_co );
753 }
754
755
756 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
757                                          int storage_idx, int offset )
758 {
759     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
760
761     node->sample_count = count;
762     node->depth = parent ? parent->depth + 1 : 0;
763     node->parent = parent;
764     node->left = node->right = 0;
765     node->split = 0;
766     node->value = 0;
767     node->class_idx = 0;
768     node->maxlr = 0.;
769
770     node->buf_idx = storage_idx;
771     node->offset = offset;
772     if( nv_heap )
773         node->num_valid = (int*)cvSetNew( nv_heap );
774     else
775         node->num_valid = 0;
776     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
777     node->complexity = 0;
778
779     if( params.cv_folds > 0 && cv_heap )
780     {
781         int cv_n = params.cv_folds;
782         node->Tn = INT_MAX;
783         node->cv_Tn = (int*)cvSetNew( cv_heap );
784         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
785         node->cv_node_error = node->cv_node_risk + cv_n;
786     }
787     else
788     {
789         node->Tn = 0;
790         node->cv_Tn = 0;
791         node->cv_node_risk = 0;
792         node->cv_node_error = 0;
793     }
794
795     return node;
796 }
797
798
799 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
800                 int split_point, int inversed, float quality )
801 {
802     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
803     split->var_idx = vi;
804     split->ord.c = cmp_val;
805     split->ord.split_point = split_point;
806     split->inversed = inversed;
807     split->quality = quality;
808     split->next = 0;
809
810     return split;
811 }
812
813
814 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
815 {
816     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
817     int i, n = (max_c_count + 31)/32;
818
819     split->var_idx = vi;
820     split->inversed = 0;
821     split->quality = quality;
822     for( i = 0; i < n; i++ )
823         split->subset[i] = 0;
824     split->next = 0;
825
826     return split;
827 }
828
829
830 void CvDTreeTrainData::free_node( CvDTreeNode* node )
831 {
832     CvDTreeSplit* split = node->split;
833     free_node_data( node );
834     while( split )
835     {
836         CvDTreeSplit* next = split->next;
837         cvSetRemoveByPtr( split_heap, split );
838         split = next;
839     }
840     node->split = 0;
841     cvSetRemoveByPtr( node_heap, node );
842 }
843
844
845 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
846 {
847     if( node->num_valid )
848     {
849         cvSetRemoveByPtr( nv_heap, node->num_valid );
850         node->num_valid = 0;
851     }
852     // do not free cv_* fields, as all the cross-validation related data is released at once.
853 }
854
855
856 void CvDTreeTrainData::free_train_data()
857 {
858     cvReleaseMat( &counts );
859     cvReleaseMat( &buf );
860     cvReleaseMat( &direction );
861     cvReleaseMat( &split_buf );
862     cvReleaseMemStorage( &temp_storage );
863     cv_heap = nv_heap = 0;
864 }
865
866
867 void CvDTreeTrainData::clear()
868 {
869     free_train_data();
870
871     cvReleaseMemStorage( &tree_storage );
872
873     cvReleaseMat( &var_idx );
874     cvReleaseMat( &var_type );
875     cvReleaseMat( &cat_count );
876     cvReleaseMat( &cat_ofs );
877     cvReleaseMat( &cat_map );
878     cvReleaseMat( &priors );
879
880     node_heap = split_heap = 0;
881
882     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
883     have_cv_labels = have_priors = is_classifier = false;
884
885     buf_count = buf_size = 0;
886     shared = false;
887
888     data_root = 0;
889
890     rng = cvRNG(-1);
891 }
892
893
894 int CvDTreeTrainData::get_num_classes() const
895 {
896     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
897 }
898
899
900 int CvDTreeTrainData::get_var_type(int vi) const
901 {
902     return var_type->data.i[vi];
903 }
904
905
906 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
907 {
908     int oi = ~get_var_type(vi);
909     assert( 0 <= oi && oi < ord_var_count );
910     return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
911                            n->offset + oi*n->sample_count*2);
912 }
913
914
915 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
916 {
917     return get_cat_var_data( n, var_count );
918 }
919
920
921 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
922 {
923     return (float*)get_cat_var_data( n, var_count );
924 }
925
926
927 int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
928 {
929     return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
930 }
931
932
933 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
934 {
935     int ci = get_var_type(vi);
936     assert( 0 <= ci && ci <= cat_var_count + 1 );
937     return buf->data.i + n->buf_idx*buf->cols + n->offset +
938            (ord_var_count*2 + ci)*n->sample_count;
939 }
940
941
942 float* CvDTreeTrainData::get_weights( CvDTreeNode* n )
943 {
944     return have_weights ?
945         (float*)get_cat_var_data( n, var_count + 1 + (params.cv_folds > 0) ) : 0;
946 }
947
948
949 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
950 {
951     int idx = n->buf_idx + 1;
952     if( idx >= buf_count )
953         idx = shared ? 1 : 0;
954     return idx;
955 }
956
957
958 /////////////////////// Decision Tree /////////////////////////
959
960 CvDTree::CvDTree()
961 {
962     data = 0;
963     var_importance = 0;
964     default_model_name = "my_tree";
965
966     clear();
967 }
968
969
970 void CvDTree::clear()
971 {
972     cvReleaseMat( &var_importance );
973     if( data )
974     {
975         if( !data->shared )
976             delete data;
977         else
978             free_tree();
979         data = 0;
980     }
981     root = 0;
982     pruned_tree_idx = -1;
983 }
984
985
986 CvDTree::~CvDTree()
987 {
988     clear();
989 }
990
991
992 const CvDTreeNode* CvDTree::get_root() const
993 {
994     return root;
995 }
996
997
998 int CvDTree::get_pruned_tree_idx() const
999 {
1000     return pruned_tree_idx;
1001 }
1002
1003
1004 CvDTreeTrainData* CvDTree::get_data()
1005 {
1006     return data;
1007 }
1008
1009
1010 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1011                      const CvMat* _responses, const CvMat* _var_idx,
1012                      const CvMat* _sample_idx, const CvMat* _var_type,
1013                      const CvMat* _missing_mask, CvDTreeParams _params )
1014 {
1015     bool result = false;
1016
1017     CV_FUNCNAME( "CvDTree::train" );
1018
1019     __BEGIN__;
1020
1021     clear();
1022     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1023                                  _var_idx, _sample_idx, _var_type,
1024                                  _missing_mask, _params, false );
1025     CV_CALL( result = do_train(0));
1026
1027     __END__;
1028
1029     return result;
1030 }
1031
1032
1033 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1034 {
1035     bool result = false;
1036
1037     CV_FUNCNAME( "CvDTree::train" );
1038
1039     __BEGIN__;
1040
1041     clear();
1042     data = _data;
1043     data->shared = true;
1044     CV_CALL( result = do_train(_subsample_idx));
1045
1046     __END__;
1047
1048     return result;
1049 }
1050
1051
1052 bool CvDTree::do_train( const CvMat* _subsample_idx )
1053 {
1054     bool result = false;
1055
1056     CV_FUNCNAME( "CvDTree::do_train" );
1057
1058     __BEGIN__;
1059
1060     root = data->subsample_data( _subsample_idx );
1061
1062     CV_CALL( try_split_node(root));
1063     
1064     if( data->params.cv_folds > 0 )
1065         CV_CALL( prune_cv());
1066
1067     if( !data->shared )
1068         data->free_train_data();
1069
1070     result = true;
1071
1072     __END__;
1073
1074     return result;
1075 }
1076
1077
1078 #define DTREE_CAT_DIR(idx,subset) \
1079     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
1080
1081 void CvDTree::try_split_node( CvDTreeNode* node )
1082 {
1083     CvDTreeSplit* best_split = 0;
1084     int i, n = node->sample_count, vi;
1085     bool can_split = true;
1086     double quality_scale;
1087
1088     calc_node_value( node );
1089
1090     if( node->sample_count <= data->params.min_sample_count ||
1091         node->depth >= data->params.max_depth )
1092         can_split = false;
1093
1094     if( can_split && data->is_classifier )
1095     {
1096         // check if we have a "pure" node,
1097         // we assume that cls_count is filled by calc_node_value()
1098         int* cls_count = data->counts->data.i;
1099         int nz = 0, m = data->get_num_classes();
1100         for( i = 0; i < m; i++ )
1101             nz += cls_count[i] != 0;
1102         if( nz == 1 ) // there is only one class
1103             can_split = false;
1104     }
1105     else if( can_split )
1106     {
1107         const float* responses = data->get_ord_responses( node );
1108         float diff = responses[n-1] - responses[0];
1109         if( diff < data->params.regression_accuracy )
1110             can_split = false;
1111     }
1112
1113     if( can_split )
1114     {
1115         best_split = find_best_split(node);
1116         // TODO: check the split quality ...
1117         node->split = best_split;
1118     }
1119
1120     if( !can_split || !best_split )
1121     {
1122         data->free_node_data(node);
1123         return;
1124     }
1125
1126     quality_scale = calc_node_dir( node );
1127
1128     if( data->params.use_surrogates )
1129     {
1130         // find all the surrogate splits
1131         // and sort them by their similarity to the primary one
1132         for( vi = 0; vi < data->var_count; vi++ )
1133         {
1134             CvDTreeSplit* split;
1135             int ci = data->get_var_type(vi);
1136
1137             if( vi == best_split->var_idx )
1138                 continue;
1139
1140             if( ci >= 0 )
1141                 split = find_surrogate_split_cat( node, vi );
1142             else
1143                 split = find_surrogate_split_ord( node, vi );
1144
1145             if( split )
1146             {
1147                 // insert the split
1148                 CvDTreeSplit* prev_split = node->split;
1149                 split->quality = (float)(split->quality*quality_scale);
1150
1151                 while( prev_split->next &&
1152                        prev_split->next->quality > split->quality )
1153                     prev_split = prev_split->next;
1154                 split->next = prev_split->next;
1155                 prev_split->next = split;
1156             }
1157         }
1158     }
1159
1160     split_node_data( node );
1161     try_split_node( node->left );
1162     try_split_node( node->right );
1163 }
1164
1165
1166 // calculate direction (left(-1),right(1),missing(0))
1167 // for each sample using the best split
1168 // the function returns scale coefficients for surrogate split quality factors.
1169 // the scale is applied to normalize surrogate split quality relatively to the
1170 // best (primary) split quality. That is, if a surrogate split is absolutely
1171 // identical to the primary split, its quality will be set to the maximum value = 
1172 // quality of the primary split; otherwise, it will be lower.
1173 // besides, the function compute node->maxlr,
1174 // minimum possible quality (w/o considering the above mentioned scale)
1175 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1176 // are not discarded.
1177 double CvDTree::calc_node_dir( CvDTreeNode* node )
1178 {
1179     char* dir = (char*)data->direction->data.ptr;
1180     int i, n = node->sample_count, vi = node->split->var_idx;
1181     double L, R;
1182
1183     assert( !node->split->inversed );
1184
1185     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1186     {
1187         const int* labels = data->get_cat_var_data(node,vi);
1188         const int* subset = node->split->subset;
1189
1190         if( !data->have_priors )
1191         {
1192             int sum = 0, sum_abs = 0;
1193
1194             for( i = 0; i < n; i++ )
1195             {
1196                 int idx = labels[i];
1197                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1198                 sum += d; sum_abs += d & 1;
1199                 dir[i] = (char)d;
1200             }
1201
1202             R = (sum_abs + sum) >> 1;
1203             L = (sum_abs - sum) >> 1;
1204         }
1205         else
1206         {
1207             const int* responses = data->get_class_labels(node);
1208             const double* priors = data->priors->data.db;
1209             double sum = 0, sum_abs = 0;
1210
1211             for( i = 0; i < n; i++ )
1212             {
1213                 int idx = labels[i];
1214                 double w = priors[responses[i]];
1215                 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1216                 sum += d*w; sum_abs += (d & 1)*w;
1217                 dir[i] = (char)d;
1218             }
1219
1220             R = (sum_abs + sum) * 0.5;
1221             L = (sum_abs - sum) * 0.5;
1222         }
1223     }
1224     else // split on ordered var
1225     {
1226         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1227         int split_point = node->split->ord.split_point;
1228         int n1 = node->get_num_valid(vi);
1229
1230         assert( 0 <= split_point && split_point < n1-1 );
1231
1232         if( !data->have_priors )
1233         {
1234             for( i = 0; i <= split_point; i++ )
1235                 dir[sorted[i].i] = (char)-1;
1236             for( ; i < n1; i++ )
1237                 dir[sorted[i].i] = (char)1;
1238             for( ; i < n; i++ )
1239                 dir[sorted[i].i] = (char)0;
1240
1241             L = split_point-1;
1242             R = n1 - split_point + 1;
1243         }
1244         else
1245         {
1246             const int* responses = data->get_class_labels(node);
1247             const double* priors = data->priors->data.db;
1248             L = R = 0;
1249
1250             for( i = 0; i <= split_point; i++ )
1251             {
1252                 int idx = sorted[i].i;
1253                 double w = priors[responses[idx]];
1254                 dir[idx] = (char)-1;
1255                 L += w;
1256             }
1257
1258             for( ; i < n1; i++ )
1259             {
1260                 int idx = sorted[i].i;
1261                 double w = priors[responses[idx]];
1262                 dir[idx] = (char)1;
1263                 R += w;
1264             }
1265
1266             for( ; i < n; i++ )
1267                 dir[sorted[i].i] = (char)0;
1268         }
1269     }
1270
1271     node->maxlr = MAX( L, R );
1272     return node->split->quality/(L + R);
1273 }
1274
1275
1276 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1277 {
1278     int vi;
1279     CvDTreeSplit *best_split = 0, *split = 0, *t;
1280
1281     for( vi = 0; vi < data->var_count; vi++ )
1282     {
1283         int ci = data->get_var_type(vi);
1284         if( node->get_num_valid(vi) <= 1 )
1285             continue;
1286
1287         if( data->is_classifier )
1288         {
1289             if( ci >= 0 )
1290                 split = find_split_cat_class( node, vi );
1291             else
1292                 split = find_split_ord_class( node, vi );
1293         }
1294         else
1295         {
1296             if( ci >= 0 )
1297                 split = find_split_cat_reg( node, vi );
1298             else
1299                 split = find_split_ord_reg( node, vi );
1300         }
1301
1302         if( split )
1303         {
1304             if( !best_split || best_split->quality < split->quality )
1305                 CV_SWAP( best_split, split, t );
1306             if( split )
1307                 cvSetRemoveByPtr( data->split_heap, split );
1308         }
1309     }
1310
1311     return best_split;
1312 }
1313
1314
1315 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1316 {
1317     const float epsilon = FLT_EPSILON*2;
1318     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1319     const int* responses = data->get_class_labels(node);
1320     int n = node->sample_count;
1321     int n1 = node->get_num_valid(vi);
1322     int m = data->get_num_classes();
1323     const int* rc0 = data->counts->data.i;
1324     int* lc = (int*)(rc0 + m);
1325     int* rc = lc + m;
1326     int i, best_i = -1;
1327     double lsum2 = 0, rsum2 = 0, best_val = 0;
1328     const double* priors = data->have_priors ? data->priors->data.db : 0;
1329
1330     // init arrays of class instance counters on both sides of the split
1331     for( i = 0; i < m; i++ )
1332     {
1333         lc[i] = 0;
1334         rc[i] = rc0[i];
1335     }
1336
1337     // compensate for missing values
1338     for( i = n1; i < n; i++ )
1339         rc[responses[sorted[i].i]]--;
1340
1341     if( !priors )
1342     {
1343         int L = 0, R = n1;
1344
1345         for( i = 0; i < m; i++ )
1346             rsum2 += (double)rc[i]*rc[i];
1347
1348         for( i = 0; i < n1 - 1; i++ )
1349         {
1350             int idx = responses[sorted[i].i];
1351             int lv, rv;
1352             L++; R--;
1353             lv = lc[idx]; rv = rc[idx];
1354             lsum2 += lv*2 + 1;
1355             rsum2 -= rv*2 - 1;
1356             lc[idx] = lv + 1; rc[idx] = rv - 1;
1357
1358             if( sorted[i].val + epsilon < sorted[i+1].val )
1359             {
1360                 double val = lsum2/L + rsum2/R;
1361                 if( best_val < val )
1362                 {
1363                     best_val = val;
1364                     best_i = i;
1365                 }
1366             }
1367         }
1368     }
1369     else
1370     {
1371         double L = 0, R = 0;
1372         for( i = 0; i < m; i++ )
1373         {
1374             double wv = rc[i]*priors[i];
1375             R += wv;
1376             rsum2 += wv*wv;
1377         }
1378
1379         for( i = 0; i < n1 - 1; i++ )
1380         {
1381             int idx = responses[sorted[i].i];
1382             int lv, rv;
1383             double p = priors[idx], p2 = p*p;
1384             L += p; R -= p;
1385             lv = lc[idx]; rv = rc[idx];
1386             lsum2 += p2*(lv*2 + 1);
1387             rsum2 -= p2*(rv*2 - 1);
1388             lc[idx] = lv + 1; rc[idx] = rv - 1;
1389
1390             if( sorted[i].val + epsilon < sorted[i+1].val )
1391             {
1392                 double val = lsum2/L + rsum2/R;
1393                 if( best_val < val )
1394                 {
1395                     best_val = val;
1396                     best_i = i;
1397                 }
1398             }
1399         }
1400     }
1401
1402     return best_i >= 0 ? data->new_split_ord( vi,
1403         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1404         0, (float)best_val ) : 0;
1405 }
1406
1407
1408 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1409                                 int* csums, int k, int* labels )
1410 {
1411     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1412     int iters = 0, max_iters = 100;
1413     int i, j, idx;
1414     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1415     double *v_weights = buf, *c_weights = buf + k;
1416     bool modified = true;
1417     CvRNG* r = &data->rng;
1418
1419     // assign labels randomly
1420     for( i = idx = 0; i < n; i++ )
1421     {
1422         int sum = 0;
1423         const int* v = vectors + i*m;
1424         labels[i] = idx++;
1425         idx &= idx < k ? -1 : 0;
1426
1427         // compute weight of each vector
1428         for( j = 0; j < m; j++ )
1429             sum += v[j];
1430         v_weights[i] = sum ? 1./sum : 0.;
1431     }
1432
1433     for( i = 0; i < n; i++ )
1434     {
1435         int i1 = cvRandInt(r) % n;
1436         int i2 = cvRandInt(r) % n;
1437         CV_SWAP( labels[i1], labels[i2], j );
1438     }
1439
1440     for( iters = 0; iters <= max_iters; iters++ )
1441     {
1442         // calculate csums
1443         for( i = 0; i < k; i++ )
1444         {
1445             for( j = 0; j < m; j++ )
1446                 csums[i*m + j] = 0;
1447         }
1448
1449         for( i = 0; i < n; i++ )
1450         {
1451             const int* v = vectors + i*m;
1452             int* s = csums + labels[i]*m;
1453             for( j = 0; j < m; j++ )
1454                 s[j] += v[j];
1455         }
1456
1457         // exit the loop here, when we have up-to-date csums
1458         if( iters == max_iters || !modified )
1459             break;
1460
1461         modified = false;
1462
1463         // calculate weight of each cluster
1464         for( i = 0; i < k; i++ )
1465         {
1466             const int* s = csums + i*m;
1467             int sum = 0;
1468             for( j = 0; j < m; j++ )
1469                 sum += s[j];
1470             c_weights[i] = sum ? 1./sum : 0;
1471         }
1472
1473         // now for each vector determine the closest cluster
1474         for( i = 0; i < n; i++ )
1475         {
1476             const int* v = vectors + i*m;
1477             double alpha = v_weights[i];
1478             double min_dist2 = DBL_MAX;
1479             int min_idx = -1;
1480
1481             for( idx = 0; idx < k; idx++ )
1482             {
1483                 const int* s = csums + idx*m;
1484                 double dist2 = 0., beta = c_weights[idx];
1485                 for( j = 0; j < m; j++ )
1486                 {
1487                     double t = v[j]*alpha - s[j]*beta;
1488                     dist2 += t*t;
1489                 }
1490                 if( min_dist2 > dist2 )
1491                 {
1492                     min_dist2 = dist2;
1493                     min_idx = idx;
1494                 }
1495             }
1496
1497             if( min_idx != labels[i] )
1498                 modified = true;
1499             labels[i] = min_idx;
1500         }
1501     }
1502 }
1503
1504
1505 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1506 {
1507     CvDTreeSplit* split;
1508     const int* labels = data->get_cat_var_data(node, vi);
1509     const int* responses = data->get_class_labels(node);
1510     int ci = data->get_var_type(vi);
1511     int n = node->sample_count;
1512     int m = data->get_num_classes();
1513     int _mi = data->cat_count->data.i[ci], mi = _mi;
1514     const int* rc0 = data->counts->data.i;
1515     int* lc = (int*)(rc0 + m);
1516     int* rc = lc + m;
1517     int* _cjk = rc + m*2, *cjk = _cjk;
1518     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1519     int* cluster_labels = 0;
1520     int** int_ptr = 0;
1521     int i, j, k, idx;
1522     double L = 0, R = 0;
1523     double best_val = 0;
1524     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1525     const double* priors = data->priors->data.db;
1526
1527     // init array of counters:
1528     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1529     for( j = -1; j < mi; j++ )
1530         for( k = 0; k < m; k++ )
1531             cjk[j*m + k] = 0;
1532
1533     for( i = 0; i < n; i++ )
1534     {
1535         j = labels[i];
1536         k = responses[i];
1537         cjk[j*m + k]++;
1538     }
1539
1540     if( m > 2 )
1541     {
1542         if( mi > data->params.max_categories )
1543         {
1544             mi = MIN(data->params.max_categories, n);
1545             cjk += _mi*m;
1546             cluster_labels = cjk + mi*m;
1547             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1548         }
1549         subset_i = 1;
1550         subset_n = 1 << mi;
1551     }
1552     else
1553     {
1554         assert( m == 2 );
1555         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1556         for( j = 0; j < mi; j++ )
1557             int_ptr[j] = cjk + j*2 + 1;
1558         icvSortIntPtr( int_ptr, mi, 0 );
1559         subset_i = 0;
1560         subset_n = mi;
1561     }
1562
1563     for( k = 0; k < m; k++ )
1564     {
1565         int sum = 0;
1566         for( j = 0; j < mi; j++ )
1567             sum += cjk[j*m + k];
1568         rc[k] = sum;
1569         lc[k] = 0;
1570     }
1571
1572     for( j = 0; j < mi; j++ )
1573     {
1574         double sum = 0;
1575         for( k = 0; k < m; k++ )
1576             sum += cjk[j*m + k]*priors[k];
1577         c_weights[j] = sum;
1578         R += c_weights[j];
1579     }
1580
1581     for( ; subset_i < subset_n; subset_i++ )
1582     {
1583         double weight;
1584         int* crow;
1585         double lsum2 = 0, rsum2 = 0;
1586
1587         if( m == 2 )
1588             idx = (int)(int_ptr[subset_i] - cjk)/2;
1589         else
1590         {
1591             int graycode = (subset_i>>1)^subset_i;
1592             int diff = graycode ^ prevcode;
1593
1594             // determine index of the changed bit.
1595             Cv32suf u;
1596             idx = diff >= (1 << 16) ? 16 : 0;
1597             u.f = (float)(((diff >> 16) | diff) & 65535);
1598             idx += (u.i >> 23) - 127;
1599             subtract = graycode < prevcode;
1600             prevcode = graycode;
1601         }
1602
1603         crow = cjk + idx*m;
1604         weight = c_weights[idx];
1605         if( weight < FLT_EPSILON )
1606             continue;
1607
1608         if( !subtract )
1609         {
1610             for( k = 0; k < m; k++ )
1611             {
1612                 int t = crow[k];
1613                 int lval = lc[k] + t;
1614                 int rval = rc[k] - t;
1615                 double p = priors[k], p2 = p*p;
1616                 lsum2 += p2*lval*lval;
1617                 rsum2 += p2*rval*rval;
1618                 lc[k] = lval; rc[k] = rval;
1619             }
1620             L += weight;
1621             R -= weight;
1622         }
1623         else
1624         {
1625             for( k = 0; k < m; k++ )
1626             {
1627                 int t = crow[k];
1628                 int lval = lc[k] - t;
1629                 int rval = rc[k] + t;
1630                 double p = priors[k], p2 = p*p;
1631                 lsum2 += p2*lval*lval;
1632                 rsum2 += p2*rval*rval;
1633                 lc[k] = lval; rc[k] = rval;
1634             }
1635             L -= weight;
1636             R += weight;
1637         }
1638
1639         if( L > FLT_EPSILON && R > FLT_EPSILON )
1640         {
1641             double val = lsum2/L + rsum2/R;
1642             if( best_val < val )
1643             {
1644                 best_val = val;
1645                 best_subset = subset_i;
1646             }
1647         }
1648     }
1649
1650     if( best_subset < 0 )
1651         return 0;
1652
1653     split = data->new_split_cat( vi, (float)best_val );
1654
1655     if( m == 2 )
1656     {
1657         for( i = 0; i <= best_subset; i++ )
1658         {
1659             idx = (int)(int_ptr[i] - cjk) >> 1;
1660             split->subset[idx >> 5] |= 1 << (idx & 31);
1661         }
1662     }
1663     else
1664     {
1665         for( i = 0; i < _mi; i++ )
1666         {
1667             idx = cluster_labels ? cluster_labels[i] : i;
1668             if( best_subset & (1 << idx) )
1669                 split->subset[i >> 5] |= 1 << (i & 31);
1670         }
1671     }
1672
1673     return split;
1674 }
1675
1676
1677 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1678 {
1679     const float epsilon = FLT_EPSILON*2;
1680     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1681     const float* responses = data->get_ord_responses(node);
1682     int n = node->sample_count;
1683     int n1 = node->get_num_valid(vi);
1684     int i, best_i = -1;
1685     double best_val = 0, lsum = 0, rsum = node->value*n;
1686     int L = 0, R = n1;
1687
1688     // compensate for missing values
1689     for( i = n1; i < n; i++ )
1690         rsum -= responses[sorted[i].i];
1691
1692     // find the optimal split
1693     for( i = 0; i < n1 - 1; i++ )
1694     {
1695         float t = responses[sorted[i].i];
1696         L++; R--;
1697         lsum += t;
1698         rsum -= t;
1699
1700         if( sorted[i].val + epsilon < sorted[i+1].val )
1701         {
1702             double val = lsum*lsum/L + rsum*rsum/R;
1703             if( best_val < val )
1704             {
1705                 best_val = val;
1706                 best_i = i;
1707             }
1708         }
1709     }
1710
1711     return best_i >= 0 ? data->new_split_ord( vi,
1712         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1713         0, (float)best_val ) : 0;
1714 }
1715
1716
1717 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1718 {
1719     CvDTreeSplit* split;
1720     const int* labels = data->get_cat_var_data(node, vi);
1721     const float* responses = data->get_ord_responses(node);
1722     int ci = data->get_var_type(vi);
1723     int n = node->sample_count;
1724     int mi = data->cat_count->data.i[ci];
1725     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1726     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1727     double** sum_ptr = 0;
1728     int i, L = 0, R = 0;
1729     double best_val = 0, lsum = 0, rsum = 0;
1730     int best_subset = -1, subset_i;
1731
1732     for( i = -1; i < mi; i++ )
1733         sum[i] = counts[i] = 0;
1734
1735     // calculate sum response and weight of each category of the input var
1736     for( i = 0; i < n; i++ )
1737     {
1738         int idx = labels[i];
1739         double s = sum[idx] + responses[i];
1740         int nc = counts[idx] + 1;
1741         sum[idx] = s;
1742         counts[idx] = nc;
1743     }
1744
1745     // calculate average response in each category
1746     for( i = 0; i < mi; i++ )
1747     {
1748         R += counts[i];
1749         rsum += sum[i];
1750         sum[i] /= MAX(counts[i],1);
1751         sum_ptr[i] = sum + i;
1752     }
1753
1754     icvSortDblPtr( sum_ptr, mi, 0 );
1755
1756     // revert back to unnormalized sum
1757     // (there should be a very little loss of accuracy)
1758     for( i = 0; i < mi; i++ )
1759         sum[i] *= counts[i];
1760
1761     for( subset_i = 0; subset_i < mi-1; subset_i++ )
1762     {
1763         int idx = (int)(sum_ptr[subset_i] - sum);
1764         int ni = counts[idx];
1765
1766         if( ni )
1767         {
1768             double s = sum[idx];
1769             lsum += s; L += ni;
1770             rsum -= s; R -= ni;
1771             
1772             if( L && R )
1773             {
1774                 double val = lsum*lsum/L + rsum*rsum/R;
1775                 if( best_val < val )
1776                 {
1777                     best_val = val;
1778                     best_subset = subset_i;
1779                 }
1780             }
1781         }
1782     }
1783
1784     if( best_subset < 0 )
1785         return 0;
1786
1787     split = data->new_split_cat( vi, (float)best_val );
1788     for( i = 0; i <= best_subset; i++ )
1789     {
1790         int idx = (int)(sum_ptr[i] - sum);
1791         split->subset[idx >> 5] |= 1 << (idx & 31);
1792     }
1793
1794     return split;
1795 }
1796
1797
1798 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
1799 {
1800     const float epsilon = FLT_EPSILON*2;
1801     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1802     const char* dir = (char*)data->direction->data.ptr;
1803     int n1 = node->get_num_valid(vi);
1804     // LL - number of samples that both the primary and the surrogate splits send to the left
1805     // LR - ... primary split sends to the left and the surrogate split sends to the right
1806     // RL - ... primary split sends to the right and the surrogate split sends to the left
1807     // RR - ... both send to the right
1808     int i, best_i = -1, best_inversed = 0;
1809     double best_val; 
1810
1811     if( !data->have_priors )
1812     {
1813         int LL = 0, RL = 0, LR, RR;
1814         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
1815         int sum = 0, sum_abs = 0;
1816         
1817         for( i = 0; i < n1; i++ )
1818         {
1819             int d = dir[sorted[i].i];
1820             sum += d; sum_abs += d & 1;
1821         }
1822
1823         // sum_abs = R + L; sum = R - L
1824         RR = (sum_abs + sum) >> 1;
1825         LR = (sum_abs - sum) >> 1;
1826
1827         // initially all the samples are sent to the right by the surrogate split,
1828         // LR of them are sent to the left by primary split, and RR - to the right.
1829         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1830         for( i = 0; i < n1 - 1; i++ )
1831         {
1832             int d = dir[sorted[i].i];
1833
1834             if( d < 0 )
1835             {
1836                 LL++; LR--;
1837                 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1838                 {
1839                     best_val = LL + RR;
1840                     best_i = i; best_inversed = 0;
1841                 }
1842             }
1843             else if( d > 0 )
1844             {
1845                 RL++; RR--;
1846                 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1847                 {
1848                     best_val = RL + LR;
1849                     best_i = i; best_inversed = 1;
1850                 }
1851             }
1852         }
1853         best_val = _best_val;
1854     }
1855     else
1856     {
1857         double LL = 0, RL = 0, LR, RR;
1858         double worst_val = node->maxlr;
1859         double sum = 0, sum_abs = 0;
1860         const double* priors = data->priors->data.db;
1861         const int* responses = data->get_class_labels(node);
1862         best_val = worst_val;
1863         
1864         for( i = 0; i < n1; i++ )
1865         {
1866             int idx = sorted[i].i;
1867             double w = priors[responses[idx]];
1868             int d = dir[idx];
1869             sum += d*w; sum_abs += (d & 1)*w;
1870         }
1871
1872         // sum_abs = R + L; sum = R - L
1873         RR = (sum_abs + sum)*0.5;
1874         LR = (sum_abs - sum)*0.5;
1875
1876         // initially all the samples are sent to the right by the surrogate split,
1877         // LR of them are sent to the left by primary split, and RR - to the right.
1878         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1879         for( i = 0; i < n1 - 1; i++ )
1880         {
1881             int idx = sorted[i].i;
1882             double w = priors[responses[idx]];
1883             int d = dir[idx];
1884
1885             if( d < 0 )
1886             {
1887                 LL += w; LR -= w;
1888                 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1889                 {
1890                     best_val = LL + RR;
1891                     best_i = i; best_inversed = 0;
1892                 }
1893             }
1894             else if( d > 0 )
1895             {
1896                 RL += w; RR -= w;
1897                 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1898                 {
1899                     best_val = RL + LR;
1900                     best_i = i; best_inversed = 1;
1901                 }
1902             }
1903         }
1904     }
1905
1906     return best_i >= 0 ? data->new_split_ord( vi,
1907         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1908         best_inversed, (float)best_val ) : 0;
1909 }
1910
1911
1912 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
1913 {
1914     const int* labels = data->get_cat_var_data(node, vi);
1915     const char* dir = (char*)data->direction->data.ptr;
1916     int n = node->sample_count;
1917     // LL - number of samples that both the primary and the surrogate splits send to the left
1918     // LR - ... primary split sends to the left and the surrogate split sends to the right
1919     // RL - ... primary split sends to the right and the surrogate split sends to the left
1920     // RR - ... both send to the right
1921     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
1922     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
1923     double best_val = 0;
1924     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
1925     double* rc = lc + mi + 1;
1926     
1927     for( i = -1; i < mi; i++ )
1928         lc[i] = rc[i] = 0;
1929
1930     // for each category calculate the weight of samples
1931     // sent to the left (lc) and to the right (rc) by the primary split
1932     if( !data->have_priors )
1933     {
1934         int* _lc = data->counts->data.i + 1;
1935         int* _rc = _lc + mi + 1;
1936
1937         for( i = -1; i < mi; i++ )
1938             _lc[i] = _rc[i] = 0;
1939
1940         for( i = 0; i < n; i++ )
1941         {
1942             int idx = labels[i];
1943             int d = dir[i];
1944             int sum = _lc[idx] + d;
1945             int sum_abs = _rc[idx] + (d & 1);
1946             _lc[idx] = sum; _rc[idx] = sum_abs;
1947         }
1948
1949         for( i = 0; i < mi; i++ )
1950         {
1951             int sum = _lc[i];
1952             int sum_abs = _rc[i];
1953             lc[i] = (sum_abs - sum) >> 1;
1954             rc[i] = (sum_abs + sum) >> 1;
1955         }
1956     }
1957     else
1958     {
1959         const double* priors = data->priors->data.db;
1960         const int* responses = data->get_class_labels(node);
1961
1962         for( i = 0; i < n; i++ )
1963         {
1964             int idx = labels[i];
1965             double w = priors[responses[i]];
1966             int d = dir[i];
1967             double sum = lc[idx] + d*w;
1968             double sum_abs = rc[idx] + (d & 1)*w;
1969             lc[idx] = sum; rc[idx] = sum_abs;
1970         }
1971
1972         for( i = 0; i < mi; i++ )
1973         {
1974             double sum = lc[i];
1975             double sum_abs = rc[i];
1976             lc[i] = (sum_abs - sum) * 0.5;
1977             rc[i] = (sum_abs + sum) * 0.5;
1978         }
1979     }
1980
1981     // 2. now form the split.
1982     // in each category send all the samples to the same direction as majority
1983     for( i = 0; i < mi; i++ )
1984     {
1985         double lval = lc[i], rval = rc[i];
1986         if( lval > rval )
1987         {
1988             split->subset[i >> 5] |= 1 << (i & 31);
1989             best_val += lval;
1990         }
1991         else
1992             best_val += rval;
1993     }
1994
1995     split->quality = (float)best_val;
1996     if( split->quality <= node->maxlr )
1997         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
1998
1999     return split;
2000 }
2001
2002
2003 void CvDTree::calc_node_value( CvDTreeNode* node )
2004 {
2005     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2006     const int* cv_labels = data->get_cv_labels(node);
2007
2008     if( data->is_classifier )
2009     {
2010         // in case of classification tree:
2011         //  * node value is the label of the class that has the largest weight in the node.
2012         //  * node risk is the weighted number of misclassified samples,
2013         //  * j-th cross-validation fold value and risk are calculated as above,
2014         //    but using the samples with cv_labels(*)!=j.
2015         //  * j-th cross-validation fold error is calculated as the weighted number of
2016         //    misclassified samples with cv_labels(*)==j.
2017
2018         // compute the number of instances of each class
2019         int* cls_count = data->counts->data.i;
2020         const int* responses = data->get_class_labels(node);
2021         int m = data->get_num_classes();
2022         int* cv_cls_count = cls_count + m;
2023         double max_val = -1, total_weight = 0;
2024         int max_k = -1;
2025         double* priors = data->priors->data.db;
2026
2027         for( k = 0; k < m; k++ )
2028             cls_count[k] = 0;
2029
2030         if( cv_n == 0 )
2031         {
2032             for( i = 0; i < n; i++ )
2033                 cls_count[responses[i]]++;
2034         }
2035         else
2036         {
2037             for( j = 0; j < cv_n; j++ )
2038                 for( k = 0; k < m; k++ )
2039                     cv_cls_count[j*m + k] = 0;
2040
2041             for( i = 0; i < n; i++ )
2042             {
2043                 j = cv_labels[i]; k = responses[i];
2044                 cv_cls_count[j*m + k]++;
2045             }
2046
2047             for( j = 0; j < cv_n; j++ )
2048                 for( k = 0; k < m; k++ )
2049                     cls_count[k] += cv_cls_count[j*m + k];
2050         }
2051
2052         for( k = 0; k < m; k++ )
2053         {
2054             double val = cls_count[k]*priors[k];
2055             total_weight += val;
2056             if( max_val < val )
2057             {
2058                 max_val = val;
2059                 max_k = k;
2060             }
2061         }
2062
2063         node->class_idx = max_k;
2064         node->value = data->cat_map->data.i[
2065             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2066         node->node_risk = total_weight - max_val;
2067
2068         for( j = 0; j < cv_n; j++ )
2069         {
2070             double sum_k = 0, sum = 0, max_val_k = 0;
2071             max_val = -1; max_k = -1;
2072
2073             for( k = 0; k < m; k++ )
2074             {
2075                 double w = priors[k];
2076                 double val_k = cv_cls_count[j*m + k]*w;
2077                 double val = cls_count[k]*w - val_k;
2078                 sum_k += val_k;
2079                 sum += val;
2080                 if( max_val < val )
2081                 {
2082                     max_val = val;
2083                     max_val_k = val_k;
2084                     max_k = k;
2085                 }
2086             }
2087
2088             node->cv_Tn[j] = INT_MAX;
2089             node->cv_node_risk[j] = sum - max_val;
2090             node->cv_node_error[j] = sum_k - max_val_k;
2091         }
2092     }
2093     else
2094     {
2095         // in case of regression tree:
2096         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2097         //    n is the number of samples in the node.
2098         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2099         //  * j-th cross-validation fold value and risk are calculated as above,
2100         //    but using the samples with cv_labels(*)!=j.
2101         //  * j-th cross-validation fold error is calculated
2102         //    using samples with cv_labels(*)==j as the test subset:
2103         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2104         //    where node_value_j is the node value calculated
2105         //    as described in the previous bullet, and summation is done
2106         //    over the samples with cv_labels(*)==j.
2107
2108         double sum = 0, sum2 = 0;
2109         const float* values = data->get_ord_responses(node);
2110         double *cv_sum = 0, *cv_sum2 = 0;
2111         int* cv_count = 0;
2112         
2113         if( cv_n == 0 )
2114         {
2115             // if cross-validation is not used, we even do not compute node_risk
2116             // (so the tree sequence T1>...>{root} may not be built).
2117             for( i = 0; i < n; i++ )
2118                 sum += values[i];
2119         }
2120         else
2121         {
2122             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2123             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2124             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2125
2126             for( j = 0; j < cv_n; j++ )
2127             {
2128                 cv_sum[j] = cv_sum2[j] = 0.;
2129                 cv_count[j] = 0;
2130             }
2131
2132             for( i = 0; i < n; i++ )
2133             {
2134                 j = cv_labels[i];
2135                 double t = values[i];
2136                 double s = cv_sum[j] + t;
2137                 double s2 = cv_sum2[j] + t*t;
2138                 int nc = cv_count[j] + 1;
2139                 cv_sum[j] = s;
2140                 cv_sum2[j] = s2;
2141                 cv_count[j] = nc;
2142             }
2143
2144             for( j = 0; j < cv_n; j++ )
2145             {
2146                 sum += cv_sum[j];
2147                 sum2 += cv_sum2[j];
2148             }
2149
2150             node->node_risk = sum2 - (sum/n)*sum;
2151         }
2152
2153         node->value = sum/n;
2154
2155         for( j = 0; j < cv_n; j++ )
2156         {
2157             double s = cv_sum[j], si = sum - s;
2158             double s2 = cv_sum2[j], s2i = sum2 - s2;
2159             int c = cv_count[j], ci = n - c;
2160             double r = si/MAX(ci,1);
2161             node->cv_node_risk[j] = s2i - r*r*ci;
2162             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2163             node->cv_Tn[j] = INT_MAX;
2164         }
2165     }
2166 }
2167
2168
2169 void CvDTree::complete_node_dir( CvDTreeNode* node )
2170 {
2171     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2172     int nz = n - node->get_num_valid(node->split->var_idx);
2173     char* dir = (char*)data->direction->data.ptr;
2174
2175     // try to complete direction using surrogate splits
2176     if( nz && data->params.use_surrogates )
2177     {
2178         CvDTreeSplit* split = node->split->next;
2179         for( ; split != 0 && nz; split = split->next )
2180         {
2181             int inversed_mask = split->inversed ? -1 : 0;
2182             vi = split->var_idx;
2183
2184             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2185             {
2186                 const int* labels = data->get_cat_var_data(node, vi);
2187                 const int* subset = split->subset;
2188
2189                 for( i = 0; i < n; i++ )
2190                 {
2191                     int idx;
2192                     if( !dir[i] && (idx = labels[i]) >= 0 )
2193                     {
2194                         int d = DTREE_CAT_DIR(idx,subset);
2195                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2196                         if( --nz )
2197                             break;
2198                     }
2199                 }
2200             }
2201             else // split on ordered var
2202             {
2203                 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2204                 int split_point = split->ord.split_point;
2205                 int n1 = node->get_num_valid(vi);
2206
2207                 assert( 0 <= split_point && split_point < n-1 );
2208
2209                 for( i = 0; i < n1; i++ )
2210                 {
2211                     int idx = sorted[i].i;
2212                     if( !dir[idx] )
2213                     {
2214                         int d = i <= split_point ? -1 : 1;
2215                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2216                         if( --nz )
2217                             break;
2218                     }
2219                 }
2220             }
2221         }
2222     }
2223
2224     // find the default direction for the rest
2225     if( nz )
2226     {
2227         for( i = nr = 0; i < n; i++ )
2228             nr += dir[i] > 0;
2229         nl = n - nr - nz;
2230         d0 = nl > nr ? -1 : nr > nl;
2231     }
2232
2233     // make sure that every sample is directed either to the left or to the right
2234     for( i = 0; i < n; i++ )
2235     {
2236         int d = dir[i];
2237         if( !d )
2238         {
2239             d = d0;
2240             if( !d )
2241                 d = d1, d1 = -d1;
2242         }
2243         d = d > 0;
2244         dir[i] = (char)d; // remap (-1,1) to (0,1)
2245     }
2246 }
2247
2248
2249 void CvDTree::split_node_data( CvDTreeNode* node )
2250 {
2251     int vi, i, n = node->sample_count, nl, nr;
2252     char* dir = (char*)data->direction->data.ptr;
2253     CvDTreeNode *left = 0, *right = 0;
2254     int* new_idx = data->split_buf->data.i;
2255     int new_buf_idx = data->get_child_buf_idx( node );
2256
2257     complete_node_dir(node);
2258
2259     for( i = nl = nr = 0; i < n; i++ )
2260     {
2261         int d = dir[i];
2262         // initialize new indices for splitting ordered variables
2263         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2264         nr += d;
2265         nl += d^1;
2266     }
2267
2268     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2269     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2270         (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
2271
2272     // split ordered variables, keep both halves sorted.
2273     for( vi = 0; vi < data->var_count; vi++ )
2274     {
2275         int ci = data->get_var_type(vi);
2276         int n1 = node->get_num_valid(vi);
2277         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2278         CvPair32s32f tl, tr;
2279
2280         if( ci >= 0 )
2281             continue;
2282
2283         src = data->get_ord_var_data(node, vi);
2284         ldst0 = ldst = data->get_ord_var_data(left, vi);
2285         rdst0 = rdst = data->get_ord_var_data(right, vi);
2286         tl = ldst0[nl]; tr = rdst0[nr];
2287
2288         // split sorted
2289         for( i = 0; i < n1; i++ )
2290         {
2291             int idx = src[i].i;
2292             float val = src[i].val;
2293             int d = dir[idx];
2294             idx = new_idx[idx];
2295             ldst->i = rdst->i = idx;
2296             ldst->val = rdst->val = val;
2297             ldst += d^1;
2298             rdst += d;
2299         }
2300
2301         left->set_num_valid(vi, (int)(ldst - ldst0));
2302         right->set_num_valid(vi, (int)(rdst - rdst0));
2303
2304         // split missing
2305         for( ; i < n; i++ )
2306         {
2307             int idx = src[i].i;
2308             int d = dir[idx];
2309             idx = new_idx[idx];
2310             ldst->i = rdst->i = idx;
2311             ldst->val = rdst->val = ord_nan;
2312             ldst += d^1;
2313             rdst += d;
2314         }
2315
2316         ldst0[nl] = tl; rdst0[nr] = tr;
2317     }
2318
2319     // split categorical vars, responses and cv_labels using new_idx relocation table
2320     for( vi = 0; vi <= data->var_count + data->have_cv_labels + data->have_weights; vi++ )
2321     {
2322         int ci = data->get_var_type(vi);
2323         int n1 = node->get_num_valid(vi), nr1 = 0;
2324         int *src, *ldst0, *rdst0, *ldst, *rdst;
2325         int tl, tr;
2326
2327         if( ci < 0 )
2328             continue;
2329
2330         src = data->get_cat_var_data(node, vi);
2331         ldst0 = ldst = data->get_cat_var_data(left, vi);
2332         rdst0 = rdst = data->get_cat_var_data(right, vi);
2333         tl = ldst0[nl]; tr = rdst0[nr];
2334
2335         for( i = 0; i < n; i++ )
2336         {
2337             int d = dir[i];
2338             int val = src[i];
2339             *ldst = *rdst = val;
2340             ldst += d^1;
2341             rdst += d;
2342             nr1 += (val >= 0)&d;
2343         }
2344
2345         if( vi < data->var_count )
2346         {
2347             left->set_num_valid(vi, n1 - nr1);
2348             right->set_num_valid(vi, nr1);
2349         }
2350
2351         ldst0[nl] = tl; rdst0[nr] = tr;
2352     }
2353
2354     // deallocate the parent node data that is not needed anymore
2355     data->free_node_data(node);
2356 }
2357
2358
2359 void CvDTree::prune_cv()
2360 {
2361     CvMat* ab = 0;
2362     CvMat* temp = 0;
2363     CvMat* err_jk = 0;
2364     
2365     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2366     // 2. choose the best tree index (if need, apply 1SE rule).
2367     // 3. store the best index and cut the branches.
2368
2369     CV_FUNCNAME( "CvDTree::prune_cv" );
2370
2371     __BEGIN__;
2372
2373     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2374     // currently, 1SE for regression is not implemented
2375     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2376     double* err;
2377     double min_err = 0, min_err_se = 0;
2378     int min_idx = -1;
2379     
2380     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2381
2382     // build the main tree sequence, calculate alpha's
2383     for(;;tree_count++)
2384     {
2385         double min_alpha = update_tree_rnc(tree_count, -1);
2386         if( cut_tree(tree_count, -1, min_alpha) )
2387             break;
2388
2389         if( ab->cols <= tree_count )
2390         {
2391             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2392             for( ti = 0; ti < ab->cols; ti++ )
2393                 temp->data.db[ti] = ab->data.db[ti];
2394             cvReleaseMat( &ab );
2395             ab = temp;
2396             temp = 0;
2397         }
2398
2399         ab->data.db[tree_count] = min_alpha;
2400     }
2401
2402     ab->data.db[0] = 0.;
2403     for( ti = 1; ti < tree_count-1; ti++ )
2404         ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2405     ab->data.db[tree_count-1] = DBL_MAX*0.5;
2406
2407     CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2408     err = err_jk->data.db;
2409
2410     for( j = 0; j < cv_n; j++ )
2411     {
2412         int tj = 0, tk = 0;
2413         for( ; tk < tree_count; tj++ )
2414         {
2415             double min_alpha = update_tree_rnc(tj, j);
2416             if( cut_tree(tj, j, min_alpha) )
2417                 min_alpha = DBL_MAX;
2418
2419             for( ; tk < tree_count; tk++ )
2420             {
2421                 if( ab->data.db[tk] > min_alpha )
2422                     break;
2423                 err[j*tree_count + tk] = root->tree_error;
2424             }
2425         }
2426     }
2427
2428     for( ti = 0; ti < tree_count; ti++ )
2429     {
2430         double sum_err = 0;
2431         for( j = 0; j < cv_n; j++ )
2432             sum_err += err[j*tree_count + ti];
2433         if( ti == 0 || sum_err < min_err )
2434         {
2435             min_err = sum_err;
2436             min_idx = ti;
2437             if( use_1se )
2438                 min_err_se = sqrt( sum_err*(n - sum_err) );
2439         }
2440         else if( sum_err < min_err + min_err_se )
2441             min_idx = ti;
2442     }
2443
2444     pruned_tree_idx = min_idx;
2445     free_prune_data(data->params.truncate_pruned_tree != 0);
2446
2447     __END__;
2448
2449     cvReleaseMat( &err_jk );
2450     cvReleaseMat( &ab );
2451     cvReleaseMat( &temp );
2452 }
2453
2454
2455 double CvDTree::update_tree_rnc( int T, int fold )
2456 {
2457     CvDTreeNode* node = root;
2458     double min_alpha = DBL_MAX;
2459     
2460     for(;;)
2461     {
2462         CvDTreeNode* parent;
2463         for(;;)
2464         {
2465             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2466             if( t <= T || !node->left )
2467             {
2468                 node->complexity = 1;
2469                 node->tree_risk = node->node_risk;
2470                 node->tree_error = 0.;
2471                 if( fold >= 0 )
2472                 {
2473                     node->tree_risk = node->cv_node_risk[fold];
2474                     node->tree_error = node->cv_node_error[fold];
2475                 }
2476                 break;
2477             }
2478             node = node->left;
2479         }
2480         
2481         for( parent = node->parent; parent && parent->right == node;
2482             node = parent, parent = parent->parent )
2483         {
2484             parent->complexity += node->complexity;
2485             parent->tree_risk += node->tree_risk;
2486             parent->tree_error += node->tree_error;
2487
2488             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2489                 - parent->tree_risk)/(parent->complexity - 1);
2490             min_alpha = MIN( min_alpha, parent->alpha );
2491         }
2492
2493         if( !parent )
2494             break;
2495
2496         parent->complexity = node->complexity;
2497         parent->tree_risk = node->tree_risk;
2498         parent->tree_error = node->tree_error;
2499         node = parent->right;
2500     }
2501
2502     return min_alpha;
2503 }
2504
2505
2506 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2507 {
2508     CvDTreeNode* node = root;
2509     if( !node->left )
2510         return 1;
2511
2512     for(;;)
2513     {
2514         CvDTreeNode* parent;
2515         for(;;)
2516         {
2517             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2518             if( t <= T || !node->left )
2519                 break;
2520             if( node->alpha <= min_alpha + FLT_EPSILON )
2521             {
2522                 if( fold >= 0 )
2523                     node->cv_Tn[fold] = T;
2524                 else
2525                     node->Tn = T;
2526                 if( node == root )
2527                     return 1;
2528                 break;
2529             }
2530             node = node->left;
2531         }
2532         
2533         for( parent = node->parent; parent && parent->right == node;
2534             node = parent, parent = parent->parent )
2535             ;
2536
2537         if( !parent )
2538             break;
2539
2540         node = parent->right;
2541     }
2542
2543     return 0;
2544 }
2545
2546
2547 void CvDTree::free_prune_data(bool cut_tree)
2548 {
2549     CvDTreeNode* node = root;
2550     
2551     for(;;)
2552     {
2553         CvDTreeNode* parent;
2554         for(;;)
2555         {
2556             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2557             // as we will clear the whole cross-validation heap at the end
2558             node->cv_Tn = 0;
2559             node->cv_node_error = node->cv_node_risk = 0;
2560             if( !node->left )
2561                 break;
2562             node = node->left;
2563         }
2564         
2565         for( parent = node->parent; parent && parent->right == node;
2566             node = parent, parent = parent->parent )
2567         {
2568             if( cut_tree && parent->Tn <= pruned_tree_idx )
2569             {
2570                 data->free_node( parent->left );
2571                 data->free_node( parent->right );
2572                 parent->left = parent->right = 0;
2573             }
2574         }
2575
2576         if( !parent )
2577             break;
2578
2579         node = parent->right;
2580     }
2581
2582     if( data->cv_heap )
2583         cvClearSet( data->cv_heap );
2584 }
2585
2586
2587 void CvDTree::free_tree()
2588 {
2589     if( root && data && data->shared )
2590     {
2591         pruned_tree_idx = INT_MIN;
2592         free_prune_data(true);
2593         data->free_node(root);
2594         root = 0;
2595     }
2596 }
2597
2598
2599 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2600     const CvMat* _missing, bool preprocessed_input ) const
2601 {
2602     CvDTreeNode* result = 0;
2603     int* catbuf = 0;
2604
2605     CV_FUNCNAME( "CvDTree::predict" );
2606
2607     __BEGIN__;
2608
2609     int i, step, mstep = 0;
2610     const float* sample;
2611     const uchar* m = 0;
2612     CvDTreeNode* node = root;
2613     const int* vtype;
2614     const int* vidx;
2615     const int* cmap;
2616     const int* cofs;
2617
2618     if( !node )
2619         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2620
2621     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2622         _sample->cols != 1 && _sample->rows != 1 ||
2623         _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2624         _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2625             CV_ERROR( CV_StsBadArg,
2626         "the input sample must be 1d floating-point vector with the same "
2627         "number of elements as the total number of variables used for training" );
2628
2629     sample = _sample->data.fl;
2630     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
2631
2632     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2633     {
2634         int n = data->cat_count->cols;
2635         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2636         for( i = 0; i < n; i++ )
2637             catbuf[i] = -1;
2638     }
2639
2640     if( _missing )
2641     {
2642         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2643         !CV_ARE_SIZES_EQ(_missing, _sample) )
2644             CV_ERROR( CV_StsBadArg,
2645         "the missing data mask must be 8-bit vector of the same size as input sample" );
2646         m = _missing->data.ptr;
2647         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2648     }
2649
2650     vtype = data->var_type->data.i;
2651     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2652     cmap = data->cat_map->data.i;
2653     cofs = data->cat_ofs->data.i;
2654
2655     while( node->Tn > pruned_tree_idx && node->left )
2656     {
2657         CvDTreeSplit* split = node->split;
2658         int dir = 0;
2659         for( ; !dir && split != 0; split = split->next )
2660         {
2661             int vi = split->var_idx;
2662             int ci = vtype[vi];
2663             i = vidx ? vidx[vi] : vi;
2664             float val = sample[i*step];
2665             if( m && m[i*mstep] )
2666                 continue;
2667             if( ci < 0 ) // ordered
2668                 dir = val <= split->ord.c ? -1 : 1;
2669             else // categorical
2670             {
2671                 int c;
2672                 if( preprocessed_input )
2673                     c = cvRound(val);
2674                 else
2675                 {
2676                     c = catbuf[ci];
2677                     if( c < 0 )
2678                     {
2679                         int a = c = cofs[ci];
2680                         int b = cofs[ci+1];
2681                         int ival = cvRound(val);
2682                         if( ival != val )
2683                             CV_ERROR( CV_StsBadArg,
2684                             "one of input categorical variable is not an integer" );
2685
2686                         while( a < b )
2687                         {
2688                             c = (a + b) >> 1;
2689                             if( ival < cmap[c] )
2690                                 b = c;
2691                             else if( ival > cmap[c] )
2692                                 a = c+1;
2693                             else
2694                                 break;
2695                         }
2696
2697                         if( c < 0 || ival != cmap[c] )
2698                             continue;
2699
2700                         catbuf[ci] = c -= cofs[ci];
2701                     }
2702                 }
2703                 dir = DTREE_CAT_DIR(c, split->subset);
2704             }
2705
2706             if( split->inversed )
2707                 dir = -dir;
2708         }
2709
2710         if( !dir )
2711         {
2712             double diff = node->right->sample_count - node->left->sample_count;
2713             dir = diff < 0 ? -1 : 1;
2714         }
2715         node = dir < 0 ? node->left : node->right;
2716     }
2717
2718     result = node;
2719
2720     __END__;
2721
2722     return result;
2723 }
2724
2725
2726 const CvMat* CvDTree::get_var_importance()
2727 {
2728     if( !var_importance )
2729     {
2730         CvDTreeNode* node = root;
2731         double* importance;
2732         if( !node )
2733             return 0;
2734         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
2735         cvZero( var_importance );
2736         importance = var_importance->data.db;
2737
2738         for(;;)
2739         {
2740             CvDTreeNode* parent;
2741             for( ;; node = node->left )
2742             {
2743                 CvDTreeSplit* split = node->split;
2744                 
2745                 if( !node->left || node->Tn <= pruned_tree_idx )
2746                     break;
2747                 
2748                 for( ; split != 0; split = split->next )
2749                     importance[split->var_idx] += split->quality;
2750             }
2751
2752             for( parent = node->parent; parent && parent->right == node;
2753                 node = parent, parent = parent->parent )
2754                 ;
2755
2756             if( !parent )
2757                 break;
2758
2759             node = parent->right;
2760         }
2761
2762         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
2763     }
2764
2765     return var_importance;
2766 }
2767
2768
2769 void CvDTree::write_train_data_params( CvFileStorage* fs )
2770 {
2771     CV_FUNCNAME( "CvDTree::write_train_data_params" );
2772
2773     __BEGIN__;
2774
2775     int vi, vcount = data->var_count;
2776
2777     cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
2778     cvWriteInt( fs, "var_all", data->var_all );
2779     cvWriteInt( fs, "var_count", data->var_count );
2780     cvWriteInt( fs, "ord_var_count", data->ord_var_count );
2781     cvWriteInt( fs, "cat_var_count", data->cat_var_count );
2782
2783     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
2784     cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
2785
2786     if( data->is_classifier )
2787     {
2788         cvWriteInt( fs, "max_categories", data->params.max_categories );
2789     }
2790     else
2791     {
2792         cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
2793     }
2794
2795     cvWriteInt( fs, "max_depth", data->params.max_depth );
2796     cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
2797     cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
2798     
2799     if( data->params.cv_folds > 1 )
2800     {
2801         cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
2802         cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
2803     }
2804
2805     if( data->priors )
2806         cvWrite( fs, "priors", data->priors );
2807
2808     cvEndWriteStruct( fs );
2809
2810     if( data->var_idx )
2811         cvWrite( fs, "var_idx", data->var_idx );
2812     
2813     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
2814
2815     for( vi = 0; vi < vcount; vi++ )
2816         cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
2817
2818     cvEndWriteStruct( fs );
2819
2820     if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
2821     {
2822         CV_ASSERT( data->cat_count != 0 );
2823         cvWrite( fs, "cat_count", data->cat_count );
2824         cvWrite( fs, "cat_map", data->cat_map );
2825     }
2826
2827     __END__;
2828 }
2829
2830
2831 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
2832 {
2833     int ci;
2834     
2835     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
2836     cvWriteInt( fs, "var", split->var_idx );
2837     cvWriteReal( fs, "quality", split->quality );
2838
2839     ci = data->get_var_type(split->var_idx);
2840     if( ci >= 0 ) // split on a categorical var
2841     {
2842         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
2843         for( i = 0; i < n; i++ )
2844             to_right += DTREE_CAT_DIR(i,split->subset) > 0;
2845
2846         // ad-hoc rule when to use inverse categorical split notation
2847         // to achieve more compact and clear representation
2848         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
2849         
2850         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
2851                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
2852         for( i = 0; i < n; i++ )
2853         {
2854             int dir = DTREE_CAT_DIR(i,split->subset);
2855             if( dir*default_dir < 0 )
2856                 cvWriteInt( fs, 0, i );
2857         }
2858         cvEndWriteStruct( fs );
2859     }
2860     else
2861         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
2862
2863     cvEndWriteStruct( fs );
2864 }
2865
2866
2867 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
2868 {
2869     CvDTreeSplit* split;
2870     
2871     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2872
2873     cvWriteInt( fs, "depth", node->depth );
2874     cvWriteInt( fs, "sample_count", node->sample_count );
2875     cvWriteReal( fs, "value", node->value );
2876     
2877     if( data->is_classifier )
2878         cvWriteInt( fs, "norm_class_idx", node->class_idx );
2879
2880     cvWriteInt( fs, "Tn", node->Tn );
2881     cvWriteInt( fs, "complexity", node->complexity );
2882     cvWriteReal( fs, "alpha", node->alpha );
2883     cvWriteReal( fs, "node_risk", node->node_risk );
2884     cvWriteReal( fs, "tree_risk", node->tree_risk );
2885     cvWriteReal( fs, "tree_error", node->tree_error );
2886
2887     if( node->left )
2888     {
2889         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
2890
2891         for( split = node->split; split != 0; split = split->next )
2892             write_split( fs, split );
2893
2894         cvEndWriteStruct( fs );
2895     }
2896
2897     cvEndWriteStruct( fs );
2898 }
2899
2900
2901 void CvDTree::write_tree_nodes( CvFileStorage* fs )
2902 {
2903     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
2904
2905     __BEGIN__;
2906
2907     CvDTreeNode* node = root;
2908
2909     // traverse the tree and save all the nodes in depth-first order
2910     for(;;)
2911     {
2912         CvDTreeNode* parent;
2913         for(;;)
2914         {
2915             write_node( fs, node );
2916             if( !node->left )
2917                 break;
2918             node = node->left;
2919         }
2920         
2921         for( parent = node->parent; parent && parent->right == node;
2922             node = parent, parent = parent->parent )
2923             ;
2924
2925         if( !parent )
2926             break;
2927
2928         node = parent->right;
2929     }
2930
2931     __END__;
2932 }
2933
2934
2935 void CvDTree::write( CvFileStorage* fs, const char* name )
2936 {
2937     //CV_FUNCNAME( "CvDTree::write" );
2938
2939     __BEGIN__;
2940
2941     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
2942
2943     write_train_data_params( fs );
2944
2945     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
2946     get_var_importance();
2947     cvWrite( fs, "var_importance", var_importance );
2948
2949     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
2950     write_tree_nodes( fs );
2951     cvEndWriteStruct( fs );
2952
2953     cvEndWriteStruct( fs );
2954
2955     __END__;
2956 }
2957
2958
2959 void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
2960 {
2961     CV_FUNCNAME( "CvDTree::read_train_data_params" );
2962
2963     __BEGIN__;
2964     
2965     CvDTreeParams params;
2966     CvFileNode *tparams_node, *vartype_node;
2967     CvSeqReader reader;
2968     int is_classifier, vi, cat_var_count, ord_var_count;
2969     int max_split_size, tree_block_size;
2970
2971     data = new CvDTreeTrainData;
2972
2973     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
2974     data->is_classifier = (is_classifier != 0);
2975     data->var_all = cvReadIntByName( fs, node, "var_all" );
2976     data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
2977     data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
2978     data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
2979
2980     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
2981
2982     if( tparams_node ) // training parameters are not necessary
2983     {
2984         data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
2985
2986         if( is_classifier )
2987         {
2988             data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
2989         }
2990         else
2991         {
2992             data->params.regression_accuracy =
2993                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
2994         }
2995
2996         data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
2997         data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
2998         data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
2999     
3000         if( data->params.cv_folds > 1 )
3001         {
3002             data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
3003             data->params.truncate_pruned_tree =
3004                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
3005         }
3006
3007         data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
3008         if( data->priors && !CV_IS_MAT(data->priors) )
3009             CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
3010     }
3011
3012     CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
3013     if( data->var_idx )
3014     {
3015         if( !CV_IS_MAT(data->var_idx) ||
3016             data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
3017             data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
3018             CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
3019             CV_ERROR( CV_StsParseError,
3020                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
3021
3022         for( vi = 0; vi < data->var_count; vi++ )
3023             if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
3024                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
3025     }
3026     
3027     ////// read var type
3028     CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
3029
3030     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
3031     if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
3032         vartype_node->data.seq->total != data->var_count )
3033         CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3034
3035     cvStartReadSeq( vartype_node->data.seq, &reader );
3036
3037     cat_var_count = 0;
3038     ord_var_count = -1;
3039     for( vi = 0; vi < data->var_count; vi++ )
3040     {
3041         CvFileNode* n = (CvFileNode*)reader.ptr;
3042         if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
3043             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3044         data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
3045         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3046     }
3047     
3048     ord_var_count = ~ord_var_count;
3049     if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
3050         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
3051     //////
3052
3053     if( data->cat_var_count > 0 || is_classifier )
3054     {
3055         int ccount, max_c_count = 0, total_c_count = 0;
3056         CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
3057         CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
3058
3059         if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
3060             data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
3061             CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
3062             data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
3063             data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
3064             CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
3065             CV_ERROR( CV_StsParseError,
3066             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
3067
3068         ccount = cat_var_count + is_classifier;
3069
3070         CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
3071         data->cat_ofs->data.i[0] = 0;
3072
3073         for( vi = 0; vi < ccount; vi++ )
3074         {
3075             int val = data->cat_count->data.i[vi];
3076             if( val <= 0 )
3077                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
3078             max_c_count = MAX( max_c_count, val );
3079             data->cat_ofs->data.i[vi+1] = total_c_count += val;
3080         }
3081
3082         if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
3083             CV_ERROR( CV_StsBadSize,
3084             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
3085         
3086         data->max_c_count = max_c_count;
3087     }
3088
3089     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
3090         (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
3091
3092     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
3093     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
3094     CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
3095     CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
3096             sizeof(CvDTreeNode), data->tree_storage ));
3097     CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
3098             max_split_size, data->tree_storage ));
3099
3100     __END__;
3101 }
3102
3103
3104 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3105 {
3106     CvDTreeSplit* split = 0;
3107     
3108     CV_FUNCNAME( "CvDTree::read_split" );
3109
3110     __BEGIN__;
3111
3112     int vi, ci;
3113     
3114     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3115         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3116
3117     vi = cvReadIntByName( fs, fnode, "var", -1 );
3118     if( (unsigned)vi >= (unsigned)data->var_count )
3119         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3120
3121     ci = data->get_var_type(vi);
3122     if( ci >= 0 ) // split on categorical var
3123     {
3124         int i, n = data->cat_count->data.i[ci], inversed = 0;
3125         CvSeqReader reader;
3126         CvFileNode* inseq;
3127         split = data->new_split_cat( vi, 0 );
3128         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3129         if( !inseq )
3130         {
3131             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3132             inversed = 1;
3133         }
3134         if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
3135             CV_ERROR( CV_StsParseError,
3136             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3137
3138         cvStartReadSeq( inseq->data.seq, &reader );
3139
3140         for( i = 0; i < reader.seq->total; i++ )
3141         {
3142             CvFileNode* inode = (CvFileNode*)reader.ptr;
3143             int val = inode->data.i;
3144             if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3145                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3146
3147             split->subset[val >> 5] |= 1 << (val & 31);
3148             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3149         }
3150
3151         // for categorical splits we do not use inversed splits,
3152         // instead we inverse the variable set in the split
3153         if( inversed )
3154             for( i = 0; i < (n + 31) >> 5; i++ )
3155                 split->subset[i] ^= -1;
3156     }
3157     else
3158     {
3159         CvFileNode* cmp_node;
3160         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3161
3162         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3163         if( !cmp_node )
3164         {
3165             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3166             split->inversed = 1;
3167         }
3168
3169         split->ord.c = (float)cvReadReal( cmp_node );
3170     }
3171         
3172     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3173
3174     __END__;
3175     
3176     return split;
3177 }
3178
3179
3180 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3181 {
3182     CvDTreeNode* node = 0;
3183     
3184     CV_FUNCNAME( "CvDTree::read_node" );
3185
3186     __BEGIN__;
3187
3188     CvFileNode* splits;
3189     int i, depth;
3190
3191     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3192         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3193
3194     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3195     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3196     if( depth != node->depth )
3197         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3198
3199     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3200     node->value = cvReadRealByName( fs, fnode, "value" );
3201     if( data->is_classifier )
3202         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3203
3204     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3205     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3206     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3207     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3208     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3209     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3210
3211     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3212     if( splits )
3213     {
3214         CvSeqReader reader;
3215         CvDTreeSplit* last_split = 0;
3216
3217         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3218             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3219
3220         cvStartReadSeq( splits->data.seq, &reader );
3221         for( i = 0; i < reader.seq->total; i++ )
3222         {
3223             CvDTreeSplit* split;
3224             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3225             if( !last_split )
3226                 node->split = last_split = split;
3227             else
3228                 last_split = last_split->next = split;
3229
3230             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3231         }
3232     }
3233
3234     __END__;
3235     
3236     return node;
3237 }
3238
3239
3240 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3241 {
3242     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3243
3244     __BEGIN__;
3245
3246     CvSeqReader reader;
3247     CvDTreeNode _root;
3248     CvDTreeNode* parent = &_root;
3249     int i;
3250     parent->left = parent->right = parent->parent = 0;
3251
3252     cvStartReadSeq( fnode->data.seq, &reader );
3253
3254     for( i = 0; i < reader.seq->total; i++ )
3255     {
3256         CvDTreeNode* node;
3257         
3258         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3259         if( !parent->left )
3260             parent->left = node;
3261         else
3262             parent->right = node;
3263         if( node->split )
3264             parent = node;
3265         else
3266         {
3267             while( parent && parent->right )
3268                 parent = parent->parent;
3269         }
3270
3271         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3272     }
3273
3274     root = _root.left;
3275
3276     __END__;
3277 }
3278
3279
3280 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3281 {
3282     CV_FUNCNAME( "CvDTree::read" );
3283
3284     __BEGIN__;
3285
3286     CvFileNode* tree_nodes;
3287
3288     clear();
3289     read_train_data_params( fs, fnode );
3290
3291     tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
3292     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3293         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3294
3295     pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
3296
3297     read_tree_nodes( fs, tree_nodes );
3298     get_var_importance(); // recompute variable importance
3299
3300     __END__;
3301 }
3302
3303 /* End of file. */