1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 static const float ord_var_epsilon = FLT_EPSILON*2;
44 static const float ord_nan = FLT_MAX*0.5f;
45 static const int min_block_size = 1 << 16;
46 static const int block_size_delta = 1 << 10;
48 CvDTreeTrainData::CvDTreeTrainData()
50 var_idx = var_type = cat_count = cat_ofs = cat_map =
51 priors = counts = buf = direction = split_buf = 0;
52 tree_storage = temp_storage = 0;
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
59 const CvMat* _responses, const CvMat* _var_idx,
60 const CvMat* _sample_idx, const CvMat* _var_type,
61 const CvMat* _missing_mask,
62 CvDTreeParams _params, bool _shared )
64 var_idx = var_type = cat_count = cat_ofs = cat_map =
65 priors = counts = buf = direction = split_buf = 0;
66 tree_storage = temp_storage = 0;
68 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
69 _var_type, _missing_mask, _params, _shared );
73 CvDTreeTrainData::~CvDTreeTrainData()
79 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
83 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
90 if( params.max_categories < 2 )
91 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
92 params.max_categories = MIN( params.max_categories, 15 );
94 if( params.max_depth < 0 )
95 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
96 params.max_depth = MIN( params.max_depth, 25 );
98 params.min_sample_count = MAX(params.min_sample_count,1);
100 if( params.cv_folds < 0 )
101 CV_ERROR( CV_StsOutOfRange,
102 "params.cv_folds should be =0 (the tree is not pruned) "
103 "or n>0 (tree is pruned using n-fold cross-validation)" );
105 if( params.cv_folds == 1 )
108 if( params.regression_accuracy < 0 )
109 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
123 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
124 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
126 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
127 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
128 const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
131 CvMat* sample_idx = 0;
132 CvMat* var_type0 = 0;
136 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
140 int sample_all = 0, r_type = 0, cv_n;
141 int total_c_count = 0;
142 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
146 const int *sidx = 0, *vidx = 0;
153 CV_CALL( set_params( _params ));
155 // check parameter types and sizes
156 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
157 if( _tflag == CV_ROW_SAMPLE )
159 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
162 ms_step = _missing_mask->step, mv_step = 1;
166 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
169 mv_step = _missing_mask->step, ms_step = 1;
172 sample_count = sample_all;
177 CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
178 sidx = sample_idx->data.i;
179 sample_count = sample_idx->rows + sample_idx->cols - 1;
184 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
185 vidx = var_idx->data.i;
186 var_count = var_idx->rows + var_idx->cols - 1;
189 if( !CV_IS_MAT(_responses) ||
190 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
191 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
192 _responses->rows != 1 && _responses->cols != 1 ||
193 _responses->rows + _responses->cols - 1 != sample_all )
194 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
195 "floating-point vector containing as many elements as "
196 "the total number of samples in the training data matrix" );
198 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
199 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
204 is_classifier = r_type == CV_VAR_CATEGORICAL;
206 // step 0. calc the number of categorical vars
207 for( vi = 0; vi < var_count; vi++ )
209 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
210 cat_var_count++ : ord_var_count--;
213 ord_var_count = ~ord_var_count;
214 cv_n = params.cv_folds;
215 // set the two last elements of var_type array to be able
216 // to locate responses and cross-validation labels using
217 // the corresponding get_* functions.
218 var_type->data.i[var_count] = cat_var_count;
219 var_type->data.i[var_count+1] = cat_var_count+1;
221 // in case of single ordered predictor we need dummy cv_labels
222 // for safe split_node_data() operation
223 have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
225 buf_size = (ord_var_count*2 + cat_var_count + 1 +
226 (have_cv_labels ? 1 : 0))*sample_count + 2;
228 buf_count = shared ? 3 : 2;
229 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
230 CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
231 CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
232 CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
234 // now calculate the maximum size of split,
235 // create memory storage that will keep nodes and splits of the decision tree
236 // allocate root node and the buffer for the whole training data
237 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
238 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
239 tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
240 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
241 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
242 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
244 temp_block_size = nv_size = var_count*sizeof(int);
247 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
248 CV_ERROR( CV_StsOutOfRange,
249 "The many folds in cross-validation for such a small dataset" );
251 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
252 temp_block_size = MAX(temp_block_size, cv_size);
255 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
256 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
257 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
259 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
261 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
262 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
266 // transform the training data to convenient representation
267 for( vi = 0; vi <= var_count; vi++ )
270 const uchar* mask = 0;
271 int m_step = 0, step;
272 const int* idata = 0;
273 const float* fdata = 0;
276 if( vi < var_count ) // analyze i-th input variable
278 int vi0 = vidx ? vidx[vi] : vi;
279 ci = get_var_type(vi);
280 step = ds_step; m_step = ms_step;
281 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
282 idata = _train_data->data.i + vi0*dv_step;
284 fdata = _train_data->data.fl + vi0*dv_step;
286 mask = _missing_mask->data.ptr + vi0*mv_step;
288 else // analyze _responses
291 step = CV_IS_MAT_CONT(_responses->type) ?
292 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
293 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
294 idata = _responses->data.i;
296 fdata = _responses->data.fl;
299 if( vi < var_count && ci >= 0 ||
300 vi == var_count && is_classifier ) // process categorical variable or response
302 int c_count, prev_label, prev_i;
303 int* c_map, *dst = get_cat_var_data( data_root, vi );
306 for( i = 0; i < sample_count; i++ )
308 int val = INT_MAX, si = sidx ? sidx[i] : i;
309 if( !mask || !mask[si*m_step] )
312 val = idata[si*step];
315 float t = fdata[si*step];
319 sprintf( err, "%d-th value of %d-th (categorical) "
320 "variable is not an integer", i, vi );
321 CV_ERROR( CV_StsBadArg, err );
327 sprintf( err, "%d-th value of %d-th (categorical) "
328 "variable is too large", i, vi );
329 CV_ERROR( CV_StsBadArg, err );
334 int_ptr[i] = dst + i;
337 // sort all the values, including the missing measurements
338 // that should all move to the end
339 icvSortIntPtr( int_ptr, sample_count, 0 );
340 //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
342 c_count = num_valid > 0;
344 // count the categories
345 for( i = 1; i < num_valid; i++ )
346 c_count += *int_ptr[i] != *int_ptr[i-1];
349 max_c_count = MAX( max_c_count, c_count );
350 cat_count->data.i[ci] = c_count;
351 cat_ofs->data.i[ci] = total_c_count;
353 // resize cat_map, if need
354 if( cat_map->cols < total_c_count + c_count )
357 CV_CALL( cat_map = cvCreateMat( 1,
358 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
359 for( i = 0; i < total_c_count; i++ )
360 cat_map->data.i[i] = tmp_map->data.i[i];
361 cvReleaseMat( &tmp_map );
364 c_map = cat_map->data.i + total_c_count;
365 total_c_count += c_count;
367 // compact the class indices and build the map
368 prev_label = ~*int_ptr[0];
371 for( i = 0, prev_i = -1; i < num_valid; i++ )
373 int cur_label = *int_ptr[i];
374 if( cur_label != prev_label )
376 c_map[++c_count] = prev_label = cur_label;
379 *int_ptr[i] = c_count;
382 // replace labels for missing values with -1
383 for( ; i < sample_count; i++ )
386 else if( ci < 0 ) // process ordered variable
388 CvPair32s32f* dst = get_ord_var_data( data_root, vi );
390 for( i = 0; i < sample_count; i++ )
393 int si = sidx ? sidx[i] : i;
394 if( !mask || !mask[si*m_step] )
397 val = (float)idata[si*step];
399 val = fdata[si*step];
401 if( fabs(val) >= ord_nan )
403 sprintf( err, "%d-th value of %d-th (ordered) "
404 "variable (=%g) is too large", i, vi, val );
405 CV_ERROR( CV_StsBadArg, err );
413 icvSortPairs( dst, sample_count, 0 );
415 else // special case: process ordered response,
416 // it will be stored similarly to categorical vars (i.e. no pairs)
418 float* dst = get_ord_responses( data_root );
420 for( i = 0; i < sample_count; i++ )
423 int si = sidx ? sidx[i] : i;
425 val = (float)idata[si*step];
427 val = fdata[si*step];
429 if( fabs(val) >= ord_nan )
431 sprintf( err, "%d-th value of %d-th (ordered) "
432 "variable (=%g) is out of range", i, vi, val );
433 CV_ERROR( CV_StsBadArg, err );
438 cat_count->data.i[cat_var_count] = 0;
439 cat_ofs->data.i[cat_var_count] = total_c_count;
440 num_valid = sample_count;
444 data_root->num_valid[vi] = num_valid;
449 int* dst = get_cv_labels(data_root);
452 for( i = vi = 0; i < sample_count; i++ )
455 vi &= vi < cv_n ? -1 : 0;
458 for( i = 0; i < sample_count; i++ )
460 int a = cvRandInt(r) % sample_count;
461 int b = cvRandInt(r) % sample_count;
462 CV_SWAP( dst[a], dst[b], vi );
466 cat_map->cols = MAX( total_c_count, 1 );
468 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
469 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
470 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
472 have_priors = is_classifier && params.priors;
475 int m = get_num_classes(), rows = 4;
477 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
478 for( i = 0; i < m; i++ )
480 double val = have_priors ? params.priors[i] : 1.;
482 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
483 priors->data.db[i] = val;
488 cvScale( priors, priors, 1./sum );
490 if( cat_var_count > 0 || params.cv_folds > 0 )
492 // need storage for cjk (see find_split_cat_gini) and risks/errors
493 rows += MAX( max_c_count, params.cv_folds ) + 1;
494 // add buffer for k-means clustering
495 if( m > 2 && max_c_count > params.max_categories )
496 rows += params.max_categories + (max_c_count+m-1)/m;
499 CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
502 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
503 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
508 cvReleaseMat( &sample_idx );
509 cvReleaseMat( &var_type0 );
510 cvReleaseMat( &tmp_map );
514 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
516 CvDTreeNode* root = 0;
517 CvMat* isubsample_idx = 0;
518 CvMat* subsample_co = 0;
520 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
525 CV_ERROR( CV_StsError, "No training data has been set" );
528 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
530 if( !isubsample_idx )
532 // make a copy of the root node
535 root = new_node( 0, 1, 0, 0 );
538 root->num_valid = temp.num_valid;
539 for( i = 0; i < var_count; i++ )
540 root->num_valid[i] = data_root->num_valid[i];
541 root->cv_Tn = temp.cv_Tn;
542 root->cv_node_risk = temp.cv_node_risk;
543 root->cv_node_error = temp.cv_node_error;
547 int* sidx = isubsample_idx->data.i;
548 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
549 int* co, cur_ofs = 0;
550 int vi, i, total = data_root->sample_count;
551 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
552 root = new_node( 0, count, 1, 0 );
554 CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
555 cvZero( subsample_co );
556 co = subsample_co->data.i;
557 for( i = 0; i < count; i++ )
559 for( i = 0; i < total; i++ )
570 for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
572 int ci = get_var_type(vi);
574 if( ci >= 0 || vi >= var_count )
576 const int* src = get_cat_var_data( data_root, vi );
577 int* dst = get_cat_var_data( root, vi );
580 for( i = 0; i < count; i++ )
582 int val = src[sidx[i]];
584 num_valid += val >= 0;
588 root->num_valid[vi] = num_valid;
592 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
593 CvPair32s32f* dst = get_ord_var_data( root, vi );
594 int j = 0, num_valid = data_root->num_valid[vi], idx, count_i;
596 for( i = 0; i < num_valid; i++ )
602 float val = src[i].val;
603 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
611 root->num_valid[vi] = j;
613 for( ; i < total; i++ )
619 float val = src[i].val;
620 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
633 cvReleaseMat( &isubsample_idx );
634 cvReleaseMat( &subsample_co );
640 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
641 float* values, uchar* missing,
642 float* responses, bool get_class_idx )
644 CvMat* subsample_idx = 0;
645 CvMat* subsample_co = 0;
647 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
651 int i, vi, total = sample_count, count = total, cur_ofs = 0;
657 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
658 sidx = subsample_idx->data.i;
659 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
660 co = subsample_co->data.i;
661 cvZero( subsample_co );
662 count = subsample_idx->cols + subsample_idx->rows - 1;
663 for( i = 0; i < count; i++ )
665 for( i = 0; i < total; i++ )
667 int count_i = co[i*2];
670 co[i*2+1] = cur_ofs*var_count;
676 memset( missing, 1, count*var_count );
678 for( vi = 0; vi < var_count; vi++ )
680 int ci = get_var_type(vi);
681 if( ci >= 0 ) // categorical
683 float* dst = values + vi;
684 uchar* m = missing + vi;
685 const int* src = get_cat_var_data(data_root, vi);
687 for( i = 0; i < count; i++, dst += var_count, m += var_count )
689 int idx = sidx ? sidx[i] : i;
697 float* dst = values + vi;
698 uchar* m = missing + vi;
699 const CvPair32s32f* src = get_ord_var_data(data_root, vi);
700 int count1 = data_root->num_valid[vi];
702 for( i = 0; i < count1; i++ )
709 cur_ofs = co[idx*2+1];
712 cur_ofs = idx*var_count;
715 float val = src[i].val;
716 for( ; count_i > 0; count_i--, cur_ofs += var_count )
729 const int* src = get_class_labels(data_root);
730 for( i = 0; i < count; i++ )
732 int idx = sidx ? sidx[i] : i;
733 int val = get_class_idx ? src[idx] :
734 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
735 responses[i] = (float)val;
740 const float* src = get_ord_responses(data_root);
741 for( i = 0; i < count; i++ )
743 int idx = sidx ? sidx[i] : i;
744 responses[i] = src[idx];
750 cvReleaseMat( &subsample_idx );
751 cvReleaseMat( &subsample_co );
755 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
756 int storage_idx, int offset )
758 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
760 node->sample_count = count;
761 node->depth = parent ? parent->depth + 1 : 0;
762 node->parent = parent;
763 node->left = node->right = 0;
766 node->class_idx = -1;
769 node->buf_idx = storage_idx;
770 node->offset = offset;
772 node->num_valid = (int*)cvSetNew( nv_heap );
775 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
777 node->complexity = 0;
779 if( params.cv_folds > 0 && cv_heap )
781 int cv_n = params.cv_folds;
782 node->cv_Tn = (int*)cvSetNew( cv_heap );
783 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
784 node->cv_node_error = node->cv_node_risk + cv_n;
789 node->cv_node_risk = 0;
790 node->cv_node_error = 0;
797 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
798 int split_point, int inversed, float quality )
800 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
802 split->ord.c = cmp_val;
803 split->ord.split_point = split_point;
804 split->inversed = inversed;
805 split->quality = quality;
812 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
814 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
815 int i, n = (max_c_count + 31)/32;
819 split->quality = quality;
820 for( i = 0; i < n; i++ )
821 split->subset[i] = 0;
828 void CvDTreeTrainData::free_node( CvDTreeNode* node )
830 CvDTreeSplit* split = node->split;
831 free_node_data( node );
834 CvDTreeSplit* next = split->next;
835 cvSetRemoveByPtr( split_heap, split );
839 cvSetRemoveByPtr( node_heap, node );
843 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
845 if( node->num_valid )
847 cvSetRemoveByPtr( nv_heap, node->num_valid );
850 // do not free cv_* fields, as all the cross-validation related data is released at once.
854 void CvDTreeTrainData::free_train_data()
856 cvReleaseMat( &counts );
857 cvReleaseMat( &buf );
858 cvReleaseMat( &direction );
859 cvReleaseMat( &split_buf );
860 cvReleaseMemStorage( &temp_storage );
861 cv_heap = nv_heap = 0;
865 void CvDTreeTrainData::clear()
869 cvReleaseMemStorage( &tree_storage );
871 cvReleaseMat( &var_idx );
872 cvReleaseMat( &var_type );
873 cvReleaseMat( &cat_count );
874 cvReleaseMat( &cat_ofs );
875 cvReleaseMat( &cat_map );
876 cvReleaseMat( &priors );
878 node_heap = split_heap = 0;
880 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
881 have_cv_labels = have_priors = is_classifier = false;
883 buf_count = buf_size = 0;
892 int CvDTreeTrainData::get_num_classes() const
894 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
898 int CvDTreeTrainData::get_var_type(int vi) const
900 return var_type->data.i[vi];
904 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
906 int oi = ~get_var_type(vi);
907 assert( 0 <= oi && oi < ord_var_count );
908 return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
909 n->offset + oi*n->sample_count*2);
913 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
915 return get_cat_var_data( n, var_count );
919 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
921 return (float*)get_cat_var_data( n, var_count );
925 int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
927 return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
931 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
933 int ci = get_var_type(vi);
934 assert( 0 <= ci && ci <= cat_var_count + 1 );
935 return buf->data.i + n->buf_idx*buf->cols + n->offset +
936 (ord_var_count*2 + ci)*n->sample_count;
940 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
942 int idx = n->buf_idx + 1;
943 if( idx >= buf_count )
944 idx = shared ? 1 : 0;
948 /////////////////////// Decision Tree /////////////////////////
958 void CvDTree::clear()
960 cvReleaseMat( &var_importance );
970 pruned_tree_idx = -1;
980 bool CvDTree::train( const CvMat* _train_data, int _tflag,
981 const CvMat* _responses, const CvMat* _var_idx,
982 const CvMat* _sample_idx, const CvMat* _var_type,
983 const CvMat* _missing_mask, CvDTreeParams _params )
987 CV_FUNCNAME( "CvDTree::train" );
992 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
993 _var_idx, _sample_idx, _var_type,
994 _missing_mask, _params, false );
995 CV_CALL( result = do_train(0));
1003 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1005 bool result = false;
1007 CV_FUNCNAME( "CvDTree::train" );
1013 data->shared = true;
1014 CV_CALL( result = do_train(_subsample_idx));
1022 bool CvDTree::do_train( const CvMat* _subsample_idx )
1024 bool result = false;
1026 CV_FUNCNAME( "CvDTree::train" );
1030 root = data->subsample_data( _subsample_idx );
1032 try_split_node(root);
1034 if( data->params.cv_folds > 0 )
1038 data->free_train_data();
1048 #define DTREE_CAT_DIR(idx,subset) \
1049 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
1051 void CvDTree::try_split_node( CvDTreeNode* node )
1053 CvDTreeSplit* best_split = 0;
1054 int i, n = node->sample_count, vi;
1055 bool can_split = true;
1056 double quality_scale;
1058 calc_node_value( node );
1060 if( node->sample_count <= data->params.min_sample_count ||
1061 node->depth >= data->params.max_depth )
1064 if( can_split && data->is_classifier )
1066 // check if we have a "pure" node,
1067 // we assume that cls_count is filled by calc_node_value()
1068 int* cls_count = data->counts->data.i;
1069 int nz = 0, m = data->get_num_classes();
1070 for( i = 0; i < m; i++ )
1071 nz += cls_count[i] != 0;
1072 if( nz == 1 ) // there is only one class
1075 else if( can_split )
1077 const float* responses = data->get_ord_responses( node );
1078 float diff = responses[n-1] - responses[0];
1079 if( diff < data->params.regression_accuracy )
1085 best_split = find_best_split(node);
1086 // TODO: check the split quality ...
1087 node->split = best_split;
1090 if( !can_split || !best_split )
1092 data->free_node_data(node);
1096 quality_scale = calc_node_dir( node );
1098 if( data->params.use_surrogates )
1100 // find all the surrogate splits
1101 // and sort them by their similarity to the primary one
1102 for( vi = 0; vi < data->var_count; vi++ )
1104 CvDTreeSplit* split;
1105 int ci = data->get_var_type(vi);
1107 if( vi == best_split->var_idx )
1111 split = find_surrogate_split_cat( node, vi );
1113 split = find_surrogate_split_ord( node, vi );
1118 CvDTreeSplit* prev_split = node->split;
1119 split->quality = (float)(split->quality*quality_scale);
1121 while( prev_split->next &&
1122 prev_split->next->quality > split->quality )
1123 prev_split = prev_split->next;
1124 split->next = prev_split->next;
1125 prev_split->next = split;
1130 split_node_data( node );
1131 try_split_node( node->left );
1132 try_split_node( node->right );
1136 // calculate direction (left(-1),right(1),missing(0))
1137 // for each sample using the best split
1138 // the function returns scale coefficients for surrogate split quality factors.
1139 // the scale is applied to normalize surrogate split quality relatively to the
1140 // best (primary) split quality. That is, if a surrogate split is absolutely
1141 // identical to the primary split, its quality will be set to the maximum value =
1142 // quality of the primary split; otherwise, it will be lower.
1143 // besides, the function compute node->maxlr,
1144 // minimum possible quality (w/o considering the above mentioned scale)
1145 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1146 // are not discarded.
1147 double CvDTree::calc_node_dir( CvDTreeNode* node )
1149 char* dir = (char*)data->direction->data.ptr;
1150 int i, n = node->sample_count, vi = node->split->var_idx;
1153 assert( !node->split->inversed );
1155 if( data->get_var_type(vi) >= 0 ) // split on categorical var
1157 const int* labels = data->get_cat_var_data(node,vi);
1158 const int* subset = node->split->subset;
1160 if( !data->have_priors )
1162 int sum = 0, sum_abs = 0;
1164 for( i = 0; i < n; i++ )
1166 int idx = labels[i];
1167 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1168 sum += d; sum_abs += d & 1;
1172 R = (sum_abs + sum) >> 1;
1173 L = (sum_abs - sum) >> 1;
1177 const int* responses = data->get_class_labels(node);
1178 const double* priors = data->priors->data.db;
1179 double sum = 0, sum_abs = 0;
1181 for( i = 0; i < n; i++ )
1183 int idx = labels[i];
1184 double w = priors[responses[i]];
1185 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1186 sum += d*w; sum_abs += (d & 1)*w;
1190 R = (sum_abs + sum) * 0.5;
1191 L = (sum_abs - sum) * 0.5;
1194 else // split on ordered var
1196 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1197 int split_point = node->split->ord.split_point;
1198 int n1 = node->num_valid[vi];
1200 assert( 0 <= split_point && split_point < n1-1 );
1202 if( !data->have_priors )
1204 for( i = 0; i <= split_point; i++ )
1205 dir[sorted[i].i] = (char)-1;
1206 for( ; i < n1; i++ )
1207 dir[sorted[i].i] = (char)1;
1209 dir[sorted[i].i] = (char)0;
1212 R = n1 - split_point + 1;
1216 const int* responses = data->get_class_labels(node);
1217 const double* priors = data->priors->data.db;
1220 for( i = 0; i <= split_point; i++ )
1222 int idx = sorted[i].i;
1223 double w = priors[responses[idx]];
1224 dir[idx] = (char)-1;
1228 for( ; i < n1; i++ )
1230 int idx = sorted[i].i;
1231 double w = priors[responses[idx]];
1237 dir[sorted[i].i] = (char)0;
1241 node->maxlr = MAX( L, R );
1242 return node->split->quality/(L + R);
1246 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1249 CvDTreeSplit *best_split = 0, *split = 0, *t;
1251 for( vi = 0; vi < data->var_count; vi++ )
1253 int ci = data->get_var_type(vi);
1254 if( node->num_valid[vi] <= 1 )
1257 if( data->is_classifier )
1260 split = find_split_cat_gini( node, vi );
1262 split = find_split_ord_gini( node, vi );
1267 split = find_split_cat_reg( node, vi );
1269 split = find_split_ord_reg( node, vi );
1274 if( !best_split || best_split->quality < split->quality )
1275 CV_SWAP( best_split, split, t );
1277 cvSetRemoveByPtr( data->split_heap, split );
1285 CvDTreeSplit* CvDTree::find_split_ord_gini( CvDTreeNode* node, int vi )
1287 const float epsilon = FLT_EPSILON*2;
1288 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1289 const int* responses = data->get_class_labels(node);
1290 int n = node->sample_count;
1291 int n1 = node->num_valid[vi];
1292 int m = data->get_num_classes();
1293 const int* rc0 = data->counts->data.i;
1294 int* lc = (int*)(rc0 + m);
1297 double lsum2 = 0, rsum2 = 0, best_val = 0;
1298 const double* priors = data->have_priors ? data->priors->data.db : 0;
1300 // init arrays of class instance counters on both sides of the split
1301 for( i = 0; i < m; i++ )
1307 // compensate for missing values
1308 for( i = n1; i < n; i++ )
1309 rc[responses[sorted[i].i]]--;
1315 for( i = 0; i < m; i++ )
1316 rsum2 += (double)rc[i]*rc[i];
1318 for( i = 0; i < n1 - 1; i++ )
1320 int idx = responses[sorted[i].i];
1323 lv = lc[idx]; rv = rc[idx];
1326 lc[idx] = lv + 1; rc[idx] = rv - 1;
1328 if( sorted[i].val + epsilon < sorted[i+1].val )
1330 double val = lsum2/L + rsum2/R;
1331 if( best_val < val )
1341 double L = 0, R = 0;
1342 for( i = 0; i < m; i++ )
1344 double wv = rc[i]*priors[i];
1349 for( i = 0; i < n1 - 1; i++ )
1351 int idx = responses[sorted[i].i];
1353 double p = priors[idx], p2 = p*p;
1355 lv = lc[idx]; rv = rc[idx];
1356 lsum2 += p2*(lv*2 + 1);
1357 rsum2 -= p2*(rv*2 - 1);
1358 lc[idx] = lv + 1; rc[idx] = rv - 1;
1360 if( sorted[i].val + epsilon < sorted[i+1].val )
1362 double val = lsum2/L + rsum2/R;
1363 if( best_val < val )
1372 return best_i >= 0 ? data->new_split_ord( vi,
1373 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1374 0, (float)best_val ) : 0;
1378 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1379 int* csums, int k, int* labels )
1381 // TODO: consider adding priors (class weights) to the algorithm
1382 int iters = 0, max_iters = 100;
1384 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1385 double *v_weights = buf, *c_weights = buf + k;
1386 bool modified = true;
1387 CvRNG* r = &data->rng;
1389 // assign labels randomly
1390 for( i = idx = 0; i < n; i++ )
1393 const int* v = vectors + i*m;
1395 idx &= idx < k ? -1 : 0;
1397 // compute weight of each vector
1398 for( j = 0; j < m; j++ )
1400 v_weights[i] = sum ? 1./sum : 0.;
1403 for( i = 0; i < n; i++ )
1405 int i1 = cvRandInt(r) % n;
1406 int i2 = cvRandInt(r) % n;
1407 CV_SWAP( labels[i1], labels[i2], j );
1410 for( iters = 0; iters <= max_iters; iters++ )
1413 for( i = 0; i < k; i++ )
1415 for( j = 0; j < m; j++ )
1419 for( i = 0; i < n; i++ )
1421 const int* v = vectors + i*m;
1422 int* s = csums + labels[i]*m;
1423 for( j = 0; j < m; j++ )
1427 // exit the loop here, when we have up-to-date csums
1428 if( iters == max_iters || !modified )
1433 // calculate weight of each cluster
1434 for( i = 0; i < k; i++ )
1436 const int* s = csums + i*m;
1438 for( j = 0; j < m; j++ )
1440 c_weights[i] = sum ? 1./sum : 0;
1443 // now for each vector determine the closest cluster
1444 for( i = 0; i < n; i++ )
1446 const int* v = vectors + i*m;
1447 double alpha = v_weights[i];
1448 double min_dist2 = DBL_MAX;
1451 for( idx = 0; idx < k; idx++ )
1453 const int* s = csums + idx*m;
1454 double dist2 = 0., beta = c_weights[idx];
1455 for( j = 0; j < m; j++ )
1457 double t = v[j]*alpha - s[j]*beta;
1460 if( min_dist2 > dist2 )
1467 if( min_idx != labels[i] )
1469 labels[i] = min_idx;
1475 CvDTreeSplit* CvDTree::find_split_cat_gini( CvDTreeNode* node, int vi )
1477 CvDTreeSplit* split;
1478 const int* labels = data->get_cat_var_data(node, vi);
1479 const int* responses = data->get_class_labels(node);
1480 int ci = data->get_var_type(vi);
1481 int n = node->sample_count;
1482 int m = data->get_num_classes();
1483 int _mi = data->cat_count->data.i[ci], mi = _mi;
1484 const int* rc0 = data->counts->data.i;
1485 int* lc = (int*)(rc0 + m);
1487 int* _cjk = rc + m*2, *cjk = _cjk;
1488 double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1489 int* cluster_labels = 0;
1492 double L = 0, R = 0;
1493 double best_val = 0;
1494 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1495 const double* priors = data->priors->data.db;
1497 // init array of counters:
1498 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1499 for( j = -1; j < mi; j++ )
1500 for( k = 0; k < m; k++ )
1503 for( i = 0; i < n; i++ )
1506 int k = responses[i];
1512 if( mi > data->params.max_categories )
1514 mi = MIN(data->params.max_categories, n);
1516 cluster_labels = cjk + mi*m;
1517 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1525 int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1526 for( j = 0; j < mi; j++ )
1527 int_ptr[j] = cjk + j*2 + 1;
1528 icvSortIntPtr( int_ptr, mi, 0 );
1533 for( k = 0; k < m; k++ )
1536 for( j = 0; j < mi; j++ )
1537 sum += cjk[j*m + k];
1542 for( j = 0; j < mi; j++ )
1545 for( k = 0; k < m; k++ )
1546 sum += cjk[j*m + k]*priors[k];
1551 for( ; subset_i < subset_n; subset_i++ )
1555 double lsum2 = 0, rsum2 = 0;
1558 idx = (int)(int_ptr[subset_i] - cjk)/2;
1561 int graycode = (subset_i>>1)^subset_i;
1562 int diff = graycode ^ prevcode;
1564 // determine index of the changed bit.
1566 idx = diff >= (1 << 16) ? 16 : 0;
1567 u.f = (float)(((diff >> 16) | diff) & 65535);
1568 idx += (u.i >> 23) - 127;
1569 subtract = graycode < prevcode;
1570 prevcode = graycode;
1574 weight = c_weights[idx];
1575 if( weight < FLT_EPSILON )
1580 for( k = 0; k < m; k++ )
1583 int lval = lc[k] + t;
1584 int rval = rc[k] - t;
1585 double p = priors[k], p2 = p*p;
1586 lsum2 += p2*lval*lval;
1587 rsum2 += p2*rval*rval;
1588 lc[k] = lval; rc[k] = rval;
1595 for( k = 0; k < m; k++ )
1598 int lval = lc[k] - t;
1599 int rval = rc[k] + t;
1600 double p = priors[k], p2 = p*p;
1601 lsum2 += p2*lval*lval;
1602 rsum2 += p2*rval*rval;
1603 lc[k] = lval; rc[k] = rval;
1609 if( L > FLT_EPSILON && R > FLT_EPSILON )
1611 double val = lsum2/L + rsum2/R;
1612 if( best_val < val )
1615 best_subset = subset_i;
1620 if( best_subset < 0 )
1623 split = data->new_split_cat( vi, (float)best_val );
1627 for( i = 0; i <= best_subset; i++ )
1629 idx = (int)(int_ptr[i] - cjk) >> 1;
1630 split->subset[idx >> 5] |= 1 << (idx & 31);
1635 for( i = 0; i < _mi; i++ )
1637 idx = cluster_labels ? cluster_labels[i] : i;
1638 if( best_subset & (1 << idx) )
1639 split->subset[i >> 5] |= 1 << (i & 31);
1647 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1649 const float epsilon = FLT_EPSILON*2;
1650 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1651 const float* responses = data->get_ord_responses(node);
1652 int n = node->sample_count;
1653 int n1 = node->num_valid[vi];
1654 int i, best_i = -1, L = 0, R = n1;
1655 double lsum = 0, rsum, best_val = 0;
1657 rsum = node->value*n;
1658 // compensate for missing values
1659 for( i = n1; i < n; i++ )
1660 rsum -= responses[sorted[i].i];
1662 // find the optimal split
1663 for( i = 0; i < n1 - 1; i++ )
1665 float val = responses[sorted[i].i];
1670 if( sorted[i].val + epsilon < sorted[i+1].val )
1672 double val = lsum*lsum/L + rsum*rsum/R;
1673 if( best_val < val )
1681 return best_i >= 0 ? data->new_split_ord( vi,
1682 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1683 0, (float)best_val ) : 0;
1687 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1689 CvDTreeSplit* split;
1690 const int* labels = data->get_cat_var_data(node, vi);
1691 const float* responses = data->get_ord_responses(node);
1692 int ci = data->get_var_type(vi);
1693 int n = node->sample_count;
1694 int mi = data->cat_count->data.i[ci];
1695 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1696 int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1697 double** sum_ptr = 0;
1698 int i, L = 0, R = 0;
1699 double best_val = 0, lsum = 0, rsum = 0;
1700 int best_subset = -1, subset_i;
1702 for( i = -1; i < mi; i++ )
1703 sum[i] = counts[i] = 0;
1705 // calculate sum response and weight of each category of the input var
1706 for( i = 0; i < n; i++ )
1708 int idx = labels[i];
1709 double s = sum[idx] + responses[i];
1710 int nc = counts[idx] + 1;
1715 // calculate average response in each category
1716 for( i = 0; i < mi; i++ )
1720 sum[i] /= MAX(counts[i],1);
1721 sum_ptr[i] = sum + i;
1724 icvSortDblPtr( sum_ptr, mi, 0 );
1726 // revert back to unnormalized sum
1727 // (there should be a very little loss of accuracy)
1728 for( i = 0; i < mi; i++ )
1729 sum[i] *= counts[i];
1731 for( subset_i = 0; subset_i < mi-1; subset_i++ )
1733 int idx = (int)(sum_ptr[subset_i] - sum);
1734 int ni = counts[idx];
1738 double s = sum[idx];
1744 double val = lsum*lsum/L + rsum*rsum/R;
1745 if( best_val < val )
1748 best_subset = subset_i;
1754 if( best_subset < 0 )
1757 split = data->new_split_cat( vi, (float)best_val );
1758 for( i = 0; i <= best_subset; i++ )
1760 int idx = (int)(sum_ptr[i] - sum);
1761 split->subset[idx >> 5] |= 1 << (idx & 31);
1768 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
1770 const float epsilon = FLT_EPSILON*2;
1771 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1772 const char* dir = (char*)data->direction->data.ptr;
1773 int n1 = node->num_valid[vi];
1774 // LL - number of samples that both the primary and the surrogate splits send to the left
1775 // LR - ... primary split sends to the left and the surrogate split sends to the right
1776 // RL - ... primary split sends to the right and the surrogate split sends to the left
1777 // RR - ... both send to the right
1778 int i, best_i = -1, best_inversed = 0;
1781 if( !data->have_priors )
1783 int LL = 0, RL = 0, LR, RR;
1784 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
1785 int sum = 0, sum_abs = 0;
1787 for( i = 0; i < n1; i++ )
1789 int d = dir[sorted[i].i];
1790 sum += d; sum_abs += d & 1;
1793 // sum_abs = R + L; sum = R - L
1794 RR = (sum_abs + sum) >> 1;
1795 LR = (sum_abs - sum) >> 1;
1797 // initially all the samples are sent to the right by the surrogate split,
1798 // LR of them are sent to the left by primary split, and RR - to the right.
1799 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1800 for( i = 0; i < n1 - 1; i++ )
1802 int d = dir[sorted[i].i];
1807 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1810 best_i = i; best_inversed = 0;
1816 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1819 best_i = i; best_inversed = 1;
1823 best_val = _best_val;
1827 double LL = 0, RL = 0, LR, RR;
1828 double worst_val = node->maxlr;
1829 double sum = 0, sum_abs = 0;
1830 const double* priors = data->priors->data.db;
1831 const int* responses = data->get_class_labels(node);
1832 best_val = worst_val;
1834 for( i = 0; i < n1; i++ )
1836 int idx = sorted[i].i;
1837 double w = priors[responses[idx]];
1839 sum += d*w; sum_abs += (d & 1)*w;
1842 // sum_abs = R + L; sum = R - L
1843 RR = (sum_abs + sum)*0.5;
1844 LR = (sum_abs - sum)*0.5;
1846 // initially all the samples are sent to the right by the surrogate split,
1847 // LR of them are sent to the left by primary split, and RR - to the right.
1848 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1849 for( i = 0; i < n1 - 1; i++ )
1851 int idx = sorted[i].i;
1852 double w = priors[responses[idx]];
1858 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1861 best_i = i; best_inversed = 0;
1867 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1870 best_i = i; best_inversed = 1;
1876 return best_i >= 0 ? data->new_split_ord( vi,
1877 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1878 best_inversed, (float)best_val ) : 0;
1882 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
1884 const int* labels = data->get_cat_var_data(node, vi);
1885 const char* dir = (char*)data->direction->data.ptr;
1886 int n = node->sample_count;
1887 // LL - number of samples that both the primary and the surrogate splits send to the left
1888 // LR - ... primary split sends to the left and the surrogate split sends to the right
1889 // RL - ... primary split sends to the right and the surrogate split sends to the left
1890 // RR - ... both send to the right
1891 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
1892 int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
1893 double best_val = 0;
1894 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
1895 double* rc = lc + mi + 1;
1897 for( i = -1; i < mi; i++ )
1900 // for each category calculate the weight of samples
1901 // sent to the left (lc) and to the right (rc) by the primary split
1902 if( !data->have_priors )
1904 int* _lc = data->counts->data.i + 1;
1905 int* _rc = _lc + mi + 1;
1907 for( i = -1; i < mi; i++ )
1908 _lc[i] = _rc[i] = 0;
1910 for( i = 0; i < n; i++ )
1912 int idx = labels[i];
1914 int sum = _lc[idx] + d;
1915 int sum_abs = _rc[idx] + (d & 1);
1916 _lc[idx] = sum; _rc[idx] = sum_abs;
1919 for( i = 0; i < mi; i++ )
1922 int sum_abs = _rc[i];
1923 lc[i] = (sum_abs - sum) >> 1;
1924 rc[i] = (sum_abs + sum) >> 1;
1929 const double* priors = data->priors->data.db;
1930 const int* responses = data->get_class_labels(node);
1932 for( i = 0; i < n; i++ )
1934 int idx = labels[i];
1935 double w = priors[responses[i]];
1937 double sum = lc[idx] + d*w;
1938 double sum_abs = rc[idx] + (d & 1)*w;
1939 lc[idx] = sum; rc[idx] = sum_abs;
1942 for( i = 0; i < mi; i++ )
1945 double sum_abs = rc[i];
1946 lc[i] = (sum_abs - sum) * 0.5;
1947 rc[i] = (sum_abs + sum) * 0.5;
1951 // 2. now form the split.
1952 // in each category send all the samples to the same direction as majority
1953 for( i = 0; i < mi; i++ )
1955 double lval = lc[i], rval = rc[i];
1958 split->subset[i >> 5] |= 1 << (i & 31);
1965 split->quality = (float)best_val;
1966 if( split->quality <= node->maxlr )
1967 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
1973 void CvDTree::calc_node_value( CvDTreeNode* node )
1975 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
1976 const int* cv_labels = data->get_cv_labels(node);
1978 if( data->is_classifier )
1980 // in case of classification tree:
1981 // * node value is the label of the class that has the largest weight in the node.
1982 // * node risk is the weighted number of misclassified samples,
1983 // * j-th cross-validation fold value and risk are calculated as above,
1984 // but using the samples with cv_labels(*)!=j.
1985 // * j-th cross-validation fold error is calculated as the weighted number of
1986 // misclassified samples with cv_labels(*)==j.
1988 // compute the number of instances of each class
1989 int* cls_count = data->counts->data.i;
1990 const int* responses = data->get_class_labels(node);
1991 int m = data->get_num_classes();
1992 int* cv_cls_count = cls_count + m;
1993 double max_val = -1, total_weight = 0;
1995 double* priors = data->priors->data.db;
1997 for( k = 0; k < m; k++ )
2002 for( i = 0; i < n; i++ )
2003 cls_count[responses[i]]++;
2007 for( j = 0; j < cv_n; j++ )
2008 for( k = 0; k < m; k++ )
2009 cv_cls_count[j*m + k] = 0;
2011 for( i = 0; i < n; i++ )
2013 j = cv_labels[i]; k = responses[i];
2014 cv_cls_count[j*m + k]++;
2017 for( j = 0; j < cv_n; j++ )
2018 for( k = 0; k < m; k++ )
2019 cls_count[k] += cv_cls_count[j*m + k];
2022 for( k = 0; k < m; k++ )
2024 double val = cls_count[k]*priors[k];
2025 total_weight += val;
2033 node->class_idx = max_k;
2034 node->value = data->cat_map->data.i[
2035 data->cat_ofs->data.i[data->cat_var_count] + max_k];
2036 node->node_risk = total_weight - max_val;
2038 for( j = 0; j < cv_n; j++ )
2040 double sum_k = 0, sum = 0, max_val_k = 0;
2041 max_val = -1; max_k = -1;
2043 for( k = 0; k < m; k++ )
2045 double w = priors[k];
2046 double val_k = cv_cls_count[j*m + k]*w;
2047 double val = cls_count[k]*w - val_k;
2058 node->cv_Tn[j] = INT_MAX;
2059 node->cv_node_risk[j] = sum - max_val;
2060 node->cv_node_error[j] = sum_k - max_val_k;
2065 // in case of regression tree:
2066 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2067 // n is the number of samples in the node.
2068 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2069 // * j-th cross-validation fold value and risk are calculated as above,
2070 // but using the samples with cv_labels(*)!=j.
2071 // * j-th cross-validation fold error is calculated
2072 // using samples with cv_labels(*)==j as the test subset:
2073 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2074 // where node_value_j is the node value calculated
2075 // as described in the previous bullet, and summation is done
2076 // over the samples with cv_labels(*)==j.
2078 double sum = 0, sum2 = 0;
2079 const float* values = data->get_ord_responses(node);
2080 double *cv_sum = 0, *cv_sum2 = 0;
2085 // if cross-validation is not used, we even do not compute node_risk
2086 // (so the tree sequence T1>...>{root} may not be built).
2087 for( i = 0; i < n; i++ )
2092 cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2093 cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2094 cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2096 for( j = 0; j < cv_n; j++ )
2098 cv_sum[j] = cv_sum2[j] = 0.;
2102 for( i = 0; i < n; i++ )
2105 double t = values[i];
2106 double s = cv_sum[j] + t;
2107 double s2 = cv_sum2[j] + t*t;
2108 int nc = cv_count[j] + 1;
2114 for( j = 0; j < cv_n; j++ )
2120 node->node_risk = sum2 - (sum/n)*sum;
2123 node->value = sum/n;
2125 for( j = 0; j < cv_n; j++ )
2127 double s = cv_sum[j], si = sum - s;
2128 double s2 = cv_sum2[j], s2i = sum2 - s2;
2129 int c = cv_count[j], ci = n - c;
2130 double r = si/MAX(ci,1);
2131 node->cv_node_risk[j] = s2i - r*r*ci;
2132 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2133 node->cv_Tn[j] = INT_MAX;
2139 void CvDTree::split_node_data( CvDTreeNode* node )
2141 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2142 int nz = n - node->num_valid[node->split->var_idx];
2143 char* dir = (char*)data->direction->data.ptr;
2144 CvDTreeNode *left = 0, *right = 0;
2145 int* new_idx = data->split_buf->data.i;
2146 int new_buf_idx = data->get_child_buf_idx( node );
2148 // try to complete direction using surrogate splits
2149 if( nz && data->params.use_surrogates )
2151 CvDTreeSplit* split = node->split->next;
2152 for( ; split != 0 && nz; split = split->next )
2154 int inversed_mask = split->inversed ? -1 : 0;
2155 vi = split->var_idx;
2157 if( data->get_var_type(vi) >= 0 ) // split on categorical var
2159 const int* labels = data->get_cat_var_data(node, vi);
2160 const int* subset = split->subset;
2162 for( i = 0; i < n; i++ )
2165 if( !dir[i] && (idx = labels[i]) >= 0 )
2167 int d = DTREE_CAT_DIR(idx,subset);
2168 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2174 else // split on ordered var
2176 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2177 int split_point = split->ord.split_point;
2178 int n1 = node->num_valid[vi];
2180 assert( 0 <= split_point && split_point < n-1 );
2182 for( i = 0; i < n1; i++ )
2184 int idx = sorted[i].i;
2187 int d = i <= split_point ? -1 : 1;
2188 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2197 // find the default direction for the rest
2200 for( i = nr = 0; i < n; i++ )
2203 d0 = nl > nr ? -1 : nr > nl;
2206 // make sure that every sample is directed either to the left or to the right
2207 for( i = nl = nr = 0; i < n; i++ )
2217 dir[i] = (char)d; // remap (-1,1) to (0,1)
2219 // initialize new indices for splitting ordered variables
2220 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2225 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2226 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2227 (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
2229 // split ordered variables, keep both halves sorted.
2230 for( vi = 0; vi < data->var_count; vi++ )
2232 int ci = data->get_var_type(vi);
2233 int n1 = node->num_valid[vi];
2234 CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2235 CvPair32s32f tl, tr;
2240 src = data->get_ord_var_data(node, vi);
2241 ldst0 = ldst = data->get_ord_var_data(left, vi);
2242 rdst0 = rdst = data->get_ord_var_data(right, vi);
2243 tl = ldst0[nl]; tr = rdst0[nr];
2246 for( i = 0; i < n1; i++ )
2249 float val = src[i].val;
2252 ldst->i = rdst->i = idx;
2253 ldst->val = rdst->val = val;
2258 left->num_valid[vi] = (int)(ldst - ldst0);
2259 right->num_valid[vi] = (int)(rdst - rdst0);
2267 ldst->i = rdst->i = idx;
2268 ldst->val = rdst->val = ord_nan;
2273 ldst0[nl] = tl; rdst0[nr] = tr;
2276 // split categorical vars, responses and cv_labels using new_idx relocation table
2277 for( vi = 0; vi <= data->var_count + data->have_cv_labels; vi++ )
2279 int ci = data->get_var_type(vi);
2280 int n1 = node->num_valid[vi], nr1 = 0;
2281 int *src, *ldst0, *rdst0, *ldst, *rdst;
2287 src = data->get_cat_var_data(node, vi);
2288 ldst0 = ldst = data->get_cat_var_data(left, vi);
2289 rdst0 = rdst = data->get_cat_var_data(right, vi);
2290 tl = ldst0[nl]; tr = rdst0[nr];
2292 for( i = 0; i < n; i++ )
2296 *ldst = *rdst = val;
2299 nr1 += (val >= 0)&d;
2302 if( vi < data->var_count )
2304 left->num_valid[vi] = n1 - nr1;
2305 right->num_valid[vi] = nr1;
2308 ldst0[nl] = tl; rdst0[nr] = tr;
2311 // deallocate the parent node data that is not needed anymore
2312 data->free_node_data(node);
2316 void CvDTree::prune_cv()
2322 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2323 // 2. choose the best tree index (if need, apply 1SE rule).
2324 // 3. store the best index and cut the branches.
2326 CV_FUNCNAME( "CvDTree::prune_cv" );
2330 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2331 // currently, 1SE for regression is not implemented
2332 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2334 double min_err = 0, min_err_se = 0;
2337 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2339 // build the main tree sequence, calculate alpha's
2342 double min_alpha = update_tree_rnc(tree_count, -1);
2343 if( cut_tree(tree_count, -1, min_alpha) )
2346 if( ab->cols <= tree_count )
2348 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2349 for( ti = 0; ti < ab->cols; ti++ )
2350 temp->data.db[ti] = ab->data.db[ti];
2351 cvReleaseMat( &ab );
2356 ab->data.db[tree_count] = min_alpha;
2359 ab->data.db[0] = 0.;
2360 for( ti = 1; ti < tree_count-1; ti++ )
2361 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2362 ab->data.db[tree_count-1] = DBL_MAX*0.5;
2364 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2365 err = err_jk->data.db;
2367 for( j = 0; j < cv_n; j++ )
2370 for( ; tk < tree_count; tj++ )
2372 double min_alpha = update_tree_rnc(tj, j);
2373 if( cut_tree(tj, j, min_alpha) )
2374 min_alpha = DBL_MAX;
2376 for( ; tk < tree_count; tk++ )
2378 if( ab->data.db[tk] > min_alpha )
2380 err[j*tree_count + tk] = root->tree_error;
2385 for( ti = 0; ti < tree_count; ti++ )
2388 for( j = 0; j < cv_n; j++ )
2389 sum_err += err[j*tree_count + ti];
2390 if( ti == 0 || sum_err < min_err )
2395 min_err_se = sqrt( sum_err*(n - sum_err) );
2397 else if( sum_err < min_err + min_err_se )
2401 pruned_tree_idx = min_idx;
2402 free_prune_data(data->params.truncate_pruned_tree != 0);
2406 cvReleaseMat( &err_jk );
2407 cvReleaseMat( &ab );
2408 cvReleaseMat( &temp );
2412 double CvDTree::update_tree_rnc( int T, int fold )
2414 CvDTreeNode* node = root;
2415 double min_alpha = DBL_MAX;
2419 CvDTreeNode* parent;
2422 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2423 if( t <= T || !node->left )
2425 node->complexity = 1;
2426 node->tree_risk = node->node_risk;
2427 node->tree_error = 0.;
2430 node->tree_risk = node->cv_node_risk[fold];
2431 node->tree_error = node->cv_node_error[fold];
2438 for( parent = node->parent; parent && parent->right == node;
2439 node = parent, parent = parent->parent )
2441 parent->complexity += node->complexity;
2442 parent->tree_risk += node->tree_risk;
2443 parent->tree_error += node->tree_error;
2445 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2446 - parent->tree_risk)/(parent->complexity - 1);
2447 min_alpha = MIN( min_alpha, parent->alpha );
2453 parent->complexity = node->complexity;
2454 parent->tree_risk = node->tree_risk;
2455 parent->tree_error = node->tree_error;
2456 node = parent->right;
2463 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2465 CvDTreeNode* node = root;
2471 CvDTreeNode* parent;
2474 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2475 if( t <= T || !node->left )
2477 if( node->alpha <= min_alpha + FLT_EPSILON )
2480 node->cv_Tn[fold] = T;
2490 for( parent = node->parent; parent && parent->right == node;
2491 node = parent, parent = parent->parent )
2497 node = parent->right;
2504 void CvDTree::free_prune_data(bool cut_tree)
2506 CvDTreeNode* node = root;
2510 CvDTreeNode* parent;
2513 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2514 // as we will clear the whole cross-validation heap at the end
2516 node->cv_node_error = node->cv_node_risk = 0;
2522 for( parent = node->parent; parent && parent->right == node;
2523 node = parent, parent = parent->parent )
2525 if( cut_tree && parent->Tn <= pruned_tree_idx )
2527 data->free_node( parent->left );
2528 data->free_node( parent->right );
2529 parent->left = parent->right = 0;
2536 node = parent->right;
2540 cvClearSet( data->cv_heap );
2544 void CvDTree::free_tree()
2546 if( root && data && data->shared )
2548 pruned_tree_idx = INT_MIN;
2549 free_prune_data(true);
2550 data->free_node(root);
2556 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2557 const CvMat* _missing, bool preprocessed_input ) const
2559 CvDTreeNode* result = 0;
2562 CV_FUNCNAME( "CvDTree::predict" );
2566 int i, step, mstep = 0;
2567 const float* sample;
2569 CvDTreeNode* node = root;
2576 CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2578 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2579 _sample->cols != 1 && _sample->rows != 1 ||
2580 _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2581 _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2582 CV_ERROR( CV_StsBadArg,
2583 "the input sample must be 1d floating-point vector with the same "
2584 "number of elements as the total number of variables used for training" );
2586 sample = _sample->data.fl;
2587 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
2589 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2591 int n = data->cat_count->cols;
2592 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2593 for( i = 0; i < n; i++ )
2599 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2600 !CV_ARE_SIZES_EQ(_missing, _sample) )
2601 CV_ERROR( CV_StsBadArg,
2602 "the missing data mask must be 8-bit vector of the same size as input sample" );
2603 m = _missing->data.ptr;
2604 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2607 vtype = data->var_type->data.i;
2608 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2609 cmap = data->cat_map->data.i;
2610 cofs = data->cat_ofs->data.i;
2612 while( node->Tn > pruned_tree_idx && node->left )
2614 CvDTreeSplit* split = node->split;
2616 for( ; !dir && split != 0; split = split->next )
2618 int vi = split->var_idx;
2620 int i = vidx ? vidx[vi] : vi;
2621 float val = sample[i*step];
2622 if( m && m[i*mstep] )
2624 if( ci < 0 ) // ordered
2625 dir = val <= split->ord.c ? -1 : 1;
2629 if( preprocessed_input )
2636 int a = c = cofs[ci];
2638 int ival = cvRound(val);
2640 CV_ERROR( CV_StsBadArg,
2641 "one of input categorical variable is not an integer" );
2646 if( ival < cmap[c] )
2648 else if( ival > cmap[c] )
2654 if( c < 0 || ival != cmap[c] )
2657 catbuf[ci] = c -= cofs[ci];
2660 dir = DTREE_CAT_DIR(c, split->subset);
2663 if( split->inversed )
2669 double diff = node->right->sample_count - node->left->sample_count;
2670 dir = diff < 0 ? -1 : 1;
2672 node = dir < 0 ? node->left : node->right;
2683 const CvMat* CvDTree::get_var_importance()
2685 if( !var_importance )
2687 CvDTreeNode* node = root;
2691 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
2692 cvZero( var_importance );
2693 importance = var_importance->data.db;
2697 CvDTreeNode* parent;
2698 for( ;; node = node->left )
2700 CvDTreeSplit* split = node->split;
2702 if( !node->left || node->Tn <= pruned_tree_idx )
2705 for( ; split != 0; split = split->next )
2706 importance[split->var_idx] += split->quality;
2709 for( parent = node->parent; parent && parent->right == node;
2710 node = parent, parent = parent->parent )
2716 node = parent->right;
2720 return var_importance;
2724 void CvDTree::save( const char* filename, const char* name )
2726 CvFileStorage* fs = 0;
2728 CV_FUNCNAME( "CvDTree::save" );
2732 CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_WRITE ));
2734 CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2736 write( fs, name ? name : "my_dtree" );
2740 cvReleaseFileStorage( &fs );
2744 void CvDTree::write_train_data_params( CvFileStorage* fs )
2746 CV_FUNCNAME( "CvDTree::write_train_data_params" );
2750 int vi, vcount = data->var_count;
2752 cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
2753 cvWriteInt( fs, "var_all", data->var_all );
2754 cvWriteInt( fs, "var_count", data->var_count );
2755 cvWriteInt( fs, "ord_var_count", data->ord_var_count );
2756 cvWriteInt( fs, "cat_var_count", data->cat_var_count );
2758 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
2759 cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
2761 if( data->is_classifier )
2763 cvWriteInt( fs, "max_categories", data->params.max_categories );
2767 cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
2770 cvWriteInt( fs, "max_depth", data->params.max_depth );
2771 cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
2772 cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
2774 if( data->params.cv_folds > 1 )
2776 cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
2777 cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
2781 cvWrite( fs, "priors", data->priors );
2783 cvEndWriteStruct( fs );
2786 cvWrite( fs, "var_idx", data->var_idx );
2788 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
2790 for( vi = 0; vi < vcount; vi++ )
2791 cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
2793 cvEndWriteStruct( fs );
2795 if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
2797 CV_ASSERT( data->cat_count != 0 );
2798 cvWrite( fs, "cat_count", data->cat_count );
2799 cvWrite( fs, "cat_map", data->cat_map );
2806 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
2810 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
2811 cvWriteInt( fs, "var", split->var_idx );
2812 cvWriteReal( fs, "quality", split->quality );
2814 ci = data->get_var_type(split->var_idx);
2815 if( ci >= 0 ) // split on a categorical var
2817 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
2818 for( i = 0; i < n; i++ )
2819 to_right += DTREE_CAT_DIR(i,split->subset) > 0;
2821 // ad-hoc rule when to use inverse categorical split notation
2822 // to achieve more compact and clear representation
2823 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
2825 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
2826 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
2827 for( i = 0; i < n; i++ )
2829 int dir = DTREE_CAT_DIR(i,split->subset);
2830 if( dir*default_dir < 0 )
2831 cvWriteInt( fs, 0, i );
2833 cvEndWriteStruct( fs );
2836 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
2838 cvEndWriteStruct( fs );
2842 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
2844 CvDTreeSplit* split;
2846 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2848 cvWriteInt( fs, "depth", node->depth );
2849 cvWriteInt( fs, "sample_count", node->sample_count );
2850 cvWriteReal( fs, "value", node->value );
2852 if( data->is_classifier )
2853 cvWriteInt( fs, "norm_class_idx", node->class_idx );
2855 cvWriteInt( fs, "Tn", node->Tn );
2856 cvWriteInt( fs, "complexity", node->complexity );
2857 cvWriteReal( fs, "alpha", node->alpha );
2858 cvWriteReal( fs, "node_risk", node->node_risk );
2859 cvWriteReal( fs, "tree_risk", node->tree_risk );
2860 cvWriteReal( fs, "tree_error", node->tree_error );
2864 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
2866 for( split = node->split; split != 0; split = split->next )
2867 write_split( fs, split );
2869 cvEndWriteStruct( fs );
2872 cvEndWriteStruct( fs );
2876 void CvDTree::write_tree_nodes( CvFileStorage* fs )
2878 CV_FUNCNAME( "CvDTree::write_tree_nodes" );
2882 CvDTreeNode* node = root;
2884 // traverse the tree and save all the nodes in depth-first order
2887 CvDTreeNode* parent;
2890 write_node( fs, node );
2896 for( parent = node->parent; parent && parent->right == node;
2897 node = parent, parent = parent->parent )
2903 node = parent->right;
2910 void CvDTree::write( CvFileStorage* fs, const char* name )
2912 CV_FUNCNAME( "CvDTree::write" );
2916 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
2918 write_train_data_params( fs );
2920 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
2921 get_var_importance();
2922 cvWrite( fs, "var_importance", var_importance );
2924 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
2925 write_tree_nodes( fs );
2926 cvEndWriteStruct( fs );
2928 cvEndWriteStruct( fs );
2934 void CvDTree::load( const char* filename, const char* name )
2936 CvFileStorage* fs = 0;
2938 CV_FUNCNAME( "CvDTree::load" );
2942 CvFileNode* tree = 0;
2944 CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_READ ));
2946 CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2949 tree = cvGetFileNodeByName( fs, 0, name );
2952 CvFileNode* root = cvGetRootFileNode( fs );
2953 if( root->data.seq->total > 0 )
2954 tree = (CvFileNode*)cvGetSeqElem( root->data.seq, 0 );
2961 cvReleaseFileStorage( &fs );
2965 void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
2967 CV_FUNCNAME( "CvDTree::read_train_data_params" );
2971 CvDTreeParams params;
2972 CvFileNode *tparams_node, *vartype_node;
2974 int is_classifier, vi, cat_var_count, ord_var_count;
2975 int max_split_size, tree_block_size;
2977 data = new CvDTreeTrainData;
2979 data->is_classifier = is_classifier = cvReadIntByName( fs, node, "is_classifier" ) != 0;
2980 data->var_all = cvReadIntByName( fs, node, "var_all" );
2981 data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
2982 data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
2983 data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
2985 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
2987 if( tparams_node ) // training parameters are not necessary
2989 data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
2993 data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
2997 data->params.regression_accuracy =
2998 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
3001 data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
3002 data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
3003 data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
3005 if( data->params.cv_folds > 1 )
3007 data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
3008 data->params.truncate_pruned_tree =
3009 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
3012 data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
3013 if( data->priors && !CV_IS_MAT(data->priors) )
3014 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
3017 CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
3020 if( !CV_IS_MAT(data->var_idx) ||
3021 data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
3022 data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
3023 CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
3024 CV_ERROR( CV_StsParseError,
3025 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
3027 for( vi = 0; vi < data->var_count; vi++ )
3028 if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
3029 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
3032 ////// read var type
3033 CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
3035 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
3036 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
3037 vartype_node->data.seq->total != data->var_count )
3038 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3040 cvStartReadSeq( vartype_node->data.seq, &reader );
3044 for( vi = 0; vi < data->var_count; vi++ )
3046 CvFileNode* n = (CvFileNode*)reader.ptr;
3047 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
3048 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3049 data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
3050 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3053 ord_var_count = ~ord_var_count;
3054 if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
3055 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
3058 if( data->cat_var_count > 0 || is_classifier )
3060 int ccount, max_c_count = 0, total_c_count = 0;
3061 CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
3062 CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
3064 if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
3065 data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
3066 CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
3067 data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
3068 data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
3069 CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
3070 CV_ERROR( CV_StsParseError,
3071 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
3073 ccount = cat_var_count + is_classifier;
3075 CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
3076 data->cat_ofs->data.i[0] = 0;
3078 for( vi = 0; vi < ccount; vi++ )
3080 int val = data->cat_count->data.i[vi];
3082 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
3083 max_c_count = MAX( max_c_count, val );
3084 data->cat_ofs->data.i[vi+1] = total_c_count += val;
3087 if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
3088 CV_ERROR( CV_StsBadSize,
3089 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
3091 data->max_c_count = max_c_count;
3094 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
3095 (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
3097 tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
3098 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
3099 CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
3100 CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
3101 sizeof(CvDTreeNode), data->tree_storage ));
3102 CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
3103 max_split_size, data->tree_storage ));
3109 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3111 CvDTreeSplit* split = 0;
3113 CV_FUNCNAME( "CvDTree::read_split" );
3119 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3120 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3122 vi = cvReadIntByName( fs, fnode, "var", -1 );
3123 if( (unsigned)vi >= (unsigned)data->var_count )
3124 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3126 ci = data->get_var_type(vi);
3127 if( ci >= 0 ) // split on categorical var
3129 int i, n = data->cat_count->data.i[ci], inversed = 0;
3132 split = data->new_split_cat( vi, 0 );
3133 inseq = cvGetFileNodeByName( fs, fnode, "in" );
3136 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3139 if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
3140 CV_ERROR( CV_StsParseError,
3141 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3143 cvStartReadSeq( inseq->data.seq, &reader );
3145 for( i = 0; i < reader.seq->total; i++ )
3147 CvFileNode* inode = (CvFileNode*)reader.ptr;
3148 int val = inode->data.i;
3149 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3150 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3152 split->subset[val >> 5] |= 1 << (val & 31);
3153 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3156 // for categorical splits we do not use inversed splits,
3157 // instead we inverse the variable set in the split
3159 for( i = 0; i < (n + 31) >> 5; i++ )
3160 split->subset[i] ^= -1;
3164 CvFileNode* cmp_node;
3165 split = data->new_split_ord( vi, 0, 0, 0, 0 );
3167 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3170 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3171 split->inversed = 1;
3174 split->ord.c = (float)cvReadReal( cmp_node );
3177 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3185 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3187 CvDTreeNode* node = 0;
3189 CV_FUNCNAME( "CvDTree::read_node" );
3196 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3197 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3199 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3200 depth = cvReadIntByName( fs, fnode, "depth", -1 );
3201 if( depth != node->depth )
3202 CV_ERROR( CV_StsParseError, "incorrect node depth" );
3204 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3205 node->value = cvReadRealByName( fs, fnode, "value" );
3206 if( data->is_classifier )
3207 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3209 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3210 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3211 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3212 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3213 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3214 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3216 splits = cvGetFileNodeByName( fs, fnode, "splits" );
3220 CvDTreeSplit* last_split = 0;
3222 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3223 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3225 cvStartReadSeq( splits->data.seq, &reader );
3226 for( i = 0; i < reader.seq->total; i++ )
3228 CvDTreeSplit* split;
3229 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3231 node->split = last_split = split;
3233 last_split = last_split->next = split;
3235 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3245 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3247 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3253 CvDTreeNode* parent = &_root;
3255 parent->left = parent->right = parent->parent = 0;
3257 cvStartReadSeq( fnode->data.seq, &reader );
3259 for( i = 0; i < reader.seq->total; i++ )
3263 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3265 parent->left = node;
3267 parent->right = node;
3272 while( parent && parent->right )
3273 parent = parent->parent;
3276 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3285 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3287 CV_FUNCNAME( "CvDTree::read" );
3291 CvFileNode* tree_nodes;
3294 read_train_data_params( fs, fnode );
3296 tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
3297 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3298 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3300 pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
3302 read_tree_nodes( fs, tree_nodes );
3303 get_var_importance(); // recompute variable importance