1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 static const float ord_var_epsilon = FLT_EPSILON*2;
44 static const float ord_nan = FLT_MAX*0.5f;
45 static const int min_block_size = 1 << 16;
46 static const int block_size_delta = 1 << 10;
48 CvDTreeTrainData::CvDTreeTrainData()
50 var_idx = var_type = cat_count = cat_ofs = cat_map =
51 priors = counts = buf = direction = split_buf = 0;
52 tree_storage = temp_storage = 0;
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
59 const CvMat* _responses, const CvMat* _var_idx,
60 const CvMat* _sample_idx, const CvMat* _var_type,
61 const CvMat* _missing_mask,
62 CvDTreeParams _params, bool _shared, bool _add_weights )
64 var_idx = var_type = cat_count = cat_ofs = cat_map =
65 priors = counts = buf = direction = split_buf = 0;
66 tree_storage = temp_storage = 0;
68 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
69 _var_type, _missing_mask, _params, _shared, _add_weights );
73 CvDTreeTrainData::~CvDTreeTrainData()
79 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
83 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
90 if( params.max_categories < 2 )
91 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
92 params.max_categories = MIN( params.max_categories, 15 );
94 if( params.max_depth < 0 )
95 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
96 params.max_depth = MIN( params.max_depth, 25 );
98 params.min_sample_count = MAX(params.min_sample_count,1);
100 if( params.cv_folds < 0 )
101 CV_ERROR( CV_StsOutOfRange,
102 "params.cv_folds should be =0 (the tree is not pruned) "
103 "or n>0 (tree is pruned using n-fold cross-validation)" );
105 if( params.cv_folds == 1 )
108 if( params.regression_accuracy < 0 )
109 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
123 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
124 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
126 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
127 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
128 const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
129 bool _shared, bool _add_weights )
131 CvMat* sample_idx = 0;
132 CvMat* var_type0 = 0;
136 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
140 int sample_all = 0, r_type = 0, cv_n;
141 int total_c_count = 0;
142 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
146 const int *sidx = 0, *vidx = 0;
153 CV_CALL( set_params( _params ));
155 // check parameter types and sizes
156 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
157 if( _tflag == CV_ROW_SAMPLE )
159 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
162 ms_step = _missing_mask->step, mv_step = 1;
166 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
169 mv_step = _missing_mask->step, ms_step = 1;
172 sample_count = sample_all;
177 CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
178 sidx = sample_idx->data.i;
179 sample_count = sample_idx->rows + sample_idx->cols - 1;
184 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
185 vidx = var_idx->data.i;
186 var_count = var_idx->rows + var_idx->cols - 1;
189 if( !CV_IS_MAT(_responses) ||
190 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
191 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
192 _responses->rows != 1 && _responses->cols != 1 ||
193 _responses->rows + _responses->cols - 1 != sample_all )
194 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
195 "floating-point vector containing as many elements as "
196 "the total number of samples in the training data matrix" );
198 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
199 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
204 is_classifier = r_type == CV_VAR_CATEGORICAL;
206 // step 0. calc the number of categorical vars
207 for( vi = 0; vi < var_count; vi++ )
209 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
210 cat_var_count++ : ord_var_count--;
213 ord_var_count = ~ord_var_count;
214 cv_n = params.cv_folds;
215 // set the two last elements of var_type array to be able
216 // to locate responses and cross-validation labels using
217 // the corresponding get_* functions.
218 var_type->data.i[var_count] = cat_var_count;
219 var_type->data.i[var_count+1] = cat_var_count+1;
221 // in case of single ordered predictor we need dummy cv_labels
222 // for safe split_node_data() operation
223 have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
224 have_weights = _add_weights;
226 buf_size = (ord_var_count*2 + cat_var_count + 1 +
227 (have_cv_labels ? 1 : 0) + (have_weights ? 1 : 0))*sample_count + 2;
229 buf_count = shared ? 3 : 2;
230 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
231 CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
232 CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
233 CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
235 // now calculate the maximum size of split,
236 // create memory storage that will keep nodes and splits of the decision tree
237 // allocate root node and the buffer for the whole training data
238 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
239 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
240 tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
241 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
242 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
243 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
245 temp_block_size = nv_size = var_count*sizeof(int);
248 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
249 CV_ERROR( CV_StsOutOfRange,
250 "The many folds in cross-validation for such a small dataset" );
252 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
253 temp_block_size = MAX(temp_block_size, cv_size);
256 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
257 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
258 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
260 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
262 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
263 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
267 // transform the training data to convenient representation
268 for( vi = 0; vi <= var_count; vi++ )
271 const uchar* mask = 0;
272 int m_step = 0, step;
273 const int* idata = 0;
274 const float* fdata = 0;
277 if( vi < var_count ) // analyze i-th input variable
279 int vi0 = vidx ? vidx[vi] : vi;
280 ci = get_var_type(vi);
281 step = ds_step; m_step = ms_step;
282 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
283 idata = _train_data->data.i + vi0*dv_step;
285 fdata = _train_data->data.fl + vi0*dv_step;
287 mask = _missing_mask->data.ptr + vi0*mv_step;
289 else // analyze _responses
292 step = CV_IS_MAT_CONT(_responses->type) ?
293 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
294 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
295 idata = _responses->data.i;
297 fdata = _responses->data.fl;
300 if( vi < var_count && ci >= 0 ||
301 vi == var_count && is_classifier ) // process categorical variable or response
303 int c_count, prev_label, prev_i;
304 int* c_map, *dst = get_cat_var_data( data_root, vi );
307 for( i = 0; i < sample_count; i++ )
309 int val = INT_MAX, si = sidx ? sidx[i] : i;
310 if( !mask || !mask[si*m_step] )
313 val = idata[si*step];
316 float t = fdata[si*step];
320 sprintf( err, "%d-th value of %d-th (categorical) "
321 "variable is not an integer", i, vi );
322 CV_ERROR( CV_StsBadArg, err );
328 sprintf( err, "%d-th value of %d-th (categorical) "
329 "variable is too large", i, vi );
330 CV_ERROR( CV_StsBadArg, err );
335 int_ptr[i] = dst + i;
338 // sort all the values, including the missing measurements
339 // that should all move to the end
340 icvSortIntPtr( int_ptr, sample_count, 0 );
341 //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
343 c_count = num_valid > 0;
345 // count the categories
346 for( i = 1; i < num_valid; i++ )
347 c_count += *int_ptr[i] != *int_ptr[i-1];
350 max_c_count = MAX( max_c_count, c_count );
351 cat_count->data.i[ci] = c_count;
352 cat_ofs->data.i[ci] = total_c_count;
354 // resize cat_map, if need
355 if( cat_map->cols < total_c_count + c_count )
358 CV_CALL( cat_map = cvCreateMat( 1,
359 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
360 for( i = 0; i < total_c_count; i++ )
361 cat_map->data.i[i] = tmp_map->data.i[i];
362 cvReleaseMat( &tmp_map );
365 c_map = cat_map->data.i + total_c_count;
366 total_c_count += c_count;
368 // compact the class indices and build the map
369 prev_label = ~*int_ptr[0];
372 for( i = 0, prev_i = -1; i < num_valid; i++ )
374 int cur_label = *int_ptr[i];
375 if( cur_label != prev_label )
377 c_map[++c_count] = prev_label = cur_label;
380 *int_ptr[i] = c_count;
383 // replace labels for missing values with -1
384 for( ; i < sample_count; i++ )
387 else if( ci < 0 ) // process ordered variable
389 CvPair32s32f* dst = get_ord_var_data( data_root, vi );
391 for( i = 0; i < sample_count; i++ )
394 int si = sidx ? sidx[i] : i;
395 if( !mask || !mask[si*m_step] )
398 val = (float)idata[si*step];
400 val = fdata[si*step];
402 if( fabs(val) >= ord_nan )
404 sprintf( err, "%d-th value of %d-th (ordered) "
405 "variable (=%g) is too large", i, vi, val );
406 CV_ERROR( CV_StsBadArg, err );
414 icvSortPairs( dst, sample_count, 0 );
416 else // special case: process ordered response,
417 // it will be stored similarly to categorical vars (i.e. no pairs)
419 float* dst = get_ord_responses( data_root );
421 for( i = 0; i < sample_count; i++ )
424 int si = sidx ? sidx[i] : i;
426 val = (float)idata[si*step];
428 val = fdata[si*step];
430 if( fabs(val) >= ord_nan )
432 sprintf( err, "%d-th value of %d-th (ordered) "
433 "variable (=%g) is out of range", i, vi, val );
434 CV_ERROR( CV_StsBadArg, err );
439 cat_count->data.i[cat_var_count] = 0;
440 cat_ofs->data.i[cat_var_count] = total_c_count;
441 num_valid = sample_count;
445 data_root->set_num_valid(vi, num_valid);
450 int* dst = get_cv_labels(data_root);
453 for( i = vi = 0; i < sample_count; i++ )
456 vi &= vi < cv_n ? -1 : 0;
459 for( i = 0; i < sample_count; i++ )
461 int a = cvRandInt(r) % sample_count;
462 int b = cvRandInt(r) % sample_count;
463 CV_SWAP( dst[a], dst[b], vi );
467 cat_map->cols = MAX( total_c_count, 1 );
469 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
470 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
471 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
473 have_priors = is_classifier && params.priors;
476 int m = get_num_classes(), rows = 4;
478 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
479 for( i = 0; i < m; i++ )
481 double val = have_priors ? params.priors[i] : 1.;
483 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
484 priors->data.db[i] = val;
489 cvScale( priors, priors, 1./sum );
491 if( cat_var_count > 0 || params.cv_folds > 0 )
493 // need storage for cjk (see find_split_cat_gini) and risks/errors
494 rows += MAX( max_c_count, params.cv_folds ) + 1;
495 // add buffer for k-means clustering
496 if( m > 2 && max_c_count > params.max_categories )
497 rows += params.max_categories + (max_c_count+m-1)/m;
500 CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
503 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
504 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
509 cvReleaseMat( &sample_idx );
510 cvReleaseMat( &var_type0 );
511 cvReleaseMat( &tmp_map );
515 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
517 CvDTreeNode* root = 0;
518 CvMat* isubsample_idx = 0;
519 CvMat* subsample_co = 0;
521 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
526 CV_ERROR( CV_StsError, "No training data has been set" );
529 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
531 if( !isubsample_idx )
533 // make a copy of the root node
536 root = new_node( 0, 1, 0, 0 );
539 root->num_valid = temp.num_valid;
540 if( root->num_valid )
542 for( i = 0; i < var_count; i++ )
543 root->num_valid[i] = data_root->num_valid[i];
545 root->cv_Tn = temp.cv_Tn;
546 root->cv_node_risk = temp.cv_node_risk;
547 root->cv_node_error = temp.cv_node_error;
551 int* sidx = isubsample_idx->data.i;
552 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
553 int* co, cur_ofs = 0;
554 int vi, i, total = data_root->sample_count;
555 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
556 root = new_node( 0, count, 1, 0 );
558 CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
559 cvZero( subsample_co );
560 co = subsample_co->data.i;
561 for( i = 0; i < count; i++ )
563 for( i = 0; i < total; i++ )
574 for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
576 int ci = get_var_type(vi);
578 if( ci >= 0 || vi >= var_count )
580 const int* src = get_cat_var_data( data_root, vi );
581 int* dst = get_cat_var_data( root, vi );
584 for( i = 0; i < count; i++ )
586 int val = src[sidx[i]];
588 num_valid += val >= 0;
592 root->set_num_valid(vi, num_valid);
596 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
597 CvPair32s32f* dst = get_ord_var_data( root, vi );
598 int j = 0, idx, count_i;
599 int num_valid = data_root->get_num_valid(vi);
601 for( i = 0; i < num_valid; i++ )
607 float val = src[i].val;
608 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
616 root->set_num_valid(vi, j);
618 for( ; i < total; i++ )
624 float val = src[i].val;
625 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
638 cvReleaseMat( &isubsample_idx );
639 cvReleaseMat( &subsample_co );
645 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
646 float* values, uchar* missing,
647 float* responses, bool get_class_idx )
649 CvMat* subsample_idx = 0;
650 CvMat* subsample_co = 0;
652 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
656 int i, vi, total = sample_count, count = total, cur_ofs = 0;
662 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
663 sidx = subsample_idx->data.i;
664 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
665 co = subsample_co->data.i;
666 cvZero( subsample_co );
667 count = subsample_idx->cols + subsample_idx->rows - 1;
668 for( i = 0; i < count; i++ )
670 for( i = 0; i < total; i++ )
672 int count_i = co[i*2];
675 co[i*2+1] = cur_ofs*var_count;
681 memset( missing, 1, count*var_count );
683 for( vi = 0; vi < var_count; vi++ )
685 int ci = get_var_type(vi);
686 if( ci >= 0 ) // categorical
688 float* dst = values + vi;
689 uchar* m = missing + vi;
690 const int* src = get_cat_var_data(data_root, vi);
692 for( i = 0; i < count; i++, dst += var_count, m += var_count )
694 int idx = sidx ? sidx[i] : i;
702 float* dst = values + vi;
703 uchar* m = missing + vi;
704 const CvPair32s32f* src = get_ord_var_data(data_root, vi);
705 int count1 = data_root->get_num_valid(vi);
707 for( i = 0; i < count1; i++ )
714 cur_ofs = co[idx*2+1];
717 cur_ofs = idx*var_count;
720 float val = src[i].val;
721 for( ; count_i > 0; count_i--, cur_ofs += var_count )
734 const int* src = get_class_labels(data_root);
735 for( i = 0; i < count; i++ )
737 int idx = sidx ? sidx[i] : i;
738 int val = get_class_idx ? src[idx] :
739 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
740 responses[i] = (float)val;
745 const float* src = get_ord_responses(data_root);
746 for( i = 0; i < count; i++ )
748 int idx = sidx ? sidx[i] : i;
749 responses[i] = src[idx];
755 cvReleaseMat( &subsample_idx );
756 cvReleaseMat( &subsample_co );
760 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
761 int storage_idx, int offset )
763 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
765 node->sample_count = count;
766 node->depth = parent ? parent->depth + 1 : 0;
767 node->parent = parent;
768 node->left = node->right = 0;
774 node->buf_idx = storage_idx;
775 node->offset = offset;
777 node->num_valid = (int*)cvSetNew( nv_heap );
780 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
781 node->complexity = 0;
783 if( params.cv_folds > 0 && cv_heap )
785 int cv_n = params.cv_folds;
787 node->cv_Tn = (int*)cvSetNew( cv_heap );
788 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
789 node->cv_node_error = node->cv_node_risk + cv_n;
795 node->cv_node_risk = 0;
796 node->cv_node_error = 0;
803 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
804 int split_point, int inversed, float quality )
806 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
808 split->ord.c = cmp_val;
809 split->ord.split_point = split_point;
810 split->inversed = inversed;
811 split->quality = quality;
818 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
820 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
821 int i, n = (max_c_count + 31)/32;
825 split->quality = quality;
826 for( i = 0; i < n; i++ )
827 split->subset[i] = 0;
834 void CvDTreeTrainData::free_node( CvDTreeNode* node )
836 CvDTreeSplit* split = node->split;
837 free_node_data( node );
840 CvDTreeSplit* next = split->next;
841 cvSetRemoveByPtr( split_heap, split );
845 cvSetRemoveByPtr( node_heap, node );
849 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
851 if( node->num_valid )
853 cvSetRemoveByPtr( nv_heap, node->num_valid );
856 // do not free cv_* fields, as all the cross-validation related data is released at once.
860 void CvDTreeTrainData::free_train_data()
862 cvReleaseMat( &counts );
863 cvReleaseMat( &buf );
864 cvReleaseMat( &direction );
865 cvReleaseMat( &split_buf );
866 cvReleaseMemStorage( &temp_storage );
867 cv_heap = nv_heap = 0;
871 void CvDTreeTrainData::clear()
875 cvReleaseMemStorage( &tree_storage );
877 cvReleaseMat( &var_idx );
878 cvReleaseMat( &var_type );
879 cvReleaseMat( &cat_count );
880 cvReleaseMat( &cat_ofs );
881 cvReleaseMat( &cat_map );
882 cvReleaseMat( &priors );
884 node_heap = split_heap = 0;
886 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
887 have_cv_labels = have_priors = is_classifier = false;
889 buf_count = buf_size = 0;
898 int CvDTreeTrainData::get_num_classes() const
900 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
904 int CvDTreeTrainData::get_var_type(int vi) const
906 return var_type->data.i[vi];
910 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
912 int oi = ~get_var_type(vi);
913 assert( 0 <= oi && oi < ord_var_count );
914 return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
915 n->offset + oi*n->sample_count*2);
919 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
921 return get_cat_var_data( n, var_count );
925 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
927 return (float*)get_cat_var_data( n, var_count );
931 int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
933 return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
937 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
939 int ci = get_var_type(vi);
940 assert( 0 <= ci && ci <= cat_var_count + 1 );
941 return buf->data.i + n->buf_idx*buf->cols + n->offset +
942 (ord_var_count*2 + ci)*n->sample_count;
946 float* CvDTreeTrainData::get_weights( CvDTreeNode* n )
948 return have_weights ?
949 (float*)get_cat_var_data( n, var_count + 1 + (params.cv_folds > 0) ) : 0;
953 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
955 int idx = n->buf_idx + 1;
956 if( idx >= buf_count )
957 idx = shared ? 1 : 0;
962 /////////////////////// Decision Tree /////////////////////////
972 void CvDTree::clear()
974 cvReleaseMat( &var_importance );
984 pruned_tree_idx = -1;
994 bool CvDTree::train( const CvMat* _train_data, int _tflag,
995 const CvMat* _responses, const CvMat* _var_idx,
996 const CvMat* _sample_idx, const CvMat* _var_type,
997 const CvMat* _missing_mask, CvDTreeParams _params )
1001 CV_FUNCNAME( "CvDTree::train" );
1006 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1007 _var_idx, _sample_idx, _var_type,
1008 _missing_mask, _params, false );
1009 CV_CALL( result = do_train(0));
1017 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1019 bool result = false;
1021 CV_FUNCNAME( "CvDTree::train" );
1027 data->shared = true;
1028 CV_CALL( result = do_train(_subsample_idx));
1036 bool CvDTree::do_train( const CvMat* _subsample_idx )
1038 bool result = false;
1040 CV_FUNCNAME( "CvDTree::train" );
1044 root = data->subsample_data( _subsample_idx );
1046 try_split_node(root);
1048 if( data->params.cv_folds > 0 )
1052 data->free_train_data();
1062 #define DTREE_CAT_DIR(idx,subset) \
1063 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
1065 void CvDTree::try_split_node( CvDTreeNode* node )
1067 CvDTreeSplit* best_split = 0;
1068 int i, n = node->sample_count, vi;
1069 bool can_split = true;
1070 double quality_scale;
1072 calc_node_value( node );
1074 if( node->sample_count <= data->params.min_sample_count ||
1075 node->depth >= data->params.max_depth )
1078 if( can_split && data->is_classifier )
1080 // check if we have a "pure" node,
1081 // we assume that cls_count is filled by calc_node_value()
1082 int* cls_count = data->counts->data.i;
1083 int nz = 0, m = data->get_num_classes();
1084 for( i = 0; i < m; i++ )
1085 nz += cls_count[i] != 0;
1086 if( nz == 1 ) // there is only one class
1089 else if( can_split )
1091 const float* responses = data->get_ord_responses( node );
1092 float diff = responses[n-1] - responses[0];
1093 if( diff < data->params.regression_accuracy )
1099 best_split = find_best_split(node);
1100 // TODO: check the split quality ...
1101 node->split = best_split;
1104 if( !can_split || !best_split )
1106 data->free_node_data(node);
1110 quality_scale = calc_node_dir( node );
1112 if( data->params.use_surrogates )
1114 // find all the surrogate splits
1115 // and sort them by their similarity to the primary one
1116 for( vi = 0; vi < data->var_count; vi++ )
1118 CvDTreeSplit* split;
1119 int ci = data->get_var_type(vi);
1121 if( vi == best_split->var_idx )
1125 split = find_surrogate_split_cat( node, vi );
1127 split = find_surrogate_split_ord( node, vi );
1132 CvDTreeSplit* prev_split = node->split;
1133 split->quality = (float)(split->quality*quality_scale);
1135 while( prev_split->next &&
1136 prev_split->next->quality > split->quality )
1137 prev_split = prev_split->next;
1138 split->next = prev_split->next;
1139 prev_split->next = split;
1144 split_node_data( node );
1145 try_split_node( node->left );
1146 try_split_node( node->right );
1150 // calculate direction (left(-1),right(1),missing(0))
1151 // for each sample using the best split
1152 // the function returns scale coefficients for surrogate split quality factors.
1153 // the scale is applied to normalize surrogate split quality relatively to the
1154 // best (primary) split quality. That is, if a surrogate split is absolutely
1155 // identical to the primary split, its quality will be set to the maximum value =
1156 // quality of the primary split; otherwise, it will be lower.
1157 // besides, the function compute node->maxlr,
1158 // minimum possible quality (w/o considering the above mentioned scale)
1159 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1160 // are not discarded.
1161 double CvDTree::calc_node_dir( CvDTreeNode* node )
1163 char* dir = (char*)data->direction->data.ptr;
1164 int i, n = node->sample_count, vi = node->split->var_idx;
1167 assert( !node->split->inversed );
1169 if( data->get_var_type(vi) >= 0 ) // split on categorical var
1171 const int* labels = data->get_cat_var_data(node,vi);
1172 const int* subset = node->split->subset;
1174 if( !data->have_priors )
1176 int sum = 0, sum_abs = 0;
1178 for( i = 0; i < n; i++ )
1180 int idx = labels[i];
1181 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1182 sum += d; sum_abs += d & 1;
1186 R = (sum_abs + sum) >> 1;
1187 L = (sum_abs - sum) >> 1;
1191 const int* responses = data->get_class_labels(node);
1192 const double* priors = data->priors->data.db;
1193 double sum = 0, sum_abs = 0;
1195 for( i = 0; i < n; i++ )
1197 int idx = labels[i];
1198 double w = priors[responses[i]];
1199 int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
1200 sum += d*w; sum_abs += (d & 1)*w;
1204 R = (sum_abs + sum) * 0.5;
1205 L = (sum_abs - sum) * 0.5;
1208 else // split on ordered var
1210 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1211 int split_point = node->split->ord.split_point;
1212 int n1 = node->get_num_valid(vi);
1214 assert( 0 <= split_point && split_point < n1-1 );
1216 if( !data->have_priors )
1218 for( i = 0; i <= split_point; i++ )
1219 dir[sorted[i].i] = (char)-1;
1220 for( ; i < n1; i++ )
1221 dir[sorted[i].i] = (char)1;
1223 dir[sorted[i].i] = (char)0;
1226 R = n1 - split_point + 1;
1230 const int* responses = data->get_class_labels(node);
1231 const double* priors = data->priors->data.db;
1234 for( i = 0; i <= split_point; i++ )
1236 int idx = sorted[i].i;
1237 double w = priors[responses[idx]];
1238 dir[idx] = (char)-1;
1242 for( ; i < n1; i++ )
1244 int idx = sorted[i].i;
1245 double w = priors[responses[idx]];
1251 dir[sorted[i].i] = (char)0;
1255 node->maxlr = MAX( L, R );
1256 return node->split->quality/(L + R);
1260 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1263 CvDTreeSplit *best_split = 0, *split = 0, *t;
1265 for( vi = 0; vi < data->var_count; vi++ )
1267 int ci = data->get_var_type(vi);
1268 if( node->get_num_valid(vi) <= 1 )
1271 if( data->is_classifier )
1274 split = find_split_cat_class( node, vi );
1276 split = find_split_ord_class( node, vi );
1281 split = find_split_cat_reg( node, vi );
1283 split = find_split_ord_reg( node, vi );
1288 if( !best_split || best_split->quality < split->quality )
1289 CV_SWAP( best_split, split, t );
1291 cvSetRemoveByPtr( data->split_heap, split );
1299 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1301 const float epsilon = FLT_EPSILON*2;
1302 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1303 const int* responses = data->get_class_labels(node);
1304 int n = node->sample_count;
1305 int n1 = node->get_num_valid(vi);
1306 int m = data->get_num_classes();
1307 const int* rc0 = data->counts->data.i;
1308 int* lc = (int*)(rc0 + m);
1311 double lsum2 = 0, rsum2 = 0, best_val = 0;
1312 const double* priors = data->have_priors ? data->priors->data.db : 0;
1314 // init arrays of class instance counters on both sides of the split
1315 for( i = 0; i < m; i++ )
1321 // compensate for missing values
1322 for( i = n1; i < n; i++ )
1323 rc[responses[sorted[i].i]]--;
1329 for( i = 0; i < m; i++ )
1330 rsum2 += (double)rc[i]*rc[i];
1332 for( i = 0; i < n1 - 1; i++ )
1334 int idx = responses[sorted[i].i];
1337 lv = lc[idx]; rv = rc[idx];
1340 lc[idx] = lv + 1; rc[idx] = rv - 1;
1342 if( sorted[i].val + epsilon < sorted[i+1].val )
1344 double val = lsum2/L + rsum2/R;
1345 if( best_val < val )
1355 double L = 0, R = 0;
1356 for( i = 0; i < m; i++ )
1358 double wv = rc[i]*priors[i];
1363 for( i = 0; i < n1 - 1; i++ )
1365 int idx = responses[sorted[i].i];
1367 double p = priors[idx], p2 = p*p;
1369 lv = lc[idx]; rv = rc[idx];
1370 lsum2 += p2*(lv*2 + 1);
1371 rsum2 -= p2*(rv*2 - 1);
1372 lc[idx] = lv + 1; rc[idx] = rv - 1;
1374 if( sorted[i].val + epsilon < sorted[i+1].val )
1376 double val = lsum2/L + rsum2/R;
1377 if( best_val < val )
1386 return best_i >= 0 ? data->new_split_ord( vi,
1387 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1388 0, (float)best_val ) : 0;
1392 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1393 int* csums, int k, int* labels )
1395 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1396 int iters = 0, max_iters = 100;
1398 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1399 double *v_weights = buf, *c_weights = buf + k;
1400 bool modified = true;
1401 CvRNG* r = &data->rng;
1403 // assign labels randomly
1404 for( i = idx = 0; i < n; i++ )
1407 const int* v = vectors + i*m;
1409 idx &= idx < k ? -1 : 0;
1411 // compute weight of each vector
1412 for( j = 0; j < m; j++ )
1414 v_weights[i] = sum ? 1./sum : 0.;
1417 for( i = 0; i < n; i++ )
1419 int i1 = cvRandInt(r) % n;
1420 int i2 = cvRandInt(r) % n;
1421 CV_SWAP( labels[i1], labels[i2], j );
1424 for( iters = 0; iters <= max_iters; iters++ )
1427 for( i = 0; i < k; i++ )
1429 for( j = 0; j < m; j++ )
1433 for( i = 0; i < n; i++ )
1435 const int* v = vectors + i*m;
1436 int* s = csums + labels[i]*m;
1437 for( j = 0; j < m; j++ )
1441 // exit the loop here, when we have up-to-date csums
1442 if( iters == max_iters || !modified )
1447 // calculate weight of each cluster
1448 for( i = 0; i < k; i++ )
1450 const int* s = csums + i*m;
1452 for( j = 0; j < m; j++ )
1454 c_weights[i] = sum ? 1./sum : 0;
1457 // now for each vector determine the closest cluster
1458 for( i = 0; i < n; i++ )
1460 const int* v = vectors + i*m;
1461 double alpha = v_weights[i];
1462 double min_dist2 = DBL_MAX;
1465 for( idx = 0; idx < k; idx++ )
1467 const int* s = csums + idx*m;
1468 double dist2 = 0., beta = c_weights[idx];
1469 for( j = 0; j < m; j++ )
1471 double t = v[j]*alpha - s[j]*beta;
1474 if( min_dist2 > dist2 )
1481 if( min_idx != labels[i] )
1483 labels[i] = min_idx;
1489 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1491 CvDTreeSplit* split;
1492 const int* labels = data->get_cat_var_data(node, vi);
1493 const int* responses = data->get_class_labels(node);
1494 int ci = data->get_var_type(vi);
1495 int n = node->sample_count;
1496 int m = data->get_num_classes();
1497 int _mi = data->cat_count->data.i[ci], mi = _mi;
1498 const int* rc0 = data->counts->data.i;
1499 int* lc = (int*)(rc0 + m);
1501 int* _cjk = rc + m*2, *cjk = _cjk;
1502 double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1503 int* cluster_labels = 0;
1506 double L = 0, R = 0;
1507 double best_val = 0;
1508 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1509 const double* priors = data->priors->data.db;
1511 // init array of counters:
1512 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1513 for( j = -1; j < mi; j++ )
1514 for( k = 0; k < m; k++ )
1517 for( i = 0; i < n; i++ )
1520 int k = responses[i];
1526 if( mi > data->params.max_categories )
1528 mi = MIN(data->params.max_categories, n);
1530 cluster_labels = cjk + mi*m;
1531 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1539 int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1540 for( j = 0; j < mi; j++ )
1541 int_ptr[j] = cjk + j*2 + 1;
1542 icvSortIntPtr( int_ptr, mi, 0 );
1547 for( k = 0; k < m; k++ )
1550 for( j = 0; j < mi; j++ )
1551 sum += cjk[j*m + k];
1556 for( j = 0; j < mi; j++ )
1559 for( k = 0; k < m; k++ )
1560 sum += cjk[j*m + k]*priors[k];
1565 for( ; subset_i < subset_n; subset_i++ )
1569 double lsum2 = 0, rsum2 = 0;
1572 idx = (int)(int_ptr[subset_i] - cjk)/2;
1575 int graycode = (subset_i>>1)^subset_i;
1576 int diff = graycode ^ prevcode;
1578 // determine index of the changed bit.
1580 idx = diff >= (1 << 16) ? 16 : 0;
1581 u.f = (float)(((diff >> 16) | diff) & 65535);
1582 idx += (u.i >> 23) - 127;
1583 subtract = graycode < prevcode;
1584 prevcode = graycode;
1588 weight = c_weights[idx];
1589 if( weight < FLT_EPSILON )
1594 for( k = 0; k < m; k++ )
1597 int lval = lc[k] + t;
1598 int rval = rc[k] - t;
1599 double p = priors[k], p2 = p*p;
1600 lsum2 += p2*lval*lval;
1601 rsum2 += p2*rval*rval;
1602 lc[k] = lval; rc[k] = rval;
1609 for( k = 0; k < m; k++ )
1612 int lval = lc[k] - t;
1613 int rval = rc[k] + t;
1614 double p = priors[k], p2 = p*p;
1615 lsum2 += p2*lval*lval;
1616 rsum2 += p2*rval*rval;
1617 lc[k] = lval; rc[k] = rval;
1623 if( L > FLT_EPSILON && R > FLT_EPSILON )
1625 double val = lsum2/L + rsum2/R;
1626 if( best_val < val )
1629 best_subset = subset_i;
1634 if( best_subset < 0 )
1637 split = data->new_split_cat( vi, (float)best_val );
1641 for( i = 0; i <= best_subset; i++ )
1643 idx = (int)(int_ptr[i] - cjk) >> 1;
1644 split->subset[idx >> 5] |= 1 << (idx & 31);
1649 for( i = 0; i < _mi; i++ )
1651 idx = cluster_labels ? cluster_labels[i] : i;
1652 if( best_subset & (1 << idx) )
1653 split->subset[i >> 5] |= 1 << (i & 31);
1661 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1663 const float epsilon = FLT_EPSILON*2;
1664 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1665 const float* responses = data->get_ord_responses(node);
1666 int n = node->sample_count;
1667 int n1 = node->get_num_valid(vi);
1669 double best_val = 0, lsum = 0, rsum = node->value*n;
1672 // compensate for missing values
1673 for( i = n1; i < n; i++ )
1674 rsum -= responses[sorted[i].i];
1676 // find the optimal split
1677 for( i = 0; i < n1 - 1; i++ )
1679 float val = responses[sorted[i].i];
1684 if( sorted[i].val + epsilon < sorted[i+1].val )
1686 double val = lsum*lsum/L + rsum*rsum/R;
1687 if( best_val < val )
1695 return best_i >= 0 ? data->new_split_ord( vi,
1696 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1697 0, (float)best_val ) : 0;
1701 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1703 CvDTreeSplit* split;
1704 const int* labels = data->get_cat_var_data(node, vi);
1705 const float* responses = data->get_ord_responses(node);
1706 int ci = data->get_var_type(vi);
1707 int n = node->sample_count;
1708 int mi = data->cat_count->data.i[ci];
1709 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1710 int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1711 double** sum_ptr = 0;
1712 int i, L = 0, R = 0;
1713 double best_val = 0, lsum = 0, rsum = 0;
1714 int best_subset = -1, subset_i;
1716 for( i = -1; i < mi; i++ )
1717 sum[i] = counts[i] = 0;
1719 // calculate sum response and weight of each category of the input var
1720 for( i = 0; i < n; i++ )
1722 int idx = labels[i];
1723 double s = sum[idx] + responses[i];
1724 int nc = counts[idx] + 1;
1729 // calculate average response in each category
1730 for( i = 0; i < mi; i++ )
1734 sum[i] /= MAX(counts[i],1);
1735 sum_ptr[i] = sum + i;
1738 icvSortDblPtr( sum_ptr, mi, 0 );
1740 // revert back to unnormalized sum
1741 // (there should be a very little loss of accuracy)
1742 for( i = 0; i < mi; i++ )
1743 sum[i] *= counts[i];
1745 for( subset_i = 0; subset_i < mi-1; subset_i++ )
1747 int idx = (int)(sum_ptr[subset_i] - sum);
1748 int ni = counts[idx];
1752 double s = sum[idx];
1758 double val = lsum*lsum/L + rsum*rsum/R;
1759 if( best_val < val )
1762 best_subset = subset_i;
1768 if( best_subset < 0 )
1771 split = data->new_split_cat( vi, (float)best_val );
1772 for( i = 0; i <= best_subset; i++ )
1774 int idx = (int)(sum_ptr[i] - sum);
1775 split->subset[idx >> 5] |= 1 << (idx & 31);
1782 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
1784 const float epsilon = FLT_EPSILON*2;
1785 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1786 const char* dir = (char*)data->direction->data.ptr;
1787 int n1 = node->get_num_valid(vi);
1788 // LL - number of samples that both the primary and the surrogate splits send to the left
1789 // LR - ... primary split sends to the left and the surrogate split sends to the right
1790 // RL - ... primary split sends to the right and the surrogate split sends to the left
1791 // RR - ... both send to the right
1792 int i, best_i = -1, best_inversed = 0;
1795 if( !data->have_priors )
1797 int LL = 0, RL = 0, LR, RR;
1798 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
1799 int sum = 0, sum_abs = 0;
1801 for( i = 0; i < n1; i++ )
1803 int d = dir[sorted[i].i];
1804 sum += d; sum_abs += d & 1;
1807 // sum_abs = R + L; sum = R - L
1808 RR = (sum_abs + sum) >> 1;
1809 LR = (sum_abs - sum) >> 1;
1811 // initially all the samples are sent to the right by the surrogate split,
1812 // LR of them are sent to the left by primary split, and RR - to the right.
1813 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1814 for( i = 0; i < n1 - 1; i++ )
1816 int d = dir[sorted[i].i];
1821 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1824 best_i = i; best_inversed = 0;
1830 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
1833 best_i = i; best_inversed = 1;
1837 best_val = _best_val;
1841 double LL = 0, RL = 0, LR, RR;
1842 double worst_val = node->maxlr;
1843 double sum = 0, sum_abs = 0;
1844 const double* priors = data->priors->data.db;
1845 const int* responses = data->get_class_labels(node);
1846 best_val = worst_val;
1848 for( i = 0; i < n1; i++ )
1850 int idx = sorted[i].i;
1851 double w = priors[responses[idx]];
1853 sum += d*w; sum_abs += (d & 1)*w;
1856 // sum_abs = R + L; sum = R - L
1857 RR = (sum_abs + sum)*0.5;
1858 LR = (sum_abs - sum)*0.5;
1860 // initially all the samples are sent to the right by the surrogate split,
1861 // LR of them are sent to the left by primary split, and RR - to the right.
1862 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
1863 for( i = 0; i < n1 - 1; i++ )
1865 int idx = sorted[i].i;
1866 double w = priors[responses[idx]];
1872 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1875 best_i = i; best_inversed = 0;
1881 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
1884 best_i = i; best_inversed = 1;
1890 return best_i >= 0 ? data->new_split_ord( vi,
1891 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1892 best_inversed, (float)best_val ) : 0;
1896 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
1898 const int* labels = data->get_cat_var_data(node, vi);
1899 const char* dir = (char*)data->direction->data.ptr;
1900 int n = node->sample_count;
1901 // LL - number of samples that both the primary and the surrogate splits send to the left
1902 // LR - ... primary split sends to the left and the surrogate split sends to the right
1903 // RL - ... primary split sends to the right and the surrogate split sends to the left
1904 // RR - ... both send to the right
1905 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
1906 int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
1907 double best_val = 0;
1908 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
1909 double* rc = lc + mi + 1;
1911 for( i = -1; i < mi; i++ )
1914 // for each category calculate the weight of samples
1915 // sent to the left (lc) and to the right (rc) by the primary split
1916 if( !data->have_priors )
1918 int* _lc = data->counts->data.i + 1;
1919 int* _rc = _lc + mi + 1;
1921 for( i = -1; i < mi; i++ )
1922 _lc[i] = _rc[i] = 0;
1924 for( i = 0; i < n; i++ )
1926 int idx = labels[i];
1928 int sum = _lc[idx] + d;
1929 int sum_abs = _rc[idx] + (d & 1);
1930 _lc[idx] = sum; _rc[idx] = sum_abs;
1933 for( i = 0; i < mi; i++ )
1936 int sum_abs = _rc[i];
1937 lc[i] = (sum_abs - sum) >> 1;
1938 rc[i] = (sum_abs + sum) >> 1;
1943 const double* priors = data->priors->data.db;
1944 const int* responses = data->get_class_labels(node);
1946 for( i = 0; i < n; i++ )
1948 int idx = labels[i];
1949 double w = priors[responses[i]];
1951 double sum = lc[idx] + d*w;
1952 double sum_abs = rc[idx] + (d & 1)*w;
1953 lc[idx] = sum; rc[idx] = sum_abs;
1956 for( i = 0; i < mi; i++ )
1959 double sum_abs = rc[i];
1960 lc[i] = (sum_abs - sum) * 0.5;
1961 rc[i] = (sum_abs + sum) * 0.5;
1965 // 2. now form the split.
1966 // in each category send all the samples to the same direction as majority
1967 for( i = 0; i < mi; i++ )
1969 double lval = lc[i], rval = rc[i];
1972 split->subset[i >> 5] |= 1 << (i & 31);
1979 split->quality = (float)best_val;
1980 if( split->quality <= node->maxlr )
1981 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
1987 void CvDTree::calc_node_value( CvDTreeNode* node )
1989 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
1990 const int* cv_labels = data->get_cv_labels(node);
1992 if( data->is_classifier )
1994 // in case of classification tree:
1995 // * node value is the label of the class that has the largest weight in the node.
1996 // * node risk is the weighted number of misclassified samples,
1997 // * j-th cross-validation fold value and risk are calculated as above,
1998 // but using the samples with cv_labels(*)!=j.
1999 // * j-th cross-validation fold error is calculated as the weighted number of
2000 // misclassified samples with cv_labels(*)==j.
2002 // compute the number of instances of each class
2003 int* cls_count = data->counts->data.i;
2004 const int* responses = data->get_class_labels(node);
2005 int m = data->get_num_classes();
2006 int* cv_cls_count = cls_count + m;
2007 double max_val = -1, total_weight = 0;
2009 double* priors = data->priors->data.db;
2011 for( k = 0; k < m; k++ )
2016 for( i = 0; i < n; i++ )
2017 cls_count[responses[i]]++;
2021 for( j = 0; j < cv_n; j++ )
2022 for( k = 0; k < m; k++ )
2023 cv_cls_count[j*m + k] = 0;
2025 for( i = 0; i < n; i++ )
2027 j = cv_labels[i]; k = responses[i];
2028 cv_cls_count[j*m + k]++;
2031 for( j = 0; j < cv_n; j++ )
2032 for( k = 0; k < m; k++ )
2033 cls_count[k] += cv_cls_count[j*m + k];
2036 for( k = 0; k < m; k++ )
2038 double val = cls_count[k]*priors[k];
2039 total_weight += val;
2047 node->class_idx = max_k;
2048 node->value = data->cat_map->data.i[
2049 data->cat_ofs->data.i[data->cat_var_count] + max_k];
2050 node->node_risk = total_weight - max_val;
2052 for( j = 0; j < cv_n; j++ )
2054 double sum_k = 0, sum = 0, max_val_k = 0;
2055 max_val = -1; max_k = -1;
2057 for( k = 0; k < m; k++ )
2059 double w = priors[k];
2060 double val_k = cv_cls_count[j*m + k]*w;
2061 double val = cls_count[k]*w - val_k;
2072 node->cv_Tn[j] = INT_MAX;
2073 node->cv_node_risk[j] = sum - max_val;
2074 node->cv_node_error[j] = sum_k - max_val_k;
2079 // in case of regression tree:
2080 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2081 // n is the number of samples in the node.
2082 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2083 // * j-th cross-validation fold value and risk are calculated as above,
2084 // but using the samples with cv_labels(*)!=j.
2085 // * j-th cross-validation fold error is calculated
2086 // using samples with cv_labels(*)==j as the test subset:
2087 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2088 // where node_value_j is the node value calculated
2089 // as described in the previous bullet, and summation is done
2090 // over the samples with cv_labels(*)==j.
2092 double sum = 0, sum2 = 0;
2093 const float* values = data->get_ord_responses(node);
2094 double *cv_sum = 0, *cv_sum2 = 0;
2099 // if cross-validation is not used, we even do not compute node_risk
2100 // (so the tree sequence T1>...>{root} may not be built).
2101 for( i = 0; i < n; i++ )
2106 cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2107 cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2108 cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2110 for( j = 0; j < cv_n; j++ )
2112 cv_sum[j] = cv_sum2[j] = 0.;
2116 for( i = 0; i < n; i++ )
2119 double t = values[i];
2120 double s = cv_sum[j] + t;
2121 double s2 = cv_sum2[j] + t*t;
2122 int nc = cv_count[j] + 1;
2128 for( j = 0; j < cv_n; j++ )
2134 node->node_risk = sum2 - (sum/n)*sum;
2137 node->value = sum/n;
2139 for( j = 0; j < cv_n; j++ )
2141 double s = cv_sum[j], si = sum - s;
2142 double s2 = cv_sum2[j], s2i = sum2 - s2;
2143 int c = cv_count[j], ci = n - c;
2144 double r = si/MAX(ci,1);
2145 node->cv_node_risk[j] = s2i - r*r*ci;
2146 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2147 node->cv_Tn[j] = INT_MAX;
2153 void CvDTree::complete_node_dir( CvDTreeNode* node )
2155 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2156 int nz = n - node->get_num_valid(node->split->var_idx);
2157 char* dir = (char*)data->direction->data.ptr;
2159 // try to complete direction using surrogate splits
2160 if( nz && data->params.use_surrogates )
2162 CvDTreeSplit* split = node->split->next;
2163 for( ; split != 0 && nz; split = split->next )
2165 int inversed_mask = split->inversed ? -1 : 0;
2166 vi = split->var_idx;
2168 if( data->get_var_type(vi) >= 0 ) // split on categorical var
2170 const int* labels = data->get_cat_var_data(node, vi);
2171 const int* subset = split->subset;
2173 for( i = 0; i < n; i++ )
2176 if( !dir[i] && (idx = labels[i]) >= 0 )
2178 int d = DTREE_CAT_DIR(idx,subset);
2179 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2185 else // split on ordered var
2187 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2188 int split_point = split->ord.split_point;
2189 int n1 = node->get_num_valid(vi);
2191 assert( 0 <= split_point && split_point < n-1 );
2193 for( i = 0; i < n1; i++ )
2195 int idx = sorted[i].i;
2198 int d = i <= split_point ? -1 : 1;
2199 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2208 // find the default direction for the rest
2211 for( i = nr = 0; i < n; i++ )
2214 d0 = nl > nr ? -1 : nr > nl;
2217 // make sure that every sample is directed either to the left or to the right
2218 for( i = 0; i < n; i++ )
2228 dir[i] = (char)d; // remap (-1,1) to (0,1)
2233 void CvDTree::split_node_data( CvDTreeNode* node )
2235 int vi, i, n = node->sample_count, nl, nr;
2236 char* dir = (char*)data->direction->data.ptr;
2237 CvDTreeNode *left = 0, *right = 0;
2238 int* new_idx = data->split_buf->data.i;
2239 int new_buf_idx = data->get_child_buf_idx( node );
2241 complete_node_dir(node);
2243 for( i = nl = nr = 0; i < n; i++ )
2246 // initialize new indices for splitting ordered variables
2247 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2252 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2253 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2254 (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
2256 // split ordered variables, keep both halves sorted.
2257 for( vi = 0; vi < data->var_count; vi++ )
2259 int ci = data->get_var_type(vi);
2260 int n1 = node->get_num_valid(vi);
2261 CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2262 CvPair32s32f tl, tr;
2267 src = data->get_ord_var_data(node, vi);
2268 ldst0 = ldst = data->get_ord_var_data(left, vi);
2269 rdst0 = rdst = data->get_ord_var_data(right, vi);
2270 tl = ldst0[nl]; tr = rdst0[nr];
2273 for( i = 0; i < n1; i++ )
2276 float val = src[i].val;
2279 ldst->i = rdst->i = idx;
2280 ldst->val = rdst->val = val;
2285 left->set_num_valid(vi, (int)(ldst - ldst0));
2286 right->set_num_valid(vi, (int)(rdst - rdst0));
2294 ldst->i = rdst->i = idx;
2295 ldst->val = rdst->val = ord_nan;
2300 ldst0[nl] = tl; rdst0[nr] = tr;
2303 // split categorical vars, responses and cv_labels using new_idx relocation table
2304 for( vi = 0; vi <= data->var_count + data->have_cv_labels + data->have_weights; vi++ )
2306 int ci = data->get_var_type(vi);
2307 int n1 = node->get_num_valid(vi), nr1 = 0;
2308 int *src, *ldst0, *rdst0, *ldst, *rdst;
2314 src = data->get_cat_var_data(node, vi);
2315 ldst0 = ldst = data->get_cat_var_data(left, vi);
2316 rdst0 = rdst = data->get_cat_var_data(right, vi);
2317 tl = ldst0[nl]; tr = rdst0[nr];
2319 for( i = 0; i < n; i++ )
2323 *ldst = *rdst = val;
2326 nr1 += (val >= 0)&d;
2329 if( vi < data->var_count )
2331 left->set_num_valid(vi, n1 - nr1);
2332 right->set_num_valid(vi, nr1);
2335 ldst0[nl] = tl; rdst0[nr] = tr;
2338 // deallocate the parent node data that is not needed anymore
2339 data->free_node_data(node);
2343 void CvDTree::prune_cv()
2349 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2350 // 2. choose the best tree index (if need, apply 1SE rule).
2351 // 3. store the best index and cut the branches.
2353 CV_FUNCNAME( "CvDTree::prune_cv" );
2357 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2358 // currently, 1SE for regression is not implemented
2359 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2361 double min_err = 0, min_err_se = 0;
2364 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2366 // build the main tree sequence, calculate alpha's
2369 double min_alpha = update_tree_rnc(tree_count, -1);
2370 if( cut_tree(tree_count, -1, min_alpha) )
2373 if( ab->cols <= tree_count )
2375 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2376 for( ti = 0; ti < ab->cols; ti++ )
2377 temp->data.db[ti] = ab->data.db[ti];
2378 cvReleaseMat( &ab );
2383 ab->data.db[tree_count] = min_alpha;
2386 ab->data.db[0] = 0.;
2387 for( ti = 1; ti < tree_count-1; ti++ )
2388 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2389 ab->data.db[tree_count-1] = DBL_MAX*0.5;
2391 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2392 err = err_jk->data.db;
2394 for( j = 0; j < cv_n; j++ )
2397 for( ; tk < tree_count; tj++ )
2399 double min_alpha = update_tree_rnc(tj, j);
2400 if( cut_tree(tj, j, min_alpha) )
2401 min_alpha = DBL_MAX;
2403 for( ; tk < tree_count; tk++ )
2405 if( ab->data.db[tk] > min_alpha )
2407 err[j*tree_count + tk] = root->tree_error;
2412 for( ti = 0; ti < tree_count; ti++ )
2415 for( j = 0; j < cv_n; j++ )
2416 sum_err += err[j*tree_count + ti];
2417 if( ti == 0 || sum_err < min_err )
2422 min_err_se = sqrt( sum_err*(n - sum_err) );
2424 else if( sum_err < min_err + min_err_se )
2428 pruned_tree_idx = min_idx;
2429 free_prune_data(data->params.truncate_pruned_tree != 0);
2433 cvReleaseMat( &err_jk );
2434 cvReleaseMat( &ab );
2435 cvReleaseMat( &temp );
2439 double CvDTree::update_tree_rnc( int T, int fold )
2441 CvDTreeNode* node = root;
2442 double min_alpha = DBL_MAX;
2446 CvDTreeNode* parent;
2449 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2450 if( t <= T || !node->left )
2452 node->complexity = 1;
2453 node->tree_risk = node->node_risk;
2454 node->tree_error = 0.;
2457 node->tree_risk = node->cv_node_risk[fold];
2458 node->tree_error = node->cv_node_error[fold];
2465 for( parent = node->parent; parent && parent->right == node;
2466 node = parent, parent = parent->parent )
2468 parent->complexity += node->complexity;
2469 parent->tree_risk += node->tree_risk;
2470 parent->tree_error += node->tree_error;
2472 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2473 - parent->tree_risk)/(parent->complexity - 1);
2474 min_alpha = MIN( min_alpha, parent->alpha );
2480 parent->complexity = node->complexity;
2481 parent->tree_risk = node->tree_risk;
2482 parent->tree_error = node->tree_error;
2483 node = parent->right;
2490 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2492 CvDTreeNode* node = root;
2498 CvDTreeNode* parent;
2501 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2502 if( t <= T || !node->left )
2504 if( node->alpha <= min_alpha + FLT_EPSILON )
2507 node->cv_Tn[fold] = T;
2517 for( parent = node->parent; parent && parent->right == node;
2518 node = parent, parent = parent->parent )
2524 node = parent->right;
2531 void CvDTree::free_prune_data(bool cut_tree)
2533 CvDTreeNode* node = root;
2537 CvDTreeNode* parent;
2540 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2541 // as we will clear the whole cross-validation heap at the end
2543 node->cv_node_error = node->cv_node_risk = 0;
2549 for( parent = node->parent; parent && parent->right == node;
2550 node = parent, parent = parent->parent )
2552 if( cut_tree && parent->Tn <= pruned_tree_idx )
2554 data->free_node( parent->left );
2555 data->free_node( parent->right );
2556 parent->left = parent->right = 0;
2563 node = parent->right;
2567 cvClearSet( data->cv_heap );
2571 void CvDTree::free_tree()
2573 if( root && data && data->shared )
2575 pruned_tree_idx = INT_MIN;
2576 free_prune_data(true);
2577 data->free_node(root);
2583 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2584 const CvMat* _missing, bool preprocessed_input ) const
2586 CvDTreeNode* result = 0;
2589 CV_FUNCNAME( "CvDTree::predict" );
2593 int i, step, mstep = 0;
2594 const float* sample;
2596 CvDTreeNode* node = root;
2603 CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2605 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2606 _sample->cols != 1 && _sample->rows != 1 ||
2607 _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2608 _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2609 CV_ERROR( CV_StsBadArg,
2610 "the input sample must be 1d floating-point vector with the same "
2611 "number of elements as the total number of variables used for training" );
2613 sample = _sample->data.fl;
2614 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
2616 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2618 int n = data->cat_count->cols;
2619 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2620 for( i = 0; i < n; i++ )
2626 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2627 !CV_ARE_SIZES_EQ(_missing, _sample) )
2628 CV_ERROR( CV_StsBadArg,
2629 "the missing data mask must be 8-bit vector of the same size as input sample" );
2630 m = _missing->data.ptr;
2631 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2634 vtype = data->var_type->data.i;
2635 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2636 cmap = data->cat_map->data.i;
2637 cofs = data->cat_ofs->data.i;
2639 while( node->Tn > pruned_tree_idx && node->left )
2641 CvDTreeSplit* split = node->split;
2643 for( ; !dir && split != 0; split = split->next )
2645 int vi = split->var_idx;
2647 int i = vidx ? vidx[vi] : vi;
2648 float val = sample[i*step];
2649 if( m && m[i*mstep] )
2651 if( ci < 0 ) // ordered
2652 dir = val <= split->ord.c ? -1 : 1;
2656 if( preprocessed_input )
2663 int a = c = cofs[ci];
2665 int ival = cvRound(val);
2667 CV_ERROR( CV_StsBadArg,
2668 "one of input categorical variable is not an integer" );
2673 if( ival < cmap[c] )
2675 else if( ival > cmap[c] )
2681 if( c < 0 || ival != cmap[c] )
2684 catbuf[ci] = c -= cofs[ci];
2687 dir = DTREE_CAT_DIR(c, split->subset);
2690 if( split->inversed )
2696 double diff = node->right->sample_count - node->left->sample_count;
2697 dir = diff < 0 ? -1 : 1;
2699 node = dir < 0 ? node->left : node->right;
2710 const CvMat* CvDTree::get_var_importance()
2712 if( !var_importance )
2714 CvDTreeNode* node = root;
2718 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
2719 cvZero( var_importance );
2720 importance = var_importance->data.db;
2724 CvDTreeNode* parent;
2725 for( ;; node = node->left )
2727 CvDTreeSplit* split = node->split;
2729 if( !node->left || node->Tn <= pruned_tree_idx )
2732 for( ; split != 0; split = split->next )
2733 importance[split->var_idx] += split->quality;
2736 for( parent = node->parent; parent && parent->right == node;
2737 node = parent, parent = parent->parent )
2743 node = parent->right;
2747 return var_importance;
2751 void CvDTree::save( const char* filename, const char* name )
2753 CvFileStorage* fs = 0;
2755 CV_FUNCNAME( "CvDTree::save" );
2759 CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_WRITE ));
2761 CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2763 write( fs, name ? name : "my_dtree" );
2767 cvReleaseFileStorage( &fs );
2771 void CvDTree::write_train_data_params( CvFileStorage* fs )
2773 CV_FUNCNAME( "CvDTree::write_train_data_params" );
2777 int vi, vcount = data->var_count;
2779 cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
2780 cvWriteInt( fs, "var_all", data->var_all );
2781 cvWriteInt( fs, "var_count", data->var_count );
2782 cvWriteInt( fs, "ord_var_count", data->ord_var_count );
2783 cvWriteInt( fs, "cat_var_count", data->cat_var_count );
2785 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
2786 cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
2788 if( data->is_classifier )
2790 cvWriteInt( fs, "max_categories", data->params.max_categories );
2794 cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
2797 cvWriteInt( fs, "max_depth", data->params.max_depth );
2798 cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
2799 cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
2801 if( data->params.cv_folds > 1 )
2803 cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
2804 cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
2808 cvWrite( fs, "priors", data->priors );
2810 cvEndWriteStruct( fs );
2813 cvWrite( fs, "var_idx", data->var_idx );
2815 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
2817 for( vi = 0; vi < vcount; vi++ )
2818 cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
2820 cvEndWriteStruct( fs );
2822 if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
2824 CV_ASSERT( data->cat_count != 0 );
2825 cvWrite( fs, "cat_count", data->cat_count );
2826 cvWrite( fs, "cat_map", data->cat_map );
2833 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
2837 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
2838 cvWriteInt( fs, "var", split->var_idx );
2839 cvWriteReal( fs, "quality", split->quality );
2841 ci = data->get_var_type(split->var_idx);
2842 if( ci >= 0 ) // split on a categorical var
2844 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
2845 for( i = 0; i < n; i++ )
2846 to_right += DTREE_CAT_DIR(i,split->subset) > 0;
2848 // ad-hoc rule when to use inverse categorical split notation
2849 // to achieve more compact and clear representation
2850 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
2852 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
2853 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
2854 for( i = 0; i < n; i++ )
2856 int dir = DTREE_CAT_DIR(i,split->subset);
2857 if( dir*default_dir < 0 )
2858 cvWriteInt( fs, 0, i );
2860 cvEndWriteStruct( fs );
2863 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
2865 cvEndWriteStruct( fs );
2869 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
2871 CvDTreeSplit* split;
2873 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2875 cvWriteInt( fs, "depth", node->depth );
2876 cvWriteInt( fs, "sample_count", node->sample_count );
2877 cvWriteReal( fs, "value", node->value );
2879 if( data->is_classifier )
2880 cvWriteInt( fs, "norm_class_idx", node->class_idx );
2882 cvWriteInt( fs, "Tn", node->Tn );
2883 cvWriteInt( fs, "complexity", node->complexity );
2884 cvWriteReal( fs, "alpha", node->alpha );
2885 cvWriteReal( fs, "node_risk", node->node_risk );
2886 cvWriteReal( fs, "tree_risk", node->tree_risk );
2887 cvWriteReal( fs, "tree_error", node->tree_error );
2891 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
2893 for( split = node->split; split != 0; split = split->next )
2894 write_split( fs, split );
2896 cvEndWriteStruct( fs );
2899 cvEndWriteStruct( fs );
2903 void CvDTree::write_tree_nodes( CvFileStorage* fs )
2905 CV_FUNCNAME( "CvDTree::write_tree_nodes" );
2909 CvDTreeNode* node = root;
2911 // traverse the tree and save all the nodes in depth-first order
2914 CvDTreeNode* parent;
2917 write_node( fs, node );
2923 for( parent = node->parent; parent && parent->right == node;
2924 node = parent, parent = parent->parent )
2930 node = parent->right;
2937 void CvDTree::write( CvFileStorage* fs, const char* name )
2939 CV_FUNCNAME( "CvDTree::write" );
2943 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
2945 write_train_data_params( fs );
2947 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
2948 get_var_importance();
2949 cvWrite( fs, "var_importance", var_importance );
2951 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
2952 write_tree_nodes( fs );
2953 cvEndWriteStruct( fs );
2955 cvEndWriteStruct( fs );
2961 void CvDTree::load( const char* filename, const char* name )
2963 CvFileStorage* fs = 0;
2965 CV_FUNCNAME( "CvDTree::load" );
2969 CvFileNode* tree = 0;
2971 CV_CALL( fs = cvOpenFileStorage( filename, 0, CV_STORAGE_READ ));
2973 CV_ERROR( CV_StsError, "Could not open the file storage. Check the path and permissions" );
2976 tree = cvGetFileNodeByName( fs, 0, name );
2979 CvFileNode* root = cvGetRootFileNode( fs );
2980 if( root->data.seq->total > 0 )
2981 tree = (CvFileNode*)cvGetSeqElem( root->data.seq, 0 );
2988 cvReleaseFileStorage( &fs );
2992 void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
2994 CV_FUNCNAME( "CvDTree::read_train_data_params" );
2998 CvDTreeParams params;
2999 CvFileNode *tparams_node, *vartype_node;
3001 int is_classifier, vi, cat_var_count, ord_var_count;
3002 int max_split_size, tree_block_size;
3004 data = new CvDTreeTrainData;
3006 data->is_classifier = is_classifier = cvReadIntByName( fs, node, "is_classifier" ) != 0;
3007 data->var_all = cvReadIntByName( fs, node, "var_all" );
3008 data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
3009 data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
3010 data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
3012 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
3014 if( tparams_node ) // training parameters are not necessary
3016 data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
3020 data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
3024 data->params.regression_accuracy =
3025 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
3028 data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
3029 data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
3030 data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
3032 if( data->params.cv_folds > 1 )
3034 data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
3035 data->params.truncate_pruned_tree =
3036 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
3039 data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
3040 if( data->priors && !CV_IS_MAT(data->priors) )
3041 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
3044 CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
3047 if( !CV_IS_MAT(data->var_idx) ||
3048 data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
3049 data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
3050 CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
3051 CV_ERROR( CV_StsParseError,
3052 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
3054 for( vi = 0; vi < data->var_count; vi++ )
3055 if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
3056 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
3059 ////// read var type
3060 CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
3062 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
3063 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
3064 vartype_node->data.seq->total != data->var_count )
3065 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3067 cvStartReadSeq( vartype_node->data.seq, &reader );
3071 for( vi = 0; vi < data->var_count; vi++ )
3073 CvFileNode* n = (CvFileNode*)reader.ptr;
3074 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
3075 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
3076 data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
3077 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3080 ord_var_count = ~ord_var_count;
3081 if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
3082 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
3085 if( data->cat_var_count > 0 || is_classifier )
3087 int ccount, max_c_count = 0, total_c_count = 0;
3088 CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
3089 CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
3091 if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
3092 data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
3093 CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
3094 data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
3095 data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
3096 CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
3097 CV_ERROR( CV_StsParseError,
3098 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
3100 ccount = cat_var_count + is_classifier;
3102 CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
3103 data->cat_ofs->data.i[0] = 0;
3105 for( vi = 0; vi < ccount; vi++ )
3107 int val = data->cat_count->data.i[vi];
3109 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
3110 max_c_count = MAX( max_c_count, val );
3111 data->cat_ofs->data.i[vi+1] = total_c_count += val;
3114 if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
3115 CV_ERROR( CV_StsBadSize,
3116 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
3118 data->max_c_count = max_c_count;
3121 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
3122 (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
3124 tree_block_size = MAX(sizeof(CvDTreeNode)*8, max_split_size);
3125 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
3126 CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
3127 CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
3128 sizeof(CvDTreeNode), data->tree_storage ));
3129 CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
3130 max_split_size, data->tree_storage ));
3136 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3138 CvDTreeSplit* split = 0;
3140 CV_FUNCNAME( "CvDTree::read_split" );
3146 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3147 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3149 vi = cvReadIntByName( fs, fnode, "var", -1 );
3150 if( (unsigned)vi >= (unsigned)data->var_count )
3151 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3153 ci = data->get_var_type(vi);
3154 if( ci >= 0 ) // split on categorical var
3156 int i, n = data->cat_count->data.i[ci], inversed = 0;
3159 split = data->new_split_cat( vi, 0 );
3160 inseq = cvGetFileNodeByName( fs, fnode, "in" );
3163 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3166 if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
3167 CV_ERROR( CV_StsParseError,
3168 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3170 cvStartReadSeq( inseq->data.seq, &reader );
3172 for( i = 0; i < reader.seq->total; i++ )
3174 CvFileNode* inode = (CvFileNode*)reader.ptr;
3175 int val = inode->data.i;
3176 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3177 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3179 split->subset[val >> 5] |= 1 << (val & 31);
3180 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3183 // for categorical splits we do not use inversed splits,
3184 // instead we inverse the variable set in the split
3186 for( i = 0; i < (n + 31) >> 5; i++ )
3187 split->subset[i] ^= -1;
3191 CvFileNode* cmp_node;
3192 split = data->new_split_ord( vi, 0, 0, 0, 0 );
3194 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3197 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3198 split->inversed = 1;
3201 split->ord.c = (float)cvReadReal( cmp_node );
3204 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3212 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3214 CvDTreeNode* node = 0;
3216 CV_FUNCNAME( "CvDTree::read_node" );
3223 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3224 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3226 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3227 depth = cvReadIntByName( fs, fnode, "depth", -1 );
3228 if( depth != node->depth )
3229 CV_ERROR( CV_StsParseError, "incorrect node depth" );
3231 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3232 node->value = cvReadRealByName( fs, fnode, "value" );
3233 if( data->is_classifier )
3234 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3236 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3237 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3238 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3239 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3240 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3241 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3243 splits = cvGetFileNodeByName( fs, fnode, "splits" );
3247 CvDTreeSplit* last_split = 0;
3249 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3250 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3252 cvStartReadSeq( splits->data.seq, &reader );
3253 for( i = 0; i < reader.seq->total; i++ )
3255 CvDTreeSplit* split;
3256 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3258 node->split = last_split = split;
3260 last_split = last_split->next = split;
3262 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3272 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3274 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3280 CvDTreeNode* parent = &_root;
3282 parent->left = parent->right = parent->parent = 0;
3284 cvStartReadSeq( fnode->data.seq, &reader );
3286 for( i = 0; i < reader.seq->total; i++ )
3290 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3292 parent->left = node;
3294 parent->right = node;
3299 while( parent && parent->right )
3300 parent = parent->parent;
3303 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3312 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3314 CV_FUNCNAME( "CvDTree::read" );
3318 CvFileNode* tree_nodes;
3321 read_train_data_params( fs, fnode );
3323 tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
3324 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3325 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3327 pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
3329 read_tree_nodes( fs, tree_nodes );
3330 get_var_importance(); // recompute variable importance