]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
fixed dtree
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42 #include <ctype.h>
43
44 using namespace cv;
45
46 static const float ord_nan = FLT_MAX*0.5f;
47 static const int min_block_size = 1 << 16;
48 static const int block_size_delta = 1 << 10;
49
50 CvDTreeTrainData::CvDTreeTrainData()
51 {
52     var_idx = var_type = cat_count = cat_ofs = cat_map =
53         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
54     tree_storage = temp_storage = 0;
55
56     clear();
57 }
58
59
60 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
61                       const CvMat* _responses, const CvMat* _var_idx,
62                       const CvMat* _sample_idx, const CvMat* _var_type,
63                       const CvMat* _missing_mask, const CvDTreeParams& _params,
64                       bool _shared, bool _add_labels )
65 {
66     var_idx = var_type = cat_count = cat_ofs = cat_map =
67         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
68
69     tree_storage = temp_storage = 0;
70
71     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
72               _var_type, _missing_mask, _params, _shared, _add_labels );
73 }
74
75
76 CvDTreeTrainData::~CvDTreeTrainData()
77 {
78     clear();
79 }
80
81
82 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
83 {
84     bool ok = false;
85
86     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
87
88     __BEGIN__;
89
90     // set parameters
91     params = _params;
92
93     if( params.max_categories < 2 )
94         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
95     params.max_categories = MIN( params.max_categories, 15 );
96
97     if( params.max_depth < 0 )
98         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
99     params.max_depth = MIN( params.max_depth, 25 );
100
101     params.min_sample_count = MAX(params.min_sample_count,1);
102
103     if( params.cv_folds < 0 )
104         CV_ERROR( CV_StsOutOfRange,
105         "params.cv_folds should be =0 (the tree is not pruned) "
106         "or n>0 (tree is pruned using n-fold cross-validation)" );
107
108     if( params.cv_folds == 1 )
109         params.cv_folds = 0;
110
111     if( params.regression_accuracy < 0 )
112         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
113
114     ok = true;
115
116     __END__;
117
118     return ok;
119 }
120
121 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
122 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
123 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
124
125 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])
126 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )
127 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )
128
129 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
130 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
131
132 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
133     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
134     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
135     bool _shared, bool _add_labels, bool _update_data )
136 {
137     CvMat* sample_indices = 0;
138     CvMat* var_type0 = 0;
139     CvMat* tmp_map = 0;
140     int** int_ptr = 0;
141     CvPair16u32s* pair16u32s_ptr = 0;
142     CvDTreeTrainData* data = 0;
143     float *_fdst = 0;
144     int *_idst = 0;
145     unsigned short* udst = 0;
146     int* idst = 0;
147
148     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
149
150     __BEGIN__;
151
152     int sample_all = 0, r_type = 0, cv_n;
153     int total_c_count = 0;
154     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
155     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
156     int vi, i, size;
157     char err[100];
158     const int *sidx = 0, *vidx = 0;
159     
160     if( _update_data && data_root )
161     {
162         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
163             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
164
165         // compare new and old train data
166         if( !(data->var_count == var_count &&
167             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
168             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
169             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
170             CV_ERROR( CV_StsBadArg,
171             "The new training data must have the same types and the input and output variables "
172             "and the same categories for categorical variables" );
173
174         cvReleaseMat( &priors );
175         cvReleaseMat( &priors_mult );
176         cvReleaseMat( &buf );
177         cvReleaseMat( &direction );
178         cvReleaseMat( &split_buf );
179         cvReleaseMemStorage( &temp_storage );
180
181         priors = data->priors; data->priors = 0;
182         priors_mult = data->priors_mult; data->priors_mult = 0;
183         buf = data->buf; data->buf = 0;
184         buf_count = data->buf_count; buf_size = data->buf_size;
185         sample_count = data->sample_count;
186
187         direction = data->direction; data->direction = 0;
188         split_buf = data->split_buf; data->split_buf = 0;
189         temp_storage = data->temp_storage; data->temp_storage = 0;
190         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
191
192         data_root = new_node( 0, sample_count, 0, 0 );
193         EXIT;
194     }
195
196     clear();
197
198     var_all = 0;
199     rng = cvRNG(-1);
200
201     CV_CALL( set_params( _params ));
202
203     // check parameter types and sizes
204     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
205
206     train_data = _train_data;
207     responses = _responses;
208
209     if( _tflag == CV_ROW_SAMPLE )
210     {
211         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
212         dv_step = 1;
213         if( _missing_mask )
214             ms_step = _missing_mask->step, mv_step = 1;
215     }
216     else
217     {
218         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
219         ds_step = 1;
220         if( _missing_mask )
221             mv_step = _missing_mask->step, ms_step = 1;
222     }
223     tflag = _tflag;
224
225     sample_count = sample_all;
226     var_count = var_all;
227     
228     if( _sample_idx )
229     {
230         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
231         sidx = sample_indices->data.i;
232         sample_count = sample_indices->rows + sample_indices->cols - 1;
233     }
234
235     if( _var_idx )
236     {
237         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
238         vidx = var_idx->data.i;
239         var_count = var_idx->rows + var_idx->cols - 1;
240     }
241
242     is_buf_16u = false;     
243     if ( sample_count < 65536 ) 
244         is_buf_16u = true;                                
245     
246     if( !CV_IS_MAT(_responses) ||
247         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
248          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
249         (_responses->rows != 1 && _responses->cols != 1) ||
250         _responses->rows + _responses->cols - 1 != sample_all )
251         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
252                   "floating-point vector containing as many elements as "
253                   "the total number of samples in the training data matrix" );
254    
255   
256     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
257
258     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
259    
260     
261     cat_var_count = 0;
262     ord_var_count = -1;
263
264     is_classifier = r_type == CV_VAR_CATEGORICAL;
265
266     // step 0. calc the number of categorical vars
267     for( vi = 0; vi < var_count; vi++ )
268     {
269         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
270             cat_var_count++ : ord_var_count--;
271     }
272
273     ord_var_count = ~ord_var_count;
274     cv_n = params.cv_folds;
275     // set the two last elements of var_type array to be able
276     // to locate responses and cross-validation labels using
277     // the corresponding get_* functions.
278     var_type->data.i[var_count] = cat_var_count;
279     var_type->data.i[var_count+1] = cat_var_count+1;
280
281     // in case of single ordered predictor we need dummy cv_labels
282     // for safe split_node_data() operation
283     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
284
285     work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
286                                + (have_labels ? 1 : 0); // for cv_labels
287                                
288     buf_size = (work_var_count + 1 /*for sample_indices*/) * sample_count;
289     shared = _shared;
290     buf_count = shared ? 2 : 1;
291     
292     if ( is_buf_16u )
293     {
294         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
295         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
296     }
297     else
298     {
299         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
300         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
301     }    
302
303     size = is_classifier ? (cat_var_count+1) : cat_var_count;
304     size = !size ? 1 : size;
305     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
306     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
307         
308     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
309     size = !size ? 1 : size;
310     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
311
312     // now calculate the maximum size of split,
313     // create memory storage that will keep nodes and splits of the decision tree
314     // allocate root node and the buffer for the whole training data
315     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
316         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
317     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
318     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
319     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
320     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
321
322     nv_size = var_count*sizeof(int);
323     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
324
325     temp_block_size = nv_size;
326
327     if( cv_n )
328     {
329         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
330             CV_ERROR( CV_StsOutOfRange,
331                 "The many folds in cross-validation for such a small dataset" );
332
333         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
334         temp_block_size = MAX(temp_block_size, cv_size);
335     }
336
337     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
338     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
339     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
340     if( cv_size )
341         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
342
343     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
344
345     max_c_count = 1;
346
347     _fdst = 0;
348     _idst = 0;
349     if (ord_var_count)
350         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
351     if (is_buf_16u && (cat_var_count || is_classifier))
352         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
353
354     // transform the training data to convenient representation
355     for( vi = 0; vi <= var_count; vi++ )
356     {
357         int ci;
358         const uchar* mask = 0;
359         int m_step = 0, step;
360         const int* idata = 0;
361         const float* fdata = 0;
362         int num_valid = 0;
363
364         if( vi < var_count ) // analyze i-th input variable
365         {
366             int vi0 = vidx ? vidx[vi] : vi;
367             ci = get_var_type(vi);
368             step = ds_step; m_step = ms_step;
369             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
370                 idata = _train_data->data.i + vi0*dv_step;
371             else
372                 fdata = _train_data->data.fl + vi0*dv_step;
373             if( _missing_mask )
374                 mask = _missing_mask->data.ptr + vi0*mv_step;
375         }
376         else // analyze _responses
377         {
378             ci = cat_var_count;
379             step = CV_IS_MAT_CONT(_responses->type) ?
380                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
381             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
382                 idata = _responses->data.i;
383             else
384                 fdata = _responses->data.fl;
385         }
386
387         if( (vi < var_count && ci>=0) ||
388             (vi == var_count && is_classifier) ) // process categorical variable or response
389         {
390             int c_count, prev_label;
391             int* c_map;
392             
393             if (is_buf_16u)
394                 udst = (unsigned short*)(buf->data.s + vi*sample_count);
395             else
396                 idst = buf->data.i + vi*sample_count;
397             
398             // copy data
399             for( i = 0; i < sample_count; i++ )
400             {
401                 int val = INT_MAX, si = sidx ? sidx[i] : i;
402                 if( !mask || !mask[si*m_step] )
403                 {
404                     if( idata )
405                         val = idata[si*step];
406                     else
407                     {
408                         float t = fdata[si*step];
409                         val = cvRound(t);
410                         if( fabs(t - val) > FLT_EPSILON )
411                         {
412                             sprintf( err, "%d-th value of %d-th (categorical) "
413                                 "variable is not an integer", i, vi );
414                             CV_ERROR( CV_StsBadArg, err );
415                         }
416                     }
417
418                     if( val == INT_MAX )
419                     {
420                         sprintf( err, "%d-th value of %d-th (categorical) "
421                             "variable is too large", i, vi );
422                         CV_ERROR( CV_StsBadArg, err );
423                     }
424                     num_valid++;
425                 }
426                 if (is_buf_16u)
427                 {
428                     _idst[i] = val;
429                     pair16u32s_ptr[i].u = udst + i;
430                     pair16u32s_ptr[i].i = _idst + i;
431                 }   
432                 else
433                 {
434                     idst[i] = val;
435                     int_ptr[i] = idst + i;
436                 }
437             }
438
439             c_count = num_valid > 0;
440             if (is_buf_16u)
441             {
442                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
443                 // count the categories
444                 for( i = 1; i < num_valid; i++ )
445                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
446                         c_count ++ ;
447             }
448             else
449             {
450                 icvSortIntPtr( int_ptr, sample_count, 0 );
451                 // count the categories
452                 for( i = 1; i < num_valid; i++ )
453                     c_count += *int_ptr[i] != *int_ptr[i-1];
454             }
455
456             if( vi > 0 )
457                 max_c_count = MAX( max_c_count, c_count );
458             cat_count->data.i[ci] = c_count;
459             cat_ofs->data.i[ci] = total_c_count;
460
461             // resize cat_map, if need
462             if( cat_map->cols < total_c_count + c_count )
463             {
464                 tmp_map = cat_map;
465                 CV_CALL( cat_map = cvCreateMat( 1,
466                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
467                 for( i = 0; i < total_c_count; i++ )
468                     cat_map->data.i[i] = tmp_map->data.i[i];
469                 cvReleaseMat( &tmp_map );
470             }
471
472             c_map = cat_map->data.i + total_c_count;
473             total_c_count += c_count;
474
475             c_count = -1;
476             if (is_buf_16u)
477             {
478                 // compact the class indices and build the map
479                 prev_label = ~*pair16u32s_ptr[0].i;
480                 for( i = 0; i < num_valid; i++ )
481                 {
482                     int cur_label = *pair16u32s_ptr[i].i;
483                     if( cur_label != prev_label )
484                         c_map[++c_count] = prev_label = cur_label;
485                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
486                 }
487                 // replace labels for missing values with -1
488                 for( ; i < sample_count; i++ )
489                     *pair16u32s_ptr[i].u = 65535;
490             }
491             else
492             {
493                 // compact the class indices and build the map
494                 prev_label = ~*int_ptr[0];
495                 for( i = 0; i < num_valid; i++ )
496                 {
497                     int cur_label = *int_ptr[i];
498                     if( cur_label != prev_label )
499                         c_map[++c_count] = prev_label = cur_label;
500                     *int_ptr[i] = c_count;
501                 }
502                 // replace labels for missing values with -1
503                 for( ; i < sample_count; i++ )
504                     *int_ptr[i] = -1;
505             }           
506         }
507         else if( ci < 0 ) // process ordered variable
508         {
509             if (is_buf_16u)
510                 udst = (unsigned short*)(buf->data.s + vi*sample_count);
511             else
512                 idst = buf->data.i + vi*sample_count;
513
514             for( i = 0; i < sample_count; i++ )
515             {
516                 float val = ord_nan;
517                 int si = sidx ? sidx[i] : i;
518                 if( !mask || !mask[si*m_step] )
519                 {
520                     if( idata )
521                         val = (float)idata[si*step];
522                     else
523                         val = fdata[si*step];
524
525                     if( fabs(val) >= ord_nan )
526                     {
527                         sprintf( err, "%d-th value of %d-th (ordered) "
528                             "variable (=%g) is too large", i, vi, val );
529                         CV_ERROR( CV_StsBadArg, err );
530                     }
531                 }
532                 num_valid++;
533                 if (is_buf_16u)
534                     udst[i] = (unsigned short)i;
535                 else
536                     idst[i] = i;
537                 _fdst[i] = val;
538                 
539             }
540             if (is_buf_16u)
541                 icvSortUShAux( udst, num_valid, _fdst);
542             else
543                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );
544         }
545        
546         if( vi < var_count )
547             data_root->set_num_valid(vi, num_valid);
548     }
549
550     // set sample labels
551     if (is_buf_16u)
552         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);
553     else
554         idst = buf->data.i + work_var_count*sample_count;
555
556     for (i = 0; i < sample_count; i++)
557     {
558         if (udst)
559             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
560         else
561             idst[i] = sidx ? sidx[i] : i;
562     }
563
564     if( cv_n )
565     {
566         unsigned short* udst = 0;
567         int* idst = 0;
568         CvRNG* r = &rng;
569
570         if (is_buf_16u)
571         {
572             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
573             for( i = vi = 0; i < sample_count; i++ )
574             {
575                 udst[i] = (unsigned short)vi++;
576                 vi &= vi < cv_n ? -1 : 0;
577             }
578
579             for( i = 0; i < sample_count; i++ )
580             {
581                 int a = cvRandInt(r) % sample_count;
582                 int b = cvRandInt(r) % sample_count;
583                 unsigned short unsh = (unsigned short)vi;
584                 CV_SWAP( udst[a], udst[b], unsh );
585             }
586         }
587         else
588         {
589             idst = buf->data.i + (get_work_var_count()-1)*sample_count;
590             for( i = vi = 0; i < sample_count; i++ )
591             {
592                 idst[i] = vi++;
593                 vi &= vi < cv_n ? -1 : 0;
594             }
595
596             for( i = 0; i < sample_count; i++ )
597             {
598                 int a = cvRandInt(r) % sample_count;
599                 int b = cvRandInt(r) % sample_count;
600                 CV_SWAP( idst[a], idst[b], vi );
601             }
602         }
603     }
604
605     if ( cat_map ) 
606         cat_map->cols = MAX( total_c_count, 1 );
607
608     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
609         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
610     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
611
612     have_priors = is_classifier && params.priors;
613     if( is_classifier )
614     {
615         int m = get_num_classes();
616         double sum = 0;
617         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
618         for( i = 0; i < m; i++ )
619         {
620             double val = have_priors ? params.priors[i] : 1.;
621             if( val <= 0 )
622                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
623             priors->data.db[i] = val;
624             sum += val;
625         }
626
627         // normalize weights
628         if( have_priors )
629             cvScale( priors, priors, 1./sum );
630
631         CV_CALL( priors_mult = cvCloneMat( priors ));
632         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
633     }
634
635
636     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
637     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
638
639     __END__;
640
641     if( data )
642         delete data;
643
644     if (_fdst)
645         cvFree( &_fdst );
646     if (_idst)
647         cvFree( &_idst );
648     cvFree( &int_ptr );
649     cvFree( &pair16u32s_ptr);
650     cvReleaseMat( &var_type0 );
651     cvReleaseMat( &sample_indices );
652     cvReleaseMat( &tmp_map );
653 }
654
655 void CvDTreeTrainData::do_responses_copy()
656 {
657     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
658     cvCopy( responses, responses_copy);
659     responses = responses_copy;
660 }
661
662 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
663 {
664     CvDTreeNode* root = 0;
665     CvMat* isubsample_idx = 0;
666     CvMat* subsample_co = 0;
667
668     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
669
670     __BEGIN__;
671
672     if( !data_root )
673         CV_ERROR( CV_StsError, "No training data has been set" );
674
675     if( _subsample_idx )
676         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
677
678     if( !isubsample_idx )
679     {
680         // make a copy of the root node
681         CvDTreeNode temp;
682         int i;
683         root = new_node( 0, 1, 0, 0 );
684         temp = *root;
685         *root = *data_root;
686         root->num_valid = temp.num_valid;
687         if( root->num_valid )
688         {
689             for( i = 0; i < var_count; i++ )
690                 root->num_valid[i] = data_root->num_valid[i];
691         }
692         root->cv_Tn = temp.cv_Tn;
693         root->cv_node_risk = temp.cv_node_risk;
694         root->cv_node_error = temp.cv_node_error;
695     }
696     else
697     {
698         int* sidx = isubsample_idx->data.i;
699         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
700         int* co, cur_ofs = 0;
701         int vi, i;
702         int work_var_count = get_work_var_count();
703         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
704
705         root = new_node( 0, count, 1, 0 );
706
707         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
708         cvZero( subsample_co );
709         co = subsample_co->data.i;
710         for( i = 0; i < count; i++ )
711             co[sidx[i]*2]++;
712         for( i = 0; i < sample_count; i++ )
713         {
714             if( co[i*2] )
715             {
716                 co[i*2+1] = cur_ofs;
717                 cur_ofs += co[i*2];
718             }
719             else
720                 co[i*2+1] = -1;
721         }
722
723         cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
724         for( vi = 0; vi < work_var_count; vi++ )
725         {
726             int ci = get_var_type(vi);
727
728             if( ci >= 0 || vi >= var_count )
729             {
730                 int num_valid = 0;
731                 const int* src = get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
732
733                 if (is_buf_16u)
734                 {
735                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
736                         vi*sample_count + root->offset);
737                     for( i = 0; i < count; i++ )
738                     {
739                         int val = src[sidx[i]];
740                         udst[i] = (unsigned short)val;
741                         num_valid += val >= 0;
742                     }
743                 }
744                 else
745                 {
746                     int* idst = buf->data.i + root->buf_idx*buf->cols + 
747                         vi*sample_count + root->offset;
748                     for( i = 0; i < count; i++ )
749                     {
750                         int val = src[sidx[i]];
751                         idst[i] = val;
752                         num_valid += val >= 0;
753                     }
754                 }
755
756                 if( vi < var_count )
757                     root->set_num_valid(vi, num_valid);
758             }
759             else
760             {
761                 int *src_idx_buf = (int*)(uchar*)inn_buf;
762                 float *src_val_buf = (float*)(src_idx_buf + sample_count);
763                 int* sample_indices_buf = (int*)(src_val_buf + sample_count);
764                 const int* src_idx = 0;
765                 const float* src_val = 0;
766                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
767                 int j = 0, idx, count_i;
768                 int num_valid = data_root->get_num_valid(vi);
769
770                 if (is_buf_16u)
771                 {
772                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
773                         vi*sample_count + data_root->offset);
774                     for( i = 0; i < num_valid; i++ )
775                     {
776                         idx = src_idx[i];
777                         count_i = co[idx*2];
778                         if( count_i )
779                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
780                                 udst_idx[j] = (unsigned short)cur_ofs;
781                     }
782
783                     root->set_num_valid(vi, j);
784
785                     for( ; i < sample_count; i++ )
786                     {
787                         idx = src_idx[i];
788                         count_i = co[idx*2];
789                         if( count_i )
790                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
791                                 udst_idx[j] = (unsigned short)cur_ofs;
792                     }
793                 }
794                 else
795                 {
796                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + 
797                         vi*sample_count + root->offset;
798                     for( i = 0; i < num_valid; i++ )
799                     {
800                         idx = src_idx[i];
801                         count_i = co[idx*2];
802                         if( count_i )
803                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
804                                 idst_idx[j] = cur_ofs;
805                     }
806
807                     root->set_num_valid(vi, j);
808
809                     for( ; i < sample_count; i++ )
810                     {
811                         idx = src_idx[i];
812                         count_i = co[idx*2];
813                         if( count_i )
814                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
815                                 idst_idx[j] = cur_ofs;
816                     }
817                 }
818             }
819         }
820         // sample indices subsampling
821         const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
822         if (is_buf_16u)
823         {
824             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
825                 get_work_var_count()*sample_count + root->offset);            
826             for (i = 0; i < count; i++)
827                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
828         }
829         else
830         {
831             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + 
832                 get_work_var_count()*sample_count + root->offset;            
833             for (i = 0; i < count; i++)
834                 sample_idx_dst[i] = sample_idx_src[sidx[i]];
835         }
836     }
837
838     __END__;
839
840     cvReleaseMat( &isubsample_idx );
841     cvReleaseMat( &subsample_co );
842
843     return root;
844 }
845
846
847 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
848                                     float* values, uchar* missing,
849                                     float* responses, bool get_class_idx )
850 {
851     CvMat* subsample_idx = 0;
852     CvMat* subsample_co = 0;
853
854     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
855
856     __BEGIN__;
857
858     int i, vi, total = sample_count, count = total, cur_ofs = 0;
859     int* sidx = 0;
860     int* co = 0;
861
862     cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
863     if( _subsample_idx )
864     {
865         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
866         sidx = subsample_idx->data.i;
867         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
868         co = subsample_co->data.i;
869         cvZero( subsample_co );
870         count = subsample_idx->cols + subsample_idx->rows - 1;
871         for( i = 0; i < count; i++ )
872             co[sidx[i]*2]++;
873         for( i = 0; i < total; i++ )
874         {
875             int count_i = co[i*2];
876             if( count_i )
877             {
878                 co[i*2+1] = cur_ofs*var_count;
879                 cur_ofs += count_i;
880             }
881         }
882     }
883
884     if( missing )
885         memset( missing, 1, count*var_count );
886
887     for( vi = 0; vi < var_count; vi++ )
888     {
889         int ci = get_var_type(vi);
890         if( ci >= 0 ) // categorical
891         {
892             float* dst = values + vi;
893             uchar* m = missing ? missing + vi : 0;
894             const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
895
896             for( i = 0; i < count; i++, dst += var_count )
897             {
898                 int idx = sidx ? sidx[i] : i;
899                 int val = src[idx];
900                 *dst = (float)val;
901                 if( m )
902                 {
903                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
904                     m += var_count;
905                 }
906             }
907         }
908         else // ordered
909         {
910             float* dst = values + vi;
911             uchar* m = missing ? missing + vi : 0;
912             int count1 = data_root->get_num_valid(vi);
913             float *src_val_buf = (float*)(uchar*)inn_buf;
914             int* src_idx_buf = (int*)(src_val_buf + sample_count);
915             int* sample_indices_buf = src_idx_buf + sample_count;
916             const float *src_val = 0;
917             const int* src_idx = 0;
918             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
919
920             for( i = 0; i < count1; i++ )
921             {
922                 int idx = src_idx[i];
923                 int count_i = 1;
924                 if( co )
925                 {
926                     count_i = co[idx*2];
927                     cur_ofs = co[idx*2+1];
928                 }
929                 else
930                     cur_ofs = idx*var_count;
931                 if( count_i )
932                 {
933                     float val = src_val[i];
934                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
935                     {
936                         dst[cur_ofs] = val;
937                         if( m )
938                             m[cur_ofs] = 0;
939                     }
940                 }
941             }
942         }
943     }
944
945     // copy responses
946     if( responses )
947     {
948         if( is_classifier )
949         {
950             const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
951             for( i = 0; i < count; i++ )
952             {
953                 int idx = sidx ? sidx[i] : i;
954                 int val = get_class_idx ? src[idx] :
955                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
956                 responses[i] = (float)val;
957             }
958         }
959         else
960         {
961             float* val_buf = (float*)(uchar*)inn_buf;
962             int* sample_idx_buf = (int*)(val_buf + sample_count);
963             const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
964             for( i = 0; i < count; i++ )
965             {
966                 int idx = sidx ? sidx[i] : i;
967                 responses[i] = _values[idx];
968             }
969         }
970     }
971
972     __END__;
973
974     cvReleaseMat( &subsample_idx );
975     cvReleaseMat( &subsample_co );
976 }
977
978
979 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
980                                          int storage_idx, int offset )
981 {
982     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
983
984     node->sample_count = count;
985     node->depth = parent ? parent->depth + 1 : 0;
986     node->parent = parent;
987     node->left = node->right = 0;
988     node->split = 0;
989     node->value = 0;
990     node->class_idx = 0;
991     node->maxlr = 0.;
992
993     node->buf_idx = storage_idx;
994     node->offset = offset;
995     if( nv_heap )
996         node->num_valid = (int*)cvSetNew( nv_heap );
997     else
998         node->num_valid = 0;
999     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1000     node->complexity = 0;
1001
1002     if( params.cv_folds > 0 && cv_heap )
1003     {
1004         int cv_n = params.cv_folds;
1005         node->Tn = INT_MAX;
1006         node->cv_Tn = (int*)cvSetNew( cv_heap );
1007         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1008         node->cv_node_error = node->cv_node_risk + cv_n;
1009     }
1010     else
1011     {
1012         node->Tn = 0;
1013         node->cv_Tn = 0;
1014         node->cv_node_risk = 0;
1015         node->cv_node_error = 0;
1016     }
1017
1018     return node;
1019 }
1020
1021
1022 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1023                 int split_point, int inversed, float quality )
1024 {
1025     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1026     split->var_idx = vi;
1027     split->condensed_idx = INT_MIN;
1028     split->ord.c = cmp_val;
1029     split->ord.split_point = split_point;
1030     split->inversed = inversed;
1031     split->quality = quality;
1032     split->next = 0;
1033
1034     return split;
1035 }
1036
1037
1038 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1039 {
1040     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1041     int i, n = (max_c_count + 31)/32;
1042
1043     split->var_idx = vi;
1044     split->condensed_idx = INT_MIN;
1045     split->inversed = 0;
1046     split->quality = quality;
1047     for( i = 0; i < n; i++ )
1048         split->subset[i] = 0;
1049     split->next = 0;
1050
1051     return split;
1052 }
1053
1054
1055 void CvDTreeTrainData::free_node( CvDTreeNode* node )
1056 {
1057     CvDTreeSplit* split = node->split;
1058     free_node_data( node );
1059     while( split )
1060     {
1061         CvDTreeSplit* next = split->next;
1062         cvSetRemoveByPtr( split_heap, split );
1063         split = next;
1064     }
1065     node->split = 0;
1066     cvSetRemoveByPtr( node_heap, node );
1067 }
1068
1069
1070 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1071 {
1072     if( node->num_valid )
1073     {
1074         cvSetRemoveByPtr( nv_heap, node->num_valid );
1075         node->num_valid = 0;
1076     }
1077     // do not free cv_* fields, as all the cross-validation related data is released at once.
1078 }
1079
1080
1081 void CvDTreeTrainData::free_train_data()
1082 {
1083     cvReleaseMat( &counts );
1084     cvReleaseMat( &buf );
1085     cvReleaseMat( &direction );
1086     cvReleaseMat( &split_buf );
1087     cvReleaseMemStorage( &temp_storage );
1088     cvReleaseMat( &responses_copy );
1089     cv_heap = nv_heap = 0;
1090 }
1091
1092
1093 void CvDTreeTrainData::clear()
1094 {
1095     free_train_data();
1096
1097     cvReleaseMemStorage( &tree_storage );
1098
1099     cvReleaseMat( &var_idx );
1100     cvReleaseMat( &var_type );
1101     cvReleaseMat( &cat_count );
1102     cvReleaseMat( &cat_ofs );
1103     cvReleaseMat( &cat_map );
1104     cvReleaseMat( &priors );
1105     cvReleaseMat( &priors_mult );
1106     
1107     node_heap = split_heap = 0;
1108
1109     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1110     have_labels = have_priors = is_classifier = false;
1111
1112     buf_count = buf_size = 0;
1113     shared = false;
1114     
1115     data_root = 0;
1116
1117     rng = cvRNG(-1);
1118 }
1119
1120
1121 int CvDTreeTrainData::get_num_classes() const
1122 {
1123     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1124 }
1125
1126
1127 int CvDTreeTrainData::get_var_type(int vi) const
1128 {
1129     return var_type->data.i[vi];
1130 }
1131
1132 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
1133                                          const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
1134 {
1135     int vidx = var_idx ? var_idx->data.i[vi] : vi;
1136     int node_sample_count = n->sample_count; 
1137     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1138
1139     const int* sample_indices = get_sample_indices(n, sample_indices_buf);
1140
1141     if( !is_buf_16u )
1142         *sorted_indices = buf->data.i + n->buf_idx*buf->cols +
1143         vi*sample_count + n->offset;
1144     else {
1145         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
1146             vi*sample_count + n->offset );
1147         for( int i = 0; i < node_sample_count; i++ )
1148             sorted_indices_buf[i] = short_indices[i];
1149         *sorted_indices = sorted_indices_buf;
1150     }
1151     
1152     if( tflag == CV_ROW_SAMPLE )
1153     {
1154         for( int i = 0; i < node_sample_count && 
1155             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1156         {
1157             int idx = (*sorted_indices)[i];
1158             idx = sample_indices[idx];
1159             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1160         }
1161     }
1162     else
1163         for( int i = 0; i < node_sample_count && 
1164             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1165         {
1166             int idx = (*sorted_indices)[i];
1167             idx = sample_indices[idx];
1168             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1169         }
1170     
1171     *ord_values = ord_values_buf;
1172 }
1173
1174
1175 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
1176 {
1177     if (is_classifier)
1178         return get_cat_var_data( n, var_count, labels_buf);
1179     return 0;
1180 }
1181
1182 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
1183 {
1184     return get_cat_var_data( n, get_work_var_count(), indices_buf );
1185 }
1186
1187 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
1188 {
1189     int sample_count = n->sample_count;
1190     int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
1191     const int* indices = get_sample_indices(n, sample_indices_buf);
1192
1193     for( int i = 0; i < sample_count && 
1194         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1195     {
1196         int idx = indices[i];
1197         values_buf[i] = *(responses->data.fl + idx * r_step);
1198     }
1199     
1200     return values_buf;
1201 }
1202
1203
1204 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
1205 {
1206     if (have_labels)
1207         return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
1208     return 0;
1209 }
1210
1211
1212 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
1213 {
1214     const int* cat_values = 0;
1215     if( !is_buf_16u )
1216         cat_values = buf->data.i + n->buf_idx*buf->cols +
1217             vi*sample_count + n->offset;
1218     else {
1219         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
1220             vi*sample_count + n->offset);
1221         for( int i = 0; i < n->sample_count; i++ )
1222             cat_values_buf[i] = short_values[i];
1223         cat_values = cat_values_buf;
1224     }
1225     return cat_values;
1226 }
1227
1228
1229 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1230 {
1231     int idx = n->buf_idx + 1;
1232     if( idx >= buf_count )
1233         idx = shared ? 1 : 0;
1234     return idx;
1235 }
1236
1237
1238 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1239 {
1240     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1241
1242     __BEGIN__;
1243
1244     int vi, vcount = var_count;
1245
1246     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1247     cvWriteInt( fs, "var_all", var_all );
1248     cvWriteInt( fs, "var_count", var_count );
1249     cvWriteInt( fs, "ord_var_count", ord_var_count );
1250     cvWriteInt( fs, "cat_var_count", cat_var_count );
1251
1252     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1253     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1254
1255     if( is_classifier )
1256     {
1257         cvWriteInt( fs, "max_categories", params.max_categories );
1258     }
1259     else
1260     {
1261         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1262     }
1263
1264     cvWriteInt( fs, "max_depth", params.max_depth );
1265     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1266     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1267
1268     if( params.cv_folds > 1 )
1269     {
1270         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1271         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1272     }
1273
1274     if( priors )
1275         cvWrite( fs, "priors", priors );
1276
1277     cvEndWriteStruct( fs );
1278
1279     if( var_idx )
1280         cvWrite( fs, "var_idx", var_idx );
1281
1282     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1283
1284     for( vi = 0; vi < vcount; vi++ )
1285         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1286
1287     cvEndWriteStruct( fs );
1288
1289     if( cat_count && (cat_var_count > 0 || is_classifier) )
1290     {
1291         CV_ASSERT( cat_count != 0 );
1292         cvWrite( fs, "cat_count", cat_count );
1293         cvWrite( fs, "cat_map", cat_map );
1294     }
1295
1296     __END__;
1297 }
1298
1299
1300 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1301 {
1302     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1303
1304     __BEGIN__;
1305
1306     CvFileNode *tparams_node, *vartype_node;
1307     CvSeqReader reader;
1308     int vi, max_split_size, tree_block_size;
1309
1310     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1311     var_all = cvReadIntByName( fs, node, "var_all" );
1312     var_count = cvReadIntByName( fs, node, "var_count", var_all );
1313     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1314     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1315
1316     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1317
1318     if( tparams_node ) // training parameters are not necessary
1319     {
1320         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1321
1322         if( is_classifier )
1323         {
1324             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1325         }
1326         else
1327         {
1328             params.regression_accuracy =
1329                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1330         }
1331
1332         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1333         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1334         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1335
1336         if( params.cv_folds > 1 )
1337         {
1338             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1339             params.truncate_pruned_tree =
1340                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1341         }
1342
1343         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1344         if( priors )
1345         {
1346             if( !CV_IS_MAT(priors) )
1347                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1348             priors_mult = cvCloneMat( priors );
1349         }
1350     }
1351
1352     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1353     if( var_idx )
1354     {
1355         if( !CV_IS_MAT(var_idx) ||
1356             (var_idx->cols != 1 && var_idx->rows != 1) ||
1357             var_idx->cols + var_idx->rows - 1 != var_count ||
1358             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1359             CV_ERROR( CV_StsParseError,
1360                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1361
1362         for( vi = 0; vi < var_count; vi++ )
1363             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1364                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1365     }
1366
1367     ////// read var type
1368     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1369
1370     cat_var_count = 0;
1371     ord_var_count = -1;
1372     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1373
1374     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1375         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1376     else
1377     {
1378         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1379             vartype_node->data.seq->total != var_count )
1380             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1381
1382         cvStartReadSeq( vartype_node->data.seq, &reader );
1383
1384         for( vi = 0; vi < var_count; vi++ )
1385         {
1386             CvFileNode* n = (CvFileNode*)reader.ptr;
1387             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1388                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1389             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1390             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1391         }
1392     }
1393     var_type->data.i[var_count] = cat_var_count;
1394
1395     ord_var_count = ~ord_var_count;
1396     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1397         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1398     //////
1399
1400     if( cat_var_count > 0 || is_classifier )
1401     {
1402         int ccount, total_c_count = 0;
1403         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1404         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1405
1406         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1407             (cat_count->cols != 1 && cat_count->rows != 1) ||
1408             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1409             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1410             (cat_map->cols != 1 && cat_map->rows != 1) ||
1411             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1412             CV_ERROR( CV_StsParseError,
1413             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1414
1415         ccount = cat_var_count + is_classifier;
1416
1417         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1418         cat_ofs->data.i[0] = 0;
1419         max_c_count = 1;
1420
1421         for( vi = 0; vi < ccount; vi++ )
1422         {
1423             int val = cat_count->data.i[vi];
1424             if( val <= 0 )
1425                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1426             max_c_count = MAX( max_c_count, val );
1427             cat_ofs->data.i[vi+1] = total_c_count += val;
1428         }
1429
1430         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1431             CV_ERROR( CV_StsBadSize,
1432             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1433     }
1434
1435     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1436         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1437
1438     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1439     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1440     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1441     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1442             sizeof(CvDTreeNode), tree_storage ));
1443     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1444             max_split_size, tree_storage ));
1445
1446     __END__;
1447 }
1448
1449 /////////////////////// Decision Tree /////////////////////////
1450
1451 CvDTree::CvDTree()
1452 {
1453     data = 0;
1454     var_importance = 0;
1455     default_model_name = "my_tree";
1456
1457     clear();
1458 }
1459
1460
1461 void CvDTree::clear()
1462 {
1463     cvReleaseMat( &var_importance );
1464     if( data )
1465     {
1466         if( !data->shared )
1467             delete data;
1468         else
1469             free_tree();
1470         data = 0;
1471     }
1472     root = 0;
1473     pruned_tree_idx = -1;
1474 }
1475
1476
1477 CvDTree::~CvDTree()
1478 {
1479     clear();
1480 }
1481
1482
1483 const CvDTreeNode* CvDTree::get_root() const
1484 {
1485     return root;
1486 }
1487
1488
1489 int CvDTree::get_pruned_tree_idx() const
1490 {
1491     return pruned_tree_idx;
1492 }
1493
1494
1495 CvDTreeTrainData* CvDTree::get_data()
1496 {
1497     return data;
1498 }
1499
1500
1501 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1502                      const CvMat* _responses, const CvMat* _var_idx,
1503                      const CvMat* _sample_idx, const CvMat* _var_type,
1504                      const CvMat* _missing_mask, CvDTreeParams _params )
1505 {
1506     bool result = false;
1507
1508     CV_FUNCNAME( "CvDTree::train" );
1509
1510     __BEGIN__;
1511
1512     clear();
1513     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1514                                  _var_idx, _sample_idx, _var_type,
1515                                  _missing_mask, _params, false );
1516     CV_CALL( result = do_train(0) );
1517
1518     __END__;
1519
1520     return result;
1521 }
1522
1523 bool CvDTree::train( const Mat& _train_data, int _tflag,
1524                     const Mat& _responses, const Mat& _var_idx,
1525                     const Mat& _sample_idx, const Mat& _var_type,
1526                     const Mat& _missing_mask, CvDTreeParams _params )
1527 {
1528     CvMat tdata = _train_data, responses = _responses, vidx=_var_idx,
1529         sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask; 
1530     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1531                  vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1532 }
1533
1534
1535 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1536 {
1537    bool result = false;
1538
1539     CV_FUNCNAME( "CvDTree::train" );
1540
1541     __BEGIN__;
1542
1543     const CvMat* values = _data->get_values();
1544     const CvMat* response = _data->get_responses();
1545     const CvMat* missing = _data->get_missing();
1546     const CvMat* var_types = _data->get_var_types();
1547     const CvMat* train_sidx = _data->get_train_sample_idx();
1548     const CvMat* var_idx = _data->get_var_idx();
1549
1550     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1551         train_sidx, var_types, missing, _params ) );
1552
1553     __END__;
1554
1555     return result;
1556 }
1557
1558 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1559 {
1560     bool result = false;
1561
1562     CV_FUNCNAME( "CvDTree::train" );
1563
1564     __BEGIN__;
1565
1566     clear();
1567     data = _data;
1568     data->shared = true;
1569     CV_CALL( result = do_train(_subsample_idx));
1570
1571     __END__;
1572
1573     return result;
1574 }
1575
1576
1577 bool CvDTree::do_train( const CvMat* _subsample_idx )
1578 {
1579     bool result = false;
1580
1581     CV_FUNCNAME( "CvDTree::do_train" );
1582
1583     __BEGIN__;
1584
1585     root = data->subsample_data( _subsample_idx );
1586
1587     CV_CALL( try_split_node(root));
1588
1589     if( data->params.cv_folds > 0 )
1590         CV_CALL( prune_cv());
1591
1592     if( !data->shared )
1593         data->free_train_data();
1594
1595     result = true;
1596
1597     __END__;
1598
1599     return result;
1600 }
1601
1602
1603 void CvDTree::try_split_node( CvDTreeNode* node )
1604 {
1605     CvDTreeSplit* best_split = 0;
1606     int i, n = node->sample_count, vi;
1607     bool can_split = true;
1608     double quality_scale;
1609
1610     calc_node_value( node );
1611
1612     if( node->sample_count <= data->params.min_sample_count ||
1613         node->depth >= data->params.max_depth )
1614         can_split = false;
1615
1616     if( can_split && data->is_classifier )
1617     {
1618         // check if we have a "pure" node,
1619         // we assume that cls_count is filled by calc_node_value()
1620         int* cls_count = data->counts->data.i;
1621         int nz = 0, m = data->get_num_classes();
1622         for( i = 0; i < m; i++ )
1623             nz += cls_count[i] != 0;
1624         if( nz == 1 ) // there is only one class
1625             can_split = false;
1626     }
1627     else if( can_split )
1628     {
1629         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1630             can_split = false;
1631     }
1632
1633     if( can_split )
1634     {
1635         best_split = find_best_split(node);
1636         // TODO: check the split quality ...
1637         node->split = best_split;
1638     }
1639     if( !can_split || !best_split )
1640     {
1641         data->free_node_data(node);
1642         return;
1643     }
1644
1645     quality_scale = calc_node_dir( node );
1646     if( data->params.use_surrogates )
1647     {
1648         // find all the surrogate splits
1649         // and sort them by their similarity to the primary one
1650         for( vi = 0; vi < data->var_count; vi++ )
1651         {
1652             CvDTreeSplit* split;
1653             int ci = data->get_var_type(vi);
1654
1655             if( vi == best_split->var_idx )
1656                 continue;
1657
1658             if( ci >= 0 )
1659                 split = find_surrogate_split_cat( node, vi );
1660             else
1661                 split = find_surrogate_split_ord( node, vi );
1662
1663             if( split )
1664             {
1665                 // insert the split
1666                 CvDTreeSplit* prev_split = node->split;
1667                 split->quality = (float)(split->quality*quality_scale);
1668
1669                 while( prev_split->next &&
1670                        prev_split->next->quality > split->quality )
1671                     prev_split = prev_split->next;
1672                 split->next = prev_split->next;
1673                 prev_split->next = split;
1674             }
1675         }
1676     }
1677     split_node_data( node );
1678     try_split_node( node->left );
1679     try_split_node( node->right );
1680 }
1681
1682
1683 // calculate direction (left(-1),right(1),missing(0))
1684 // for each sample using the best split
1685 // the function returns scale coefficients for surrogate split quality factors.
1686 // the scale is applied to normalize surrogate split quality relatively to the
1687 // best (primary) split quality. That is, if a surrogate split is absolutely
1688 // identical to the primary split, its quality will be set to the maximum value =
1689 // quality of the primary split; otherwise, it will be lower.
1690 // besides, the function compute node->maxlr,
1691 // minimum possible quality (w/o considering the above mentioned scale)
1692 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1693 // are not discarded.
1694 double CvDTree::calc_node_dir( CvDTreeNode* node )
1695 {
1696     char* dir = (char*)data->direction->data.ptr;
1697     int i, n = node->sample_count, vi = node->split->var_idx;
1698     double L, R;
1699
1700     assert( !node->split->inversed );
1701
1702     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1703     {
1704         cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
1705         int* labels_buf = (int*)inn_buf;
1706         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1707         const int* subset = node->split->subset;
1708         if( !data->have_priors )
1709         {
1710             int sum = 0, sum_abs = 0;
1711
1712             for( i = 0; i < n; i++ )
1713             {
1714                 int idx = labels[i];
1715                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1716                     CV_DTREE_CAT_DIR(idx,subset) : 0;
1717                 sum += d; sum_abs += d & 1;
1718                 dir[i] = (char)d;
1719             }
1720
1721             R = (sum_abs + sum) >> 1;
1722             L = (sum_abs - sum) >> 1;
1723         }
1724         else
1725         {
1726             const double* priors = data->priors_mult->data.db;
1727             double sum = 0, sum_abs = 0;
1728             int* responses_buf = labels_buf + n;
1729             const int* responses = data->get_class_labels(node, responses_buf);
1730
1731             for( i = 0; i < n; i++ )
1732             {
1733                 int idx = labels[i];
1734                 double w = priors[responses[i]];
1735                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1736                 sum += d*w; sum_abs += (d & 1)*w;
1737                 dir[i] = (char)d;
1738             }
1739
1740             R = (sum_abs + sum) * 0.5;
1741             L = (sum_abs - sum) * 0.5;
1742         }
1743     }
1744     else // split on ordered var
1745     {
1746         int split_point = node->split->ord.split_point;
1747         int n1 = node->get_num_valid(vi);
1748         cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
1749         float* val_buf = (float*)(uchar*)inn_buf;
1750         int* sorted_buf = (int*)(val_buf + n);
1751         int* sample_idx_buf = sorted_buf + n;
1752         const float* val = 0;
1753         const int* sorted = 0;
1754         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
1755         
1756         assert( 0 <= split_point && split_point < n1-1 );
1757
1758         if( !data->have_priors )
1759         {
1760             for( i = 0; i <= split_point; i++ )
1761                 dir[sorted[i]] = (char)-1;
1762             for( ; i < n1; i++ )
1763                 dir[sorted[i]] = (char)1;
1764             for( ; i < n; i++ )
1765                 dir[sorted[i]] = (char)0;
1766
1767             L = split_point-1;
1768             R = n1 - split_point + 1;
1769         }
1770         else
1771         {
1772             const double* priors = data->priors_mult->data.db;
1773             int* responses_buf = sample_idx_buf + n;
1774             const int* responses = data->get_class_labels(node, responses_buf);
1775             L = R = 0;
1776
1777             for( i = 0; i <= split_point; i++ )
1778             {
1779                 int idx = sorted[i];
1780                 double w = priors[responses[idx]];
1781                 dir[idx] = (char)-1;
1782                 L += w;
1783             }
1784
1785             for( ; i < n1; i++ )
1786             {
1787                 int idx = sorted[i];
1788                 double w = priors[responses[idx]];
1789                 dir[idx] = (char)1;
1790                 R += w;
1791             }
1792
1793             for( ; i < n; i++ )
1794                 dir[sorted[i]] = (char)0;
1795         }
1796     }
1797     node->maxlr = MAX( L, R );
1798     return node->split->quality/(L + R);
1799 }
1800
1801
1802 namespace cv
1803 {
1804
1805 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
1806 {
1807     tree = _tree;
1808     node = _node;
1809     splitSize = tree->get_data()->split_heap->elem_size;
1810
1811     bestSplit = (CvDTreeSplit*)(new char[splitSize]);
1812     memset((CvDTreeSplit*)bestSplit, 0, splitSize);
1813     bestSplit->quality = -1;
1814     bestSplit->condensed_idx = INT_MIN;
1815     split = (CvDTreeSplit*)(new char[splitSize]);
1816     memset((CvDTreeSplit*)split, 0, splitSize);
1817     //haveSplit = false;
1818 }
1819
1820 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
1821 {
1822     tree = finder.tree;
1823     node = finder.node;
1824     splitSize = tree->get_data()->split_heap->elem_size;
1825
1826     bestSplit = (CvDTreeSplit*)(new char[splitSize]);
1827     memcpy((CvDTreeSplit*)(bestSplit), (const CvDTreeSplit*)finder.bestSplit, splitSize);
1828     split = (CvDTreeSplit*)(new char[splitSize]);
1829     memset((CvDTreeSplit*)split, 0, splitSize);
1830 }
1831
1832 void DTreeBestSplitFinder::operator()(const BlockedRange& range)
1833 {
1834     int vi, vi1 = range.begin(), vi2 = range.end();
1835     int n = node->sample_count;
1836     CvDTreeTrainData* data = tree->get_data();
1837     AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
1838
1839     for( vi = vi1; vi < vi2; vi++ )
1840     {
1841         CvDTreeSplit *res;
1842         int ci = data->get_var_type(vi);
1843         if( node->get_num_valid(vi) <= 1 )
1844             continue;
1845
1846         if( data->is_classifier )
1847         {
1848             if( ci >= 0 )
1849                 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1850             else
1851                 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1852         }
1853         else
1854         {
1855             if( ci >= 0 )
1856                 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1857             else
1858                 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1859         }
1860
1861         if( res && bestSplit->quality < split->quality )
1862                 memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );
1863     }
1864 }
1865
1866 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
1867 {
1868     if( bestSplit->quality < rhs.bestSplit->quality )
1869         memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)rhs.bestSplit, splitSize );
1870 }
1871 }
1872
1873
1874 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1875 {
1876     DTreeBestSplitFinder finder( this, node );
1877
1878     cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
1879
1880     CvDTreeSplit *bestSplit = data->new_split_cat( 0, -1.0f );
1881     memcpy( bestSplit, finder.bestSplit, finder.splitSize );
1882
1883     return bestSplit;
1884 }
1885
1886 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1887                                              float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
1888 {
1889     const float epsilon = FLT_EPSILON*2;
1890     int n = node->sample_count;
1891     int n1 = node->get_num_valid(vi);
1892     int m = data->get_num_classes();
1893
1894     int base_size = 2*m*sizeof(int);
1895     cv::AutoBuffer<uchar> inn_buf(base_size);
1896     if( !_ext_buf )
1897       inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
1898     uchar* base_buf = (uchar*)inn_buf;
1899     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1900     float* values_buf = (float*)ext_buf;
1901     int* sorted_indices_buf = (int*)(values_buf + n);
1902     int* sample_indices_buf = sorted_indices_buf + n;
1903     const float* values = 0;
1904     const int* sorted_indices = 0;
1905     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
1906                             &sorted_indices, sample_indices_buf );
1907     int* responses_buf =  sample_indices_buf + n;
1908     const int* responses = data->get_class_labels( node, responses_buf );
1909
1910     const int* rc0 = data->counts->data.i;
1911     int* lc = (int*)base_buf;
1912     int* rc = lc + m;
1913     int i, best_i = -1;
1914     double lsum2 = 0, rsum2 = 0, best_val = init_quality;
1915     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1916
1917     // init arrays of class instance counters on both sides of the split
1918     for( i = 0; i < m; i++ )
1919     {
1920         lc[i] = 0;
1921         rc[i] = rc0[i];
1922     }
1923
1924     // compensate for missing values
1925     for( i = n1; i < n; i++ )
1926     {
1927         rc[responses[sorted_indices[i]]]--;
1928     }
1929
1930     if( !priors )
1931     {
1932         int L = 0, R = n1;
1933
1934         for( i = 0; i < m; i++ )
1935             rsum2 += (double)rc[i]*rc[i];
1936
1937         for( i = 0; i < n1 - 1; i++ )
1938         {
1939             int idx = responses[sorted_indices[i]];
1940             int lv, rv;
1941             L++; R--;
1942             lv = lc[idx]; rv = rc[idx];
1943             lsum2 += lv*2 + 1;
1944             rsum2 -= rv*2 - 1;
1945             lc[idx] = lv + 1; rc[idx] = rv - 1;
1946
1947             if( values[i] + epsilon < values[i+1] )
1948             {
1949                 double val = (lsum2*R + rsum2*L)/((double)L*R);
1950                 if( best_val < val )
1951                 {
1952                     best_val = val;
1953                     best_i = i;
1954                 }
1955             }
1956         }
1957     }
1958     else
1959     {
1960         double L = 0, R = 0;
1961         for( i = 0; i < m; i++ )
1962         {
1963             double wv = rc[i]*priors[i];
1964             R += wv;
1965             rsum2 += wv*wv;
1966         }
1967
1968         for( i = 0; i < n1 - 1; i++ )
1969         {
1970             int idx = responses[sorted_indices[i]];
1971             int lv, rv;
1972             double p = priors[idx], p2 = p*p;
1973             L += p; R -= p;
1974             lv = lc[idx]; rv = rc[idx];
1975             lsum2 += p2*(lv*2 + 1);
1976             rsum2 -= p2*(rv*2 - 1);
1977             lc[idx] = lv + 1; rc[idx] = rv - 1;
1978
1979             if( values[i] + epsilon < values[i+1] )
1980             {
1981                 double val = (lsum2*R + rsum2*L)/((double)L*R);
1982                 if( best_val < val )
1983                 {
1984                     best_val = val;
1985                     best_i = i;
1986                 }
1987             }
1988         }
1989     }
1990
1991     CvDTreeSplit* split = 0;
1992     if( best_i >= 0 )
1993     {
1994         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1995         split->var_idx = vi;
1996         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
1997         split->ord.split_point = best_i;
1998         split->inversed = 0;
1999         split->quality = (float)best_val;
2000     }
2001     return split;
2002 }
2003
2004
2005 void CvDTree::cluster_categories( const int* vectors, int n, int m,
2006                                 int* csums, int k, int* labels )
2007 {
2008     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2009     int iters = 0, max_iters = 100;
2010     int i, j, idx;
2011     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
2012     double *v_weights = buf, *c_weights = buf + n;
2013     bool modified = true;
2014     CvRNG* r = &data->rng;
2015
2016     // assign labels randomly
2017     for( i = 0; i < n; i++ )
2018     {
2019         int sum = 0;
2020         const int* v = vectors + i*m;
2021         labels[i] = i < k ? i : (cvRandInt(r) % k);
2022
2023         // compute weight of each vector
2024         for( j = 0; j < m; j++ )
2025             sum += v[j];
2026         v_weights[i] = sum ? 1./sum : 0.;
2027     }
2028
2029     for( i = 0; i < n; i++ )
2030     {
2031         int i1 = cvRandInt(r) % n;
2032         int i2 = cvRandInt(r) % n;
2033         CV_SWAP( labels[i1], labels[i2], j );
2034     }
2035
2036     for( iters = 0; iters <= max_iters; iters++ )
2037     {
2038         // calculate csums
2039         for( i = 0; i < k; i++ )
2040         {
2041             for( j = 0; j < m; j++ )
2042                 csums[i*m + j] = 0;
2043         }
2044
2045         for( i = 0; i < n; i++ )
2046         {
2047             const int* v = vectors + i*m;
2048             int* s = csums + labels[i]*m;
2049             for( j = 0; j < m; j++ )
2050                 s[j] += v[j];
2051         }
2052
2053         // exit the loop here, when we have up-to-date csums
2054         if( iters == max_iters || !modified )
2055             break;
2056
2057         modified = false;
2058
2059         // calculate weight of each cluster
2060         for( i = 0; i < k; i++ )
2061         {
2062             const int* s = csums + i*m;
2063             int sum = 0;
2064             for( j = 0; j < m; j++ )
2065                 sum += s[j];
2066             c_weights[i] = sum ? 1./sum : 0;
2067         }
2068
2069         // now for each vector determine the closest cluster
2070         for( i = 0; i < n; i++ )
2071         {
2072             const int* v = vectors + i*m;
2073             double alpha = v_weights[i];
2074             double min_dist2 = DBL_MAX;
2075             int min_idx = -1;
2076
2077             for( idx = 0; idx < k; idx++ )
2078             {
2079                 const int* s = csums + idx*m;
2080                 double dist2 = 0., beta = c_weights[idx];
2081                 for( j = 0; j < m; j++ )
2082                 {
2083                     double t = v[j]*alpha - s[j]*beta;
2084                     dist2 += t*t;
2085                 }
2086                 if( min_dist2 > dist2 )
2087                 {
2088                     min_dist2 = dist2;
2089                     min_idx = idx;
2090                 }
2091             }
2092
2093             if( min_idx != labels[i] )
2094                 modified = true;
2095             labels[i] = min_idx;
2096         }
2097     }
2098 }
2099
2100
2101 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
2102                                              CvDTreeSplit* _split, uchar* _ext_buf )
2103 {
2104     int ci = data->get_var_type(vi);
2105     int n = node->sample_count;
2106     int m = data->get_num_classes();
2107     int _mi = data->cat_count->data.i[ci], mi = _mi;
2108
2109     int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
2110     if( m > 2 && mi > data->params.max_categories )
2111         base_size += (m*min(data->params.max_categories, n) + mi)*sizeof(int);
2112     else
2113         base_size += mi*sizeof(int*);
2114     cv::AutoBuffer<uchar> inn_buf(base_size);
2115     if( !_ext_buf )
2116         inn_buf.allocate(base_size + 2*n*sizeof(int));
2117     uchar* base_buf = (uchar*)inn_buf;
2118     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2119
2120     int* lc = (int*)base_buf;
2121     int* rc = lc + m;
2122     int* _cjk = rc + m*2, *cjk = _cjk;
2123     double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
2124
2125     int* labels_buf = (int*)ext_buf;
2126     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2127     int* responses_buf = labels_buf + n;
2128     const int* responses = data->get_class_labels(node, responses_buf);
2129
2130     int* cluster_labels = 0;
2131     int** int_ptr = 0;
2132     int i, j, k, idx;
2133     double L = 0, R = 0;
2134     double best_val = init_quality;
2135     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2136     const double* priors = data->priors_mult->data.db;
2137
2138     // init array of counters:
2139     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
2140     for( j = -1; j < mi; j++ )
2141         for( k = 0; k < m; k++ )
2142             cjk[j*m + k] = 0;
2143
2144     for( i = 0; i < n; i++ )
2145     {
2146        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2147        k = responses[i];
2148        cjk[j*m + k]++;
2149     }
2150
2151     if( m > 2 )
2152     {
2153         if( mi > data->params.max_categories )
2154         {
2155             mi = MIN(data->params.max_categories, n);
2156             cjk = (int*)(c_weights + _mi);
2157             cluster_labels = cjk + m*mi;
2158             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2159         }
2160         subset_i = 1;
2161         subset_n = 1 << mi;
2162     }
2163     else
2164     {
2165         assert( m == 2 );
2166         int_ptr = (int**)(c_weights + _mi);
2167         for( j = 0; j < mi; j++ )
2168             int_ptr[j] = cjk + j*2 + 1;
2169         icvSortIntPtr( int_ptr, mi, 0 );
2170         subset_i = 0;
2171         subset_n = mi;
2172     }
2173
2174     for( k = 0; k < m; k++ )
2175     {
2176         int sum = 0;
2177         for( j = 0; j < mi; j++ )
2178             sum += cjk[j*m + k];
2179         rc[k] = sum;
2180         lc[k] = 0;
2181     }
2182
2183     for( j = 0; j < mi; j++ )
2184     {
2185         double sum = 0;
2186         for( k = 0; k < m; k++ )
2187             sum += cjk[j*m + k]*priors[k];
2188         c_weights[j] = sum;
2189         R += c_weights[j];
2190     }
2191
2192     for( ; subset_i < subset_n; subset_i++ )
2193     {
2194         double weight;
2195         int* crow;
2196         double lsum2 = 0, rsum2 = 0;
2197
2198         if( m == 2 )
2199             idx = (int)(int_ptr[subset_i] - cjk)/2;
2200         else
2201         {
2202             int graycode = (subset_i>>1)^subset_i;
2203             int diff = graycode ^ prevcode;
2204
2205             // determine index of the changed bit.
2206             Cv32suf u;
2207             idx = diff >= (1 << 16) ? 16 : 0;
2208             u.f = (float)(((diff >> 16) | diff) & 65535);
2209             idx += (u.i >> 23) - 127;
2210             subtract = graycode < prevcode;
2211             prevcode = graycode;
2212         }
2213
2214         crow = cjk + idx*m;
2215         weight = c_weights[idx];
2216         if( weight < FLT_EPSILON )
2217             continue;
2218
2219         if( !subtract )
2220         {
2221             for( k = 0; k < m; k++ )
2222             {
2223                 int t = crow[k];
2224                 int lval = lc[k] + t;
2225                 int rval = rc[k] - t;
2226                 double p = priors[k], p2 = p*p;
2227                 lsum2 += p2*lval*lval;
2228                 rsum2 += p2*rval*rval;
2229                 lc[k] = lval; rc[k] = rval;
2230             }
2231             L += weight;
2232             R -= weight;
2233         }
2234         else
2235         {
2236             for( k = 0; k < m; k++ )
2237             {
2238                 int t = crow[k];
2239                 int lval = lc[k] - t;
2240                 int rval = rc[k] + t;
2241                 double p = priors[k], p2 = p*p;
2242                 lsum2 += p2*lval*lval;
2243                 rsum2 += p2*rval*rval;
2244                 lc[k] = lval; rc[k] = rval;
2245             }
2246             L -= weight;
2247             R += weight;
2248         }
2249
2250         if( L > FLT_EPSILON && R > FLT_EPSILON )
2251         {
2252             double val = (lsum2*R + rsum2*L)/((double)L*R);
2253             if( best_val < val )
2254             {
2255                 best_val = val;
2256                 best_subset = subset_i;
2257             }
2258         }
2259     }
2260
2261     CvDTreeSplit* split = 0;
2262     if( best_subset >= 0 ) 
2263     {
2264         split = _split ? _split : data->new_split_cat( 0, -1.0f );
2265         split->var_idx = vi;
2266         split->quality = (float)best_val;
2267         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2268         if( m == 2 )
2269         {
2270             for( i = 0; i <= best_subset; i++ )
2271             {
2272                 idx = (int)(int_ptr[i] - cjk) >> 1;
2273                 split->subset[idx >> 5] |= 1 << (idx & 31);
2274             }
2275         }
2276         else
2277         {
2278             for( i = 0; i < _mi; i++ )
2279             {
2280                 idx = cluster_labels ? cluster_labels[i] : i;
2281                 if( best_subset & (1 << idx) )
2282                     split->subset[i >> 5] |= 1 << (i & 31);
2283             }
2284         }
2285     }
2286     return split;
2287 }
2288
2289
2290 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2291 {
2292     const float epsilon = FLT_EPSILON*2;
2293     int n = node->sample_count;
2294     int n1 = node->get_num_valid(vi);
2295
2296     cv::AutoBuffer<uchar> inn_buf;
2297     if( !_ext_buf )
2298         inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
2299     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2300     float* values_buf = (float*)ext_buf;
2301     int* sorted_indices_buf = (int*)(values_buf + n);
2302     int* sample_indices_buf = sorted_indices_buf + n;
2303     const float* values = 0;
2304     const int* sorted_indices = 0;
2305     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2306     float* responses_buf =  (float*)(sample_indices_buf + n);
2307     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
2308
2309     int i, best_i = -1;
2310     double best_val = init_quality, lsum = 0, rsum = node->value*n;
2311     int L = 0, R = n1;
2312
2313     // compensate for missing values
2314     for( i = n1; i < n; i++ )
2315         rsum -= responses[sorted_indices[i]];
2316
2317     // find the optimal split
2318     for( i = 0; i < n1 - 1; i++ )
2319     {
2320         float t = responses[sorted_indices[i]];
2321         L++; R--;
2322         lsum += t;
2323         rsum -= t;
2324
2325         if( values[i] + epsilon < values[i+1] )
2326         {
2327             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2328             if( best_val < val )
2329             {
2330                 best_val = val;
2331                 best_i = i;
2332             }
2333         }
2334     }
2335
2336     CvDTreeSplit* split = 0;
2337     if( best_i >= 0 )
2338     {
2339         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2340         split->var_idx = vi;
2341         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2342         split->ord.split_point = best_i;
2343         split->inversed = 0;
2344         split->quality = (float)best_val;
2345     }
2346     return split;
2347 }
2348
2349 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2350 {
2351     int ci = data->get_var_type(vi);
2352     int n = node->sample_count;
2353     int mi = data->cat_count->data.i[ci];
2354
2355     int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
2356     cv::AutoBuffer<uchar> inn_buf(base_size);
2357     if( !_ext_buf )
2358         inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
2359     uchar* base_buf = (uchar*)inn_buf;
2360     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2361     int* labels_buf = (int*)ext_buf;
2362     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2363     float* responses_buf = (float*)(labels_buf + n);
2364     int* sample_indices_buf = (int*)(responses_buf + n);
2365     const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
2366
2367     double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2368     int* counts = (int*)(sum + mi) + 1;
2369     double** sum_ptr = (double**)(counts + mi);
2370     int i, L = 0, R = 0;
2371     double best_val = init_quality, lsum = 0, rsum = 0;
2372     int best_subset = -1, subset_i;
2373
2374     for( i = -1; i < mi; i++ )
2375         sum[i] = counts[i] = 0;
2376
2377     // calculate sum response and weight of each category of the input var
2378     for( i = 0; i < n; i++ )
2379     {
2380         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2381         double s = sum[idx] + responses[i];
2382         int nc = counts[idx] + 1;
2383         sum[idx] = s;
2384         counts[idx] = nc;
2385     }
2386
2387     // calculate average response in each category
2388     for( i = 0; i < mi; i++ )
2389     {
2390         R += counts[i];
2391         rsum += sum[i];
2392         sum[i] /= MAX(counts[i],1);
2393         sum_ptr[i] = sum + i;
2394     }
2395
2396     icvSortDblPtr( sum_ptr, mi, 0 );
2397
2398     // revert back to unnormalized sums
2399     // (there should be a very little loss of accuracy)
2400     for( i = 0; i < mi; i++ )
2401         sum[i] *= counts[i];
2402
2403     for( subset_i = 0; subset_i < mi-1; subset_i++ )
2404     {
2405         int idx = (int)(sum_ptr[subset_i] - sum);
2406         int ni = counts[idx];
2407
2408         if( ni )
2409         {
2410             double s = sum[idx];
2411             lsum += s; L += ni;
2412             rsum -= s; R -= ni;
2413
2414             if( L && R )
2415             {
2416                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2417                 if( best_val < val )
2418                 {
2419                     best_val = val;
2420                     best_subset = subset_i;
2421                 }
2422             }
2423         }
2424     }
2425
2426     CvDTreeSplit* split = 0;
2427     if( best_subset >= 0 )
2428     {
2429         split = _split ? _split : data->new_split_cat( 0, -1.0f);
2430         split->var_idx = vi;
2431         split->quality = (float)best_val;
2432         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2433         for( i = 0; i <= best_subset; i++ )
2434         {
2435             int idx = (int)(sum_ptr[i] - sum);
2436             split->subset[idx >> 5] |= 1 << (idx & 31);
2437         }
2438     }
2439     return split;
2440 }
2441
2442 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
2443 {
2444     const float epsilon = FLT_EPSILON*2;
2445     const char* dir = (char*)data->direction->data.ptr;
2446     int n = node->sample_count, n1 = node->get_num_valid(vi);
2447     cv::AutoBuffer<uchar> inn_buf;
2448     if( !_ext_buf )
2449         inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
2450     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2451     float* values_buf = (float*)ext_buf;
2452     int* sorted_indices_buf = (int*)(values_buf + n);
2453     int* sample_indices_buf = sorted_indices_buf + n;
2454     const float* values = 0;
2455     const int* sorted_indices = 0;
2456     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2457     // LL - number of samples that both the primary and the surrogate splits send to the left
2458     // LR - ... primary split sends to the left and the surrogate split sends to the right
2459     // RL - ... primary split sends to the right and the surrogate split sends to the left
2460     // RR - ... both send to the right
2461     int i, best_i = -1, best_inversed = 0;
2462     double best_val;
2463
2464     if( !data->have_priors )
2465     {
2466         int LL = 0, RL = 0, LR, RR;
2467         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2468         int sum = 0, sum_abs = 0;
2469
2470         for( i = 0; i < n1; i++ )
2471         {
2472             int d = dir[sorted_indices[i]];
2473             sum += d; sum_abs += d & 1;
2474         }
2475
2476         // sum_abs = R + L; sum = R - L
2477         RR = (sum_abs + sum) >> 1;
2478         LR = (sum_abs - sum) >> 1;
2479
2480         // initially all the samples are sent to the right by the surrogate split,
2481         // LR of them are sent to the left by primary split, and RR - to the right.
2482         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2483         for( i = 0; i < n1 - 1; i++ )
2484         {
2485             int d = dir[sorted_indices[i]];
2486
2487             if( d < 0 )
2488             {
2489                 LL++; LR--;
2490                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2491                 {
2492                     best_val = LL + RR;
2493                     best_i = i; best_inversed = 0;
2494                 }
2495             }
2496             else if( d > 0 )
2497             {
2498                 RL++; RR--;
2499                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2500                 {
2501                     best_val = RL + LR;
2502                     best_i = i; best_inversed = 1;
2503                 }
2504             }
2505         }
2506         best_val = _best_val;
2507     }
2508     else
2509     {
2510         double LL = 0, RL = 0, LR, RR;
2511         double worst_val = node->maxlr;
2512         double sum = 0, sum_abs = 0;
2513         const double* priors = data->priors_mult->data.db;
2514         int* responses_buf = sample_indices_buf + n;
2515         const int* responses = data->get_class_labels(node, responses_buf);
2516         best_val = worst_val;
2517
2518         for( i = 0; i < n1; i++ )
2519         {
2520             int idx = sorted_indices[i];
2521             double w = priors[responses[idx]];
2522             int d = dir[idx];
2523             sum += d*w; sum_abs += (d & 1)*w;
2524         }
2525
2526         // sum_abs = R + L; sum = R - L
2527         RR = (sum_abs + sum)*0.5;
2528         LR = (sum_abs - sum)*0.5;
2529
2530         // initially all the samples are sent to the right by the surrogate split,
2531         // LR of them are sent to the left by primary split, and RR - to the right.
2532         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2533         for( i = 0; i < n1 - 1; i++ )
2534         {
2535             int idx = sorted_indices[i];
2536             double w = priors[responses[idx]];
2537             int d = dir[idx];
2538
2539             if( d < 0 )
2540             {
2541                 LL += w; LR -= w;
2542                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2543                 {
2544                     best_val = LL + RR;
2545                     best_i = i; best_inversed = 0;
2546                 }
2547             }
2548             else if( d > 0 )
2549             {
2550                 RL += w; RR -= w;
2551                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2552                 {
2553                     best_val = RL + LR;
2554                     best_i = i; best_inversed = 1;
2555                 }
2556             }
2557         }
2558     }
2559     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2560         (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2561 }
2562
2563
2564 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
2565 {
2566     const char* dir = (char*)data->direction->data.ptr;
2567     int n = node->sample_count;
2568     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2569
2570     int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
2571     cv::AutoBuffer<uchar> inn_buf(base_size);
2572     if( !_ext_buf )
2573         inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
2574     uchar* base_buf = (uchar*)inn_buf;
2575     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2576
2577     int* labels_buf = (int*)ext_buf;
2578     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2579     // LL - number of samples that both the primary and the surrogate splits send to the left
2580     // LR - ... primary split sends to the left and the surrogate split sends to the right
2581     // RL - ... primary split sends to the right and the surrogate split sends to the left
2582     // RR - ... both send to the right
2583     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2584     double best_val = 0;
2585     double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2586     double* rc = lc + mi + 1;
2587
2588     for( i = -1; i < mi; i++ )
2589         lc[i] = rc[i] = 0;
2590
2591     // for each category calculate the weight of samples
2592     // sent to the left (lc) and to the right (rc) by the primary split
2593     if( !data->have_priors )
2594     {
2595         int* _lc = (int*)rc + 1;
2596         int* _rc = _lc + mi + 1;
2597
2598         for( i = -1; i < mi; i++ )
2599             _lc[i] = _rc[i] = 0;
2600
2601         for( i = 0; i < n; i++ )
2602         {
2603             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2604             int d = dir[i];
2605             int sum = _lc[idx] + d;
2606             int sum_abs = _rc[idx] + (d & 1);
2607             _lc[idx] = sum; _rc[idx] = sum_abs;
2608         }
2609
2610         for( i = 0; i < mi; i++ )
2611         {
2612             int sum = _lc[i];
2613             int sum_abs = _rc[i];
2614             lc[i] = (sum_abs - sum) >> 1;
2615             rc[i] = (sum_abs + sum) >> 1;
2616         }
2617     }
2618     else
2619     {
2620         const double* priors = data->priors_mult->data.db;
2621         int* responses_buf = labels_buf + n;
2622         const int* responses = data->get_class_labels(node, responses_buf);
2623
2624         for( i = 0; i < n; i++ )
2625         {
2626             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2627             double w = priors[responses[i]];
2628             int d = dir[i];
2629             double sum = lc[idx] + d*w;
2630             double sum_abs = rc[idx] + (d & 1)*w;
2631             lc[idx] = sum; rc[idx] = sum_abs;
2632         }
2633
2634         for( i = 0; i < mi; i++ )
2635         {
2636             double sum = lc[i];
2637             double sum_abs = rc[i];
2638             lc[i] = (sum_abs - sum) * 0.5;
2639             rc[i] = (sum_abs + sum) * 0.5;
2640         }
2641     }
2642
2643     // 2. now form the split.
2644     // in each category send all the samples to the same direction as majority
2645     for( i = 0; i < mi; i++ )
2646     {
2647         double lval = lc[i], rval = rc[i];
2648         if( lval > rval )
2649         {
2650             split->subset[i >> 5] |= 1 << (i & 31);
2651             best_val += lval;
2652             l_win++;
2653         }
2654         else
2655             best_val += rval;
2656     }
2657
2658     split->quality = (float)best_val;
2659     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2660         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2661
2662     return split;
2663 }
2664
2665
2666 void CvDTree::calc_node_value( CvDTreeNode* node )
2667 {
2668     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2669     int m = data->get_num_classes();
2670
2671     int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
2672     int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
2673     cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
2674     uchar* base_buf = (uchar*)inn_buf;
2675     uchar* ext_buf = base_buf + base_size;
2676
2677     int* cv_labels_buf = (int*)ext_buf;
2678     const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
2679
2680     if( data->is_classifier )
2681     {
2682         // in case of classification tree:
2683         //  * node value is the label of the class that has the largest weight in the node.
2684         //  * node risk is the weighted number of misclassified samples,
2685         //  * j-th cross-validation fold value and risk are calculated as above,
2686         //    but using the samples with cv_labels(*)!=j.
2687         //  * j-th cross-validation fold error is calculated as the weighted number of
2688         //    misclassified samples with cv_labels(*)==j.
2689
2690         // compute the number of instances of each class
2691         int* cls_count = data->counts->data.i;
2692         int* responses_buf = cv_labels_buf + n;
2693         const int* responses = data->get_class_labels(node, responses_buf);
2694         int* cv_cls_count = (int*)base_buf;
2695         double max_val = -1, total_weight = 0;
2696         int max_k = -1;
2697         double* priors = data->priors_mult->data.db;
2698
2699         for( k = 0; k < m; k++ )
2700             cls_count[k] = 0;
2701
2702         if( cv_n == 0 )
2703         {
2704             for( i = 0; i < n; i++ )
2705                 cls_count[responses[i]]++;
2706         }
2707         else
2708         {
2709             for( j = 0; j < cv_n; j++ )
2710                 for( k = 0; k < m; k++ )
2711                     cv_cls_count[j*m + k] = 0;
2712
2713             for( i = 0; i < n; i++ )
2714             {
2715                 j = cv_labels[i]; k = responses[i];
2716                 cv_cls_count[j*m + k]++;
2717             }
2718
2719             for( j = 0; j < cv_n; j++ )
2720                 for( k = 0; k < m; k++ )
2721                     cls_count[k] += cv_cls_count[j*m + k];
2722         }
2723
2724         if( data->have_priors && node->parent == 0 )
2725         {
2726             // compute priors_mult from priors, take the sample ratio into account.
2727             double sum = 0;
2728             for( k = 0; k < m; k++ )
2729             {
2730                 int n_k = cls_count[k];
2731                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2732                 sum += priors[k];
2733             }
2734             sum = 1./sum;
2735             for( k = 0; k < m; k++ )
2736                 priors[k] *= sum;
2737         }
2738
2739         for( k = 0; k < m; k++ )
2740         {
2741             double val = cls_count[k]*priors[k];
2742             total_weight += val;
2743             if( max_val < val )
2744             {
2745                 max_val = val;
2746                 max_k = k;
2747             }
2748         }
2749
2750         node->class_idx = max_k;
2751         node->value = data->cat_map->data.i[
2752             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2753         node->node_risk = total_weight - max_val;
2754
2755         for( j = 0; j < cv_n; j++ )
2756         {
2757             double sum_k = 0, sum = 0, max_val_k = 0;
2758             max_val = -1; max_k = -1;
2759
2760             for( k = 0; k < m; k++ )
2761             {
2762                 double w = priors[k];
2763                 double val_k = cv_cls_count[j*m + k]*w;
2764                 double val = cls_count[k]*w - val_k;
2765                 sum_k += val_k;
2766                 sum += val;
2767                 if( max_val < val )
2768                 {
2769                     max_val = val;
2770                     max_val_k = val_k;
2771                     max_k = k;
2772                 }
2773             }
2774
2775             node->cv_Tn[j] = INT_MAX;
2776             node->cv_node_risk[j] = sum - max_val;
2777             node->cv_node_error[j] = sum_k - max_val_k;
2778         }
2779     }
2780     else
2781     {
2782         // in case of regression tree:
2783         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2784         //    n is the number of samples in the node.
2785         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2786         //  * j-th cross-validation fold value and risk are calculated as above,
2787         //    but using the samples with cv_labels(*)!=j.
2788         //  * j-th cross-validation fold error is calculated
2789         //    using samples with cv_labels(*)==j as the test subset:
2790         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2791         //    where node_value_j is the node value calculated
2792         //    as described in the previous bullet, and summation is done
2793         //    over the samples with cv_labels(*)==j.
2794
2795         double sum = 0, sum2 = 0;
2796         float* values_buf = (float*)(cv_labels_buf + n);
2797         int* sample_indices_buf = (int*)(values_buf + n);
2798         const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
2799         double *cv_sum = 0, *cv_sum2 = 0;
2800         int* cv_count = 0;
2801
2802         if( cv_n == 0 )
2803         {
2804             for( i = 0; i < n; i++ )
2805             {
2806                 double t = values[i];
2807                 sum += t;
2808                 sum2 += t*t;
2809             }
2810         }
2811         else
2812         {
2813             cv_sum = (double*)base_buf;
2814             cv_sum2 = cv_sum + cv_n;
2815             cv_count = (int*)(cv_sum2 + cv_n);
2816
2817             for( j = 0; j < cv_n; j++ )
2818             {
2819                 cv_sum[j] = cv_sum2[j] = 0.;
2820                 cv_count[j] = 0;
2821             }
2822
2823             for( i = 0; i < n; i++ )
2824             {
2825                 j = cv_labels[i];
2826                 double t = values[i];
2827                 double s = cv_sum[j] + t;
2828                 double s2 = cv_sum2[j] + t*t;
2829                 int nc = cv_count[j] + 1;
2830                 cv_sum[j] = s;
2831                 cv_sum2[j] = s2;
2832                 cv_count[j] = nc;
2833             }
2834
2835             for( j = 0; j < cv_n; j++ )
2836             {
2837                 sum += cv_sum[j];
2838                 sum2 += cv_sum2[j];
2839             }
2840         }
2841
2842         node->node_risk = sum2 - (sum/n)*sum;
2843         node->value = sum/n;
2844
2845         for( j = 0; j < cv_n; j++ )
2846         {
2847             double s = cv_sum[j], si = sum - s;
2848             double s2 = cv_sum2[j], s2i = sum2 - s2;
2849             int c = cv_count[j], ci = n - c;
2850             double r = si/MAX(ci,1);
2851             node->cv_node_risk[j] = s2i - r*r*ci;
2852             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2853             node->cv_Tn[j] = INT_MAX;
2854         }
2855     }
2856 }
2857
2858
2859 void CvDTree::complete_node_dir( CvDTreeNode* node )
2860 {
2861     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2862     int nz = n - node->get_num_valid(node->split->var_idx);
2863     char* dir = (char*)data->direction->data.ptr;
2864
2865     // try to complete direction using surrogate splits
2866     if( nz && data->params.use_surrogates )
2867     {
2868         cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
2869         CvDTreeSplit* split = node->split->next;
2870         for( ; split != 0 && nz; split = split->next )
2871         {
2872             int inversed_mask = split->inversed ? -1 : 0;
2873             vi = split->var_idx;
2874
2875             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2876             {
2877                 int* labels_buf = (int*)(uchar*)inn_buf;
2878                 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2879                 const int* subset = split->subset;
2880
2881                 for( i = 0; i < n; i++ )
2882                 {
2883                     int idx = labels[i];
2884                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2885                         
2886                     {
2887                         int d = CV_DTREE_CAT_DIR(idx,subset);
2888                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2889                         if( --nz )
2890                             break;
2891                     }
2892                 }
2893             }
2894             else // split on ordered var
2895             {
2896                 float* values_buf = (float*)(uchar*)inn_buf;
2897                 int* sorted_indices_buf = (int*)(values_buf + n);
2898                 int* sample_indices_buf = sorted_indices_buf + n;
2899                 const float* values = 0;
2900                 const int* sorted_indices = 0;
2901                 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2902                 int split_point = split->ord.split_point;
2903                 int n1 = node->get_num_valid(vi);
2904
2905                 assert( 0 <= split_point && split_point < n-1 );
2906
2907                 for( i = 0; i < n1; i++ )
2908                 {
2909                     int idx = sorted_indices[i];
2910                     if( !dir[idx] )
2911                     {
2912                         int d = i <= split_point ? -1 : 1;
2913                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2914                         if( --nz )
2915                             break;
2916                     }
2917                 }
2918             }
2919         }
2920     }
2921
2922     // find the default direction for the rest
2923     if( nz )
2924     {
2925         for( i = nr = 0; i < n; i++ )
2926             nr += dir[i] > 0;
2927         nl = n - nr - nz;
2928         d0 = nl > nr ? -1 : nr > nl;
2929     }
2930
2931     // make sure that every sample is directed either to the left or to the right
2932     for( i = 0; i < n; i++ )
2933     {
2934         int d = dir[i];
2935         if( !d )
2936         {
2937             d = d0;
2938             if( !d )
2939                 d = d1, d1 = -d1;
2940         }
2941         d = d > 0;
2942         dir[i] = (char)d; // remap (-1,1) to (0,1)
2943     }
2944 }
2945
2946
2947 void CvDTree::split_node_data( CvDTreeNode* node )
2948 {
2949     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
2950     char* dir = (char*)data->direction->data.ptr;
2951     CvDTreeNode *left = 0, *right = 0;
2952     int* new_idx = data->split_buf->data.i;
2953     int new_buf_idx = data->get_child_buf_idx( node );
2954     int work_var_count = data->get_work_var_count();
2955     CvMat* buf = data->buf;
2956     cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
2957     int* temp_buf = (int*)(uchar*)inn_buf;
2958
2959     complete_node_dir(node);
2960
2961     for( i = nl = nr = 0; i < n; i++ )
2962     {
2963         int d = dir[i];
2964         // initialize new indices for splitting ordered variables
2965         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2966         nr += d;
2967         nl += d^1;
2968     }
2969
2970     bool split_input_data;
2971     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2972     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
2973
2974     split_input_data = node->depth + 1 < data->params.max_depth &&
2975         (node->left->sample_count > data->params.min_sample_count ||
2976         node->right->sample_count > data->params.min_sample_count);
2977
2978     // split ordered variables, keep both halves sorted.
2979     for( vi = 0; vi < data->var_count; vi++ )
2980     {
2981         int ci = data->get_var_type(vi);
2982
2983         if( ci >= 0 || !split_input_data )
2984             continue;
2985
2986         int n1 = node->get_num_valid(vi);
2987         float* src_val_buf = (float*)(uchar*)(temp_buf + n);
2988         int* src_sorted_idx_buf = (int*)(src_val_buf + n);
2989         int* src_sample_idx_buf = src_sorted_idx_buf + n;
2990         const float* src_val = 0;
2991         const int* src_sorted_idx = 0;
2992         data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
2993
2994         for(i = 0; i < n; i++)
2995             temp_buf[i] = src_sorted_idx[i];
2996
2997         if (data->is_buf_16u)
2998         {
2999             unsigned short *ldst, *rdst, *ldst0, *rdst0;
3000             //unsigned short tl, tr;
3001             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
3002                 vi*scount + left->offset);
3003             rdst0 = rdst = (unsigned short*)(ldst + nl);
3004
3005             // split sorted
3006             for( i = 0; i < n1; i++ )
3007             {
3008                 int idx = temp_buf[i];
3009                 int d = dir[idx];
3010                 idx = new_idx[idx];
3011                 if (d)
3012                 {
3013                     *rdst = (unsigned short)idx;
3014                     rdst++;
3015                 }
3016                 else
3017                 {
3018                     *ldst = (unsigned short)idx;
3019                     ldst++;
3020                 }
3021             }
3022
3023             left->set_num_valid(vi, (int)(ldst - ldst0));
3024             right->set_num_valid(vi, (int)(rdst - rdst0));
3025
3026             // split missing
3027             for( ; i < n; i++ )
3028             {
3029                 int idx = temp_buf[i];
3030                 int d = dir[idx];
3031                 idx = new_idx[idx];
3032                 if (d)
3033                 {
3034                     *rdst = (unsigned short)idx;
3035                     rdst++;
3036                 }
3037                 else
3038                 {
3039                     *ldst = (unsigned short)idx;
3040                     ldst++;
3041                 }
3042             }
3043         }
3044         else
3045         {
3046             int *ldst0, *ldst, *rdst0, *rdst;
3047             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + 
3048                 vi*scount + left->offset;
3049             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + 
3050                 vi*scount + right->offset;
3051
3052             // split sorted
3053             for( i = 0; i < n1; i++ )
3054             {
3055                 int idx = temp_buf[i];
3056                 int d = dir[idx];
3057                 idx = new_idx[idx];
3058                 if (d)
3059                 {
3060                     *rdst = idx;
3061                     rdst++;
3062                 }
3063                 else
3064                 {
3065                     *ldst = idx;
3066                     ldst++;
3067                 }
3068             }
3069
3070             left->set_num_valid(vi, (int)(ldst - ldst0));
3071             right->set_num_valid(vi, (int)(rdst - rdst0));
3072
3073             // split missing
3074             for( ; i < n; i++ )
3075             {
3076                 int idx = temp_buf[i];
3077                 int d = dir[idx];
3078                 idx = new_idx[idx];
3079                 if (d)
3080                 {
3081                     *rdst = idx;
3082                     rdst++;
3083                 }
3084                 else
3085                 {
3086                     *ldst = idx;
3087                     ldst++;
3088                 }
3089             }
3090         }
3091     }
3092
3093     // split categorical vars, responses and cv_labels using new_idx relocation table
3094     for( vi = 0; vi < work_var_count; vi++ )
3095     {
3096         int ci = data->get_var_type(vi);
3097         int n1 = node->get_num_valid(vi), nr1 = 0;
3098         
3099         if( ci < 0 || (vi < data->var_count && !split_input_data) )
3100             continue;
3101
3102         int *src_lbls_buf = temp_buf + n;
3103         const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
3104
3105         for(i = 0; i < n; i++)
3106             temp_buf[i] = src_lbls[i];
3107
3108         if (data->is_buf_16u)
3109         {
3110             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + 
3111                 vi*scount + left->offset);
3112             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + 
3113                 vi*scount + right->offset);
3114             
3115             for( i = 0; i < n; i++ )
3116             {
3117                 int d = dir[i];
3118                 int idx = temp_buf[i];
3119                 if (d)
3120                 {
3121                     *rdst = (unsigned short)idx;
3122                     rdst++;
3123                     nr1 += (idx != 65535 )&d;
3124                 }
3125                 else
3126                 {
3127                     *ldst = (unsigned short)idx;
3128                     ldst++;
3129                 }
3130             }
3131
3132             if( vi < data->var_count )
3133             {
3134                 left->set_num_valid(vi, n1 - nr1);
3135                 right->set_num_valid(vi, nr1);
3136             }
3137         }
3138         else
3139         {
3140             int *ldst = buf->data.i + left->buf_idx*buf->cols + 
3141                 vi*scount + left->offset;
3142             int *rdst = buf->data.i + right->buf_idx*buf->cols + 
3143                 vi*scount + right->offset;
3144             
3145             for( i = 0; i < n; i++ )
3146             {
3147                 int d = dir[i];
3148                 int idx = temp_buf[i];
3149                 if (d)
3150                 {
3151                     *rdst = idx;
3152                     rdst++;
3153                     nr1 += (idx >= 0)&d;
3154                 }
3155                 else
3156                 {
3157                     *ldst = idx;
3158                     ldst++;
3159                 }
3160                 
3161             }
3162
3163             if( vi < data->var_count )
3164             {
3165                 left->set_num_valid(vi, n1 - nr1);
3166                 right->set_num_valid(vi, nr1);
3167             }
3168         }        
3169     }
3170
3171
3172     // split sample indices
3173     int *sample_idx_src_buf = temp_buf + n;
3174     const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
3175
3176     for(i = 0; i < n; i++)
3177         temp_buf[i] = sample_idx_src[i];
3178
3179     int pos = data->get_work_var_count();
3180     if (data->is_buf_16u)
3181     {
3182         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
3183             pos*scount + left->offset);
3184         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + 
3185             pos*scount + right->offset);
3186         for (i = 0; i < n; i++)
3187         {
3188             int d = dir[i];
3189             unsigned short idx = (unsigned short)temp_buf[i];
3190             if (d)
3191             {
3192                 *rdst = idx;
3193                 rdst++;
3194             }
3195             else
3196             {
3197                 *ldst = idx;
3198                 ldst++;
3199             }
3200         }
3201     }
3202     else
3203     {
3204         int* ldst = buf->data.i + left->buf_idx*buf->cols + 
3205             pos*scount + left->offset;
3206         int* rdst = buf->data.i + right->buf_idx*buf->cols + 
3207             pos*scount + right->offset;
3208         for (i = 0; i < n; i++)
3209         {
3210             int d = dir[i];
3211             int idx = temp_buf[i];
3212             if (d)
3213             {
3214                 *rdst = idx;
3215                 rdst++;
3216             }
3217             else
3218             {
3219                 *ldst = idx;
3220                 ldst++;
3221             }
3222         }
3223     }
3224     
3225     // deallocate the parent node data that is not needed anymore
3226     data->free_node_data(node);    
3227 }
3228
3229 float CvDTree::calc_error( CvMLData* _data, int type, vector<float> *resp )
3230 {
3231     float err = 0;
3232     const CvMat* values = _data->get_values();
3233     const CvMat* response = _data->get_responses();
3234     const CvMat* missing = _data->get_missing();
3235     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3236     const CvMat* var_types = _data->get_var_types();
3237     int* sidx = sample_idx ? sample_idx->data.i : 0;
3238     int r_step = CV_IS_MAT_CONT(response->type) ?
3239                 1 : response->step / CV_ELEM_SIZE(response->type);
3240     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3241     int sample_count = sample_idx ? sample_idx->cols : 0;
3242     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3243     float* pred_resp = 0;
3244     if( resp && (sample_count > 0) )
3245     {
3246         resp->resize( sample_count );
3247         pred_resp = &((*resp)[0]);
3248     }
3249
3250     if ( is_classifier )
3251     {
3252         for( int i = 0; i < sample_count; i++ )
3253         {
3254             CvMat sample, miss;
3255             int si = sidx ? sidx[i] : i;
3256             cvGetRow( values, &sample, si ); 
3257             if( missing ) 
3258                 cvGetRow( missing, &miss, si );             
3259             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3260             if( pred_resp )
3261                 pred_resp[i] = r;
3262             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
3263             err += d;
3264         }
3265         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3266     }
3267     else
3268     {
3269         for( int i = 0; i < sample_count; i++ )
3270         {
3271             CvMat sample, miss;
3272             int si = sidx ? sidx[i] : i;
3273             cvGetRow( values, &sample, si ); 
3274             if( missing ) 
3275                 cvGetRow( missing, &miss, si );             
3276             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3277             if( pred_resp )
3278                 pred_resp[i] = r;
3279             float d = r - response->data.fl[si*r_step];
3280             err += d*d;
3281         }
3282         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
3283     }
3284     return err;
3285 }
3286
3287 void CvDTree::prune_cv()
3288 {
3289     CvMat* ab = 0;
3290     CvMat* temp = 0;
3291     CvMat* err_jk = 0;
3292
3293     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3294     // 2. choose the best tree index (if need, apply 1SE rule).
3295     // 3. store the best index and cut the branches.
3296
3297     CV_FUNCNAME( "CvDTree::prune_cv" );
3298
3299     __BEGIN__;
3300
3301     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3302     // currently, 1SE for regression is not implemented
3303     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3304     double* err;
3305     double min_err = 0, min_err_se = 0;
3306     int min_idx = -1;
3307
3308     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3309
3310     // build the main tree sequence, calculate alpha's
3311     for(;;tree_count++)
3312     {
3313         double min_alpha = update_tree_rnc(tree_count, -1);
3314         if( cut_tree(tree_count, -1, min_alpha) )
3315             break;
3316
3317         if( ab->cols <= tree_count )
3318         {
3319             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3320             for( ti = 0; ti < ab->cols; ti++ )
3321                 temp->data.db[ti] = ab->data.db[ti];
3322             cvReleaseMat( &ab );
3323             ab = temp;
3324             temp = 0;
3325         }
3326
3327         ab->data.db[tree_count] = min_alpha;
3328     }
3329
3330     ab->data.db[0] = 0.;
3331
3332     if( tree_count > 0 )
3333     {
3334         for( ti = 1; ti < tree_count-1; ti++ )
3335             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3336         ab->data.db[tree_count-1] = DBL_MAX*0.5;
3337
3338         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3339         err = err_jk->data.db;
3340
3341         for( j = 0; j < cv_n; j++ )
3342         {
3343             int tj = 0, tk = 0;
3344             for( ; tk < tree_count; tj++ )
3345             {
3346                 double min_alpha = update_tree_rnc(tj, j);
3347                 if( cut_tree(tj, j, min_alpha) )
3348                     min_alpha = DBL_MAX;
3349
3350                 for( ; tk < tree_count; tk++ )
3351                 {
3352                     if( ab->data.db[tk] > min_alpha )
3353                         break;
3354                     err[j*tree_count + tk] = root->tree_error;
3355                 }
3356             }
3357         }
3358
3359         for( ti = 0; ti < tree_count; ti++ )
3360         {
3361             double sum_err = 0;
3362             for( j = 0; j < cv_n; j++ )
3363                 sum_err += err[j*tree_count + ti];
3364             if( ti == 0 || sum_err < min_err )
3365             {
3366                 min_err = sum_err;
3367                 min_idx = ti;
3368                 if( use_1se )
3369                     min_err_se = sqrt( sum_err*(n - sum_err) );
3370             }
3371             else if( sum_err < min_err + min_err_se )
3372                 min_idx = ti;
3373         }
3374     }
3375
3376     pruned_tree_idx = min_idx;
3377     free_prune_data(data->params.truncate_pruned_tree != 0);
3378
3379     __END__;
3380
3381     cvReleaseMat( &err_jk );
3382     cvReleaseMat( &ab );
3383     cvReleaseMat( &temp );
3384 }
3385
3386
3387 double CvDTree::update_tree_rnc( int T, int fold )
3388 {
3389     CvDTreeNode* node = root;
3390     double min_alpha = DBL_MAX;
3391
3392     for(;;)
3393     {
3394         CvDTreeNode* parent;
3395         for(;;)
3396         {
3397             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3398             if( t <= T || !node->left )
3399             {
3400                 node->complexity = 1;
3401                 node->tree_risk = node->node_risk;
3402                 node->tree_error = 0.;
3403                 if( fold >= 0 )
3404                 {
3405                     node->tree_risk = node->cv_node_risk[fold];
3406                     node->tree_error = node->cv_node_error[fold];
3407                 }
3408                 break;
3409             }
3410             node = node->left;
3411         }
3412
3413         for( parent = node->parent; parent && parent->right == node;
3414             node = parent, parent = parent->parent )
3415         {
3416             parent->complexity += node->complexity;
3417             parent->tree_risk += node->tree_risk;
3418             parent->tree_error += node->tree_error;
3419
3420             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3421                 - parent->tree_risk)/(parent->complexity - 1);
3422             min_alpha = MIN( min_alpha, parent->alpha );
3423         }
3424
3425         if( !parent )
3426             break;
3427
3428         parent->complexity = node->complexity;
3429         parent->tree_risk = node->tree_risk;
3430         parent->tree_error = node->tree_error;
3431         node = parent->right;
3432     }
3433
3434     return min_alpha;
3435 }
3436
3437
3438 int CvDTree::cut_tree( int T, int fold, double min_alpha )
3439 {
3440     CvDTreeNode* node = root;
3441     if( !node->left )
3442         return 1;
3443
3444     for(;;)
3445     {
3446         CvDTreeNode* parent;
3447         for(;;)
3448         {
3449             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3450             if( t <= T || !node->left )
3451                 break;
3452             if( node->alpha <= min_alpha + FLT_EPSILON )
3453             {
3454                 if( fold >= 0 )
3455                     node->cv_Tn[fold] = T;
3456                 else
3457                     node->Tn = T;
3458                 if( node == root )
3459                     return 1;
3460                 break;
3461             }
3462             node = node->left;
3463         }
3464
3465         for( parent = node->parent; parent && parent->right == node;
3466             node = parent, parent = parent->parent )
3467             ;
3468
3469         if( !parent )
3470             break;
3471
3472         node = parent->right;
3473     }
3474
3475     return 0;
3476 }
3477
3478
3479 void CvDTree::free_prune_data(bool cut_tree)
3480 {
3481     CvDTreeNode* node = root;
3482
3483     for(;;)
3484     {
3485         CvDTreeNode* parent;
3486         for(;;)
3487         {
3488             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3489             // as we will clear the whole cross-validation heap at the end
3490             node->cv_Tn = 0;
3491             node->cv_node_error = node->cv_node_risk = 0;
3492             if( !node->left )
3493                 break;
3494             node = node->left;
3495         }
3496
3497         for( parent = node->parent; parent && parent->right == node;
3498             node = parent, parent = parent->parent )
3499         {
3500             if( cut_tree && parent->Tn <= pruned_tree_idx )
3501             {
3502                 data->free_node( parent->left );
3503                 data->free_node( parent->right );
3504                 parent->left = parent->right = 0;
3505             }
3506         }
3507
3508         if( !parent )
3509             break;
3510
3511         node = parent->right;
3512     }
3513
3514     if( data->cv_heap )
3515         cvClearSet( data->cv_heap );
3516 }
3517
3518
3519 void CvDTree::free_tree()
3520 {
3521     if( root && data && data->shared )
3522     {
3523         pruned_tree_idx = INT_MIN;
3524         free_prune_data(true);
3525         data->free_node(root);
3526         root = 0;
3527     }
3528 }
3529
3530 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3531     const CvMat* _missing, bool preprocessed_input ) const
3532 {
3533     CvDTreeNode* result = 0;
3534     int* catbuf = 0;
3535
3536     CV_FUNCNAME( "CvDTree::predict" );
3537
3538     __BEGIN__;
3539
3540     int i, step, mstep = 0;
3541     const float* sample;
3542     const uchar* m = 0;
3543     CvDTreeNode* node = root;
3544     const int* vtype;
3545     const int* vidx;
3546     const int* cmap;
3547     const int* cofs;
3548
3549     if( !node )
3550         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
3551
3552     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3553         (_sample->cols != 1 && _sample->rows != 1) ||
3554         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3555         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3556             CV_ERROR( CV_StsBadArg,
3557         "the input sample must be 1d floating-point vector with the same "
3558         "number of elements as the total number of variables used for training" );
3559
3560     sample = _sample->data.fl;
3561     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3562
3563     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3564     {
3565         int n = data->cat_count->cols;
3566         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
3567         for( i = 0; i < n; i++ )
3568             catbuf[i] = -1;
3569     }
3570
3571     if( _missing )
3572     {
3573         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3574         !CV_ARE_SIZES_EQ(_missing, _sample) )
3575             CV_ERROR( CV_StsBadArg,
3576         "the missing data mask must be 8-bit vector of the same size as input sample" );
3577         m = _missing->data.ptr;
3578         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3579     }
3580
3581     vtype = data->var_type->data.i;
3582     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3583     cmap = data->cat_map ? data->cat_map->data.i : 0;
3584     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3585
3586     while( node->Tn > pruned_tree_idx && node->left )
3587     {
3588         CvDTreeSplit* split = node->split;
3589         int dir = 0;
3590         for( ; !dir && split != 0; split = split->next )
3591         {
3592             int vi = split->var_idx;
3593             int ci = vtype[vi];
3594             i = vidx ? vidx[vi] : vi;
3595             float val = sample[i*step];
3596             if( m && m[i*mstep] )
3597                 continue;
3598             if( ci < 0 ) // ordered
3599                 dir = val <= split->ord.c ? -1 : 1;
3600             else // categorical
3601             {
3602                 int c;
3603                 if( preprocessed_input )
3604                     c = cvRound(val);
3605                 else
3606                 {
3607                     c = catbuf[ci];
3608                     if( c < 0 )
3609                     {
3610                         int a = c = cofs[ci];
3611                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3612                         
3613                         int ival = cvRound(val);
3614                         if( ival != val )
3615                             CV_ERROR( CV_StsBadArg,
3616                             "one of input categorical variable is not an integer" );
3617                         
3618                         int sh = 0;
3619                         while( a < b )
3620                         {
3621                             sh++;
3622                             c = (a + b) >> 1;
3623                             if( ival < cmap[c] )
3624                                 b = c;
3625                             else if( ival > cmap[c] )
3626                                 a = c+1;
3627                             else
3628                                 break;
3629                         }
3630
3631                         if( c < 0 || ival != cmap[c] )
3632                             continue;
3633
3634                         catbuf[ci] = c -= cofs[ci];
3635                     }
3636                 }
3637                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3638                 dir = CV_DTREE_CAT_DIR(c, split->subset);
3639             }
3640
3641             if( split->inversed )
3642                 dir = -dir;
3643         }
3644
3645         if( !dir )
3646         {
3647             double diff = node->right->sample_count - node->left->sample_count;
3648             dir = diff < 0 ? -1 : 1;
3649         }
3650         node = dir < 0 ? node->left : node->right;
3651     }
3652
3653     result = node;
3654
3655     __END__;
3656
3657     return result;
3658 }
3659
3660
3661 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3662 {
3663     CvMat sample = _sample, mmask = _missing;
3664     return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3665 }
3666
3667
3668 const CvMat* CvDTree::get_var_importance()
3669 {
3670     if( !var_importance )
3671     {
3672         CvDTreeNode* node = root;
3673         double* importance;
3674         if( !node )
3675             return 0;
3676         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3677         cvZero( var_importance );
3678         importance = var_importance->data.db;
3679
3680         for(;;)
3681         {
3682             CvDTreeNode* parent;
3683             for( ;; node = node->left )
3684             {
3685                 CvDTreeSplit* split = node->split;
3686
3687                 if( !node->left || node->Tn <= pruned_tree_idx )
3688                     break;
3689
3690                 for( ; split != 0; split = split->next )
3691                     importance[split->var_idx] += split->quality;
3692             }
3693
3694             for( parent = node->parent; parent && parent->right == node;
3695                 node = parent, parent = parent->parent )
3696                 ;
3697
3698             if( !parent )
3699                 break;
3700
3701             node = parent->right;
3702         }
3703
3704         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3705     }
3706
3707     return var_importance;
3708 }
3709
3710
3711 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3712 {
3713     int ci;
3714
3715     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3716     cvWriteInt( fs, "var", split->var_idx );
3717     cvWriteReal( fs, "quality", split->quality );
3718
3719     ci = data->get_var_type(split->var_idx);
3720     if( ci >= 0 ) // split on a categorical var
3721     {
3722         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3723         for( i = 0; i < n; i++ )
3724             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3725
3726         // ad-hoc rule when to use inverse categorical split notation
3727         // to achieve more compact and clear representation
3728         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3729
3730         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3731                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3732
3733         for( i = 0; i < n; i++ )
3734         {
3735             int dir = CV_DTREE_CAT_DIR(i,split->subset);
3736             if( dir*default_dir < 0 )
3737                 cvWriteInt( fs, 0, i );
3738         }
3739         cvEndWriteStruct( fs );
3740     }
3741     else
3742         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3743
3744     cvEndWriteStruct( fs );
3745 }
3746
3747
3748 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3749 {
3750     CvDTreeSplit* split;
3751
3752     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3753
3754     cvWriteInt( fs, "depth", node->depth );
3755     cvWriteInt( fs, "sample_count", node->sample_count );
3756     cvWriteReal( fs, "value", node->value );
3757
3758     if( data->is_classifier )
3759         cvWriteInt( fs, "norm_class_idx", node->class_idx );
3760
3761     cvWriteInt( fs, "Tn", node->Tn );
3762     cvWriteInt( fs, "complexity", node->complexity );
3763     cvWriteReal( fs, "alpha", node->alpha );
3764     cvWriteReal( fs, "node_risk", node->node_risk );
3765     cvWriteReal( fs, "tree_risk", node->tree_risk );
3766     cvWriteReal( fs, "tree_error", node->tree_error );
3767
3768     if( node->left )
3769     {
3770         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3771
3772         for( split = node->split; split != 0; split = split->next )
3773             write_split( fs, split );
3774
3775         cvEndWriteStruct( fs );
3776     }
3777
3778     cvEndWriteStruct( fs );
3779 }
3780
3781
3782 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3783 {
3784     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3785
3786     __BEGIN__;
3787
3788     CvDTreeNode* node = root;
3789
3790     // traverse the tree and save all the nodes in depth-first order
3791     for(;;)
3792     {
3793         CvDTreeNode* parent;
3794         for(;;)
3795         {
3796             write_node( fs, node );
3797             if( !node->left )
3798                 break;
3799             node = node->left;
3800         }
3801
3802         for( parent = node->parent; parent && parent->right == node;
3803             node = parent, parent = parent->parent )
3804             ;
3805
3806         if( !parent )
3807             break;
3808
3809         node = parent->right;
3810     }
3811
3812     __END__;
3813 }
3814
3815
3816 void CvDTree::write( CvFileStorage* fs, const char* name ) const
3817 {
3818     //CV_FUNCNAME( "CvDTree::write" );
3819
3820     __BEGIN__;
3821
3822     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3823
3824     //get_var_importance();
3825     data->write_params( fs );
3826     //if( var_importance )
3827     //cvWrite( fs, "var_importance", var_importance );
3828     write( fs );
3829
3830     cvEndWriteStruct( fs );
3831
3832     __END__;
3833 }
3834
3835
3836 void CvDTree::write( CvFileStorage* fs ) const
3837 {
3838     //CV_FUNCNAME( "CvDTree::write" );
3839
3840     __BEGIN__;
3841
3842     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3843
3844     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3845     write_tree_nodes( fs );
3846     cvEndWriteStruct( fs );
3847
3848     __END__;
3849 }
3850
3851
3852 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3853 {
3854     CvDTreeSplit* split = 0;
3855
3856     CV_FUNCNAME( "CvDTree::read_split" );
3857
3858     __BEGIN__;
3859
3860     int vi, ci;
3861
3862     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3863         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3864
3865     vi = cvReadIntByName( fs, fnode, "var", -1 );
3866     if( (unsigned)vi >= (unsigned)data->var_count )
3867         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3868
3869     ci = data->get_var_type(vi);
3870     if( ci >= 0 ) // split on categorical var
3871     {
3872         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3873         CvSeqReader reader;
3874         CvFileNode* inseq;
3875         split = data->new_split_cat( vi, 0 );
3876         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3877         if( !inseq )
3878         {
3879             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3880             inversed = 1;
3881         }
3882         if( !inseq ||
3883             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3884             CV_ERROR( CV_StsParseError,
3885             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3886
3887         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3888         {
3889             val = inseq->data.i;
3890             if( (unsigned)val >= (unsigned)n )
3891                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3892
3893             split->subset[val >> 5] |= 1 << (val & 31);
3894         }
3895         else
3896         {
3897             cvStartReadSeq( inseq->data.seq, &reader );
3898
3899             for( i = 0; i < reader.seq->total; i++ )
3900             {
3901                 CvFileNode* inode = (CvFileNode*)reader.ptr;
3902                 val = inode->data.i;
3903                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3904                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3905
3906                 split->subset[val >> 5] |= 1 << (val & 31);
3907                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3908             }
3909         }
3910
3911         // for categorical splits we do not use inversed splits,
3912         // instead we inverse the variable set in the split
3913         if( inversed )
3914             for( i = 0; i < (n + 31) >> 5; i++ )
3915                 split->subset[i] ^= -1;
3916     }
3917     else
3918     {
3919         CvFileNode* cmp_node;
3920         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3921
3922         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3923         if( !cmp_node )
3924         {
3925             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3926             split->inversed = 1;
3927         }
3928
3929         split->ord.c = (float)cvReadReal( cmp_node );
3930     }
3931
3932     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3933
3934     __END__;
3935
3936     return split;
3937 }
3938
3939
3940 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3941 {
3942     CvDTreeNode* node = 0;
3943
3944     CV_FUNCNAME( "CvDTree::read_node" );
3945
3946     __BEGIN__;
3947
3948     CvFileNode* splits;
3949     int i, depth;
3950
3951     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3952         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3953
3954     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3955     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3956     if( depth != node->depth )
3957         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3958
3959     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3960     node->value = cvReadRealByName( fs, fnode, "value" );
3961     if( data->is_classifier )
3962         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3963
3964     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3965     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3966     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3967     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3968     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3969     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3970
3971     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3972     if( splits )
3973     {
3974         CvSeqReader reader;
3975         CvDTreeSplit* last_split = 0;
3976
3977         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3978             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3979
3980         cvStartReadSeq( splits->data.seq, &reader );
3981         for( i = 0; i < reader.seq->total; i++ )
3982         {
3983             CvDTreeSplit* split;
3984             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3985             if( !last_split )
3986                 node->split = last_split = split;
3987             else
3988                 last_split = last_split->next = split;
3989
3990             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3991         }
3992     }
3993
3994     __END__;
3995
3996     return node;
3997 }
3998
3999
4000 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4001 {
4002     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4003
4004     __BEGIN__;
4005
4006     CvSeqReader reader;
4007     CvDTreeNode _root;
4008     CvDTreeNode* parent = &_root;
4009     int i;
4010     parent->left = parent->right = parent->parent = 0;
4011
4012     cvStartReadSeq( fnode->data.seq, &reader );
4013
4014     for( i = 0; i < reader.seq->total; i++ )
4015     {
4016         CvDTreeNode* node;
4017
4018         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4019         if( !parent->left )
4020             parent->left = node;
4021         else
4022             parent->right = node;
4023         if( node->split )
4024             parent = node;
4025         else
4026         {
4027             while( parent && parent->right )
4028                 parent = parent->parent;
4029         }
4030
4031         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4032     }
4033
4034     root = _root.left;
4035
4036     __END__;
4037 }
4038
4039
4040 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4041 {
4042     CvDTreeTrainData* _data = new CvDTreeTrainData();
4043     _data->read_params( fs, fnode );
4044
4045     read( fs, fnode, _data );
4046     get_var_importance();
4047 }
4048
4049
4050 // a special entry point for reading weak decision trees from the tree ensembles
4051 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4052 {
4053     CV_FUNCNAME( "CvDTree::read" );
4054
4055     __BEGIN__;
4056
4057     CvFileNode* tree_nodes;
4058
4059     clear();
4060     data = _data;
4061
4062     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4063     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4064         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4065
4066     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4067     read_tree_nodes( fs, tree_nodes );
4068
4069     __END__;
4070 }
4071
4072 /* End of file. */