1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
46 static const float ord_nan = FLT_MAX*0.5f;
47 static const int min_block_size = 1 << 16;
48 static const int block_size_delta = 1 << 10;
50 CvDTreeTrainData::CvDTreeTrainData()
52 var_idx = var_type = cat_count = cat_ofs = cat_map =
53 priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
54 tree_storage = temp_storage = 0;
60 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
61 const CvMat* _responses, const CvMat* _var_idx,
62 const CvMat* _sample_idx, const CvMat* _var_type,
63 const CvMat* _missing_mask, const CvDTreeParams& _params,
64 bool _shared, bool _add_labels )
66 var_idx = var_type = cat_count = cat_ofs = cat_map =
67 priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
69 tree_storage = temp_storage = 0;
71 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
72 _var_type, _missing_mask, _params, _shared, _add_labels );
76 CvDTreeTrainData::~CvDTreeTrainData()
82 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
86 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
93 if( params.max_categories < 2 )
94 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
95 params.max_categories = MIN( params.max_categories, 15 );
97 if( params.max_depth < 0 )
98 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
99 params.max_depth = MIN( params.max_depth, 25 );
101 params.min_sample_count = MAX(params.min_sample_count,1);
103 if( params.cv_folds < 0 )
104 CV_ERROR( CV_StsOutOfRange,
105 "params.cv_folds should be =0 (the tree is not pruned) "
106 "or n>0 (tree is pruned using n-fold cross-validation)" );
108 if( params.cv_folds == 1 )
111 if( params.regression_accuracy < 0 )
112 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
121 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
122 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
123 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
125 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])
126 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )
127 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )
129 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
130 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
132 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
133 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
134 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
135 bool _shared, bool _add_labels, bool _update_data )
137 CvMat* sample_indices = 0;
138 CvMat* var_type0 = 0;
141 CvPair16u32s* pair16u32s_ptr = 0;
142 CvDTreeTrainData* data = 0;
145 unsigned short* udst = 0;
148 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
152 int sample_all = 0, r_type = 0, cv_n;
153 int total_c_count = 0;
154 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
155 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
158 const int *sidx = 0, *vidx = 0;
160 if( _update_data && data_root )
162 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
163 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
165 // compare new and old train data
166 if( !(data->var_count == var_count &&
167 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
168 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
169 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
170 CV_ERROR( CV_StsBadArg,
171 "The new training data must have the same types and the input and output variables "
172 "and the same categories for categorical variables" );
174 cvReleaseMat( &priors );
175 cvReleaseMat( &priors_mult );
176 cvReleaseMat( &buf );
177 cvReleaseMat( &direction );
178 cvReleaseMat( &split_buf );
179 cvReleaseMemStorage( &temp_storage );
181 priors = data->priors; data->priors = 0;
182 priors_mult = data->priors_mult; data->priors_mult = 0;
183 buf = data->buf; data->buf = 0;
184 buf_count = data->buf_count; buf_size = data->buf_size;
185 sample_count = data->sample_count;
187 direction = data->direction; data->direction = 0;
188 split_buf = data->split_buf; data->split_buf = 0;
189 temp_storage = data->temp_storage; data->temp_storage = 0;
190 nv_heap = data->nv_heap; cv_heap = data->cv_heap;
192 data_root = new_node( 0, sample_count, 0, 0 );
201 CV_CALL( set_params( _params ));
203 // check parameter types and sizes
204 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
206 train_data = _train_data;
207 responses = _responses;
209 if( _tflag == CV_ROW_SAMPLE )
211 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
214 ms_step = _missing_mask->step, mv_step = 1;
218 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
221 mv_step = _missing_mask->step, ms_step = 1;
225 sample_count = sample_all;
230 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
231 sidx = sample_indices->data.i;
232 sample_count = sample_indices->rows + sample_indices->cols - 1;
237 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
238 vidx = var_idx->data.i;
239 var_count = var_idx->rows + var_idx->cols - 1;
243 if ( sample_count < 65536 )
246 if( !CV_IS_MAT(_responses) ||
247 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
248 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
249 (_responses->rows != 1 && _responses->cols != 1) ||
250 _responses->rows + _responses->cols - 1 != sample_all )
251 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
252 "floating-point vector containing as many elements as "
253 "the total number of samples in the training data matrix" );
256 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
258 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
264 is_classifier = r_type == CV_VAR_CATEGORICAL;
266 // step 0. calc the number of categorical vars
267 for( vi = 0; vi < var_count; vi++ )
269 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
270 cat_var_count++ : ord_var_count--;
273 ord_var_count = ~ord_var_count;
274 cv_n = params.cv_folds;
275 // set the two last elements of var_type array to be able
276 // to locate responses and cross-validation labels using
277 // the corresponding get_* functions.
278 var_type->data.i[var_count] = cat_var_count;
279 var_type->data.i[var_count+1] = cat_var_count+1;
281 // in case of single ordered predictor we need dummy cv_labels
282 // for safe split_node_data() operation
283 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
285 work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
286 + (have_labels ? 1 : 0); // for cv_labels
288 buf_size = (work_var_count + 1 /*for sample_indices*/) * sample_count;
290 buf_count = shared ? 2 : 1;
294 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
295 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
299 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
300 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
303 size = is_classifier ? (cat_var_count+1) : cat_var_count;
304 size = !size ? 1 : size;
305 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
306 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
308 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
309 size = !size ? 1 : size;
310 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
312 // now calculate the maximum size of split,
313 // create memory storage that will keep nodes and splits of the decision tree
314 // allocate root node and the buffer for the whole training data
315 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
316 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
317 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
318 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
319 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
320 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
322 nv_size = var_count*sizeof(int);
323 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
325 temp_block_size = nv_size;
329 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
330 CV_ERROR( CV_StsOutOfRange,
331 "The many folds in cross-validation for such a small dataset" );
333 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
334 temp_block_size = MAX(temp_block_size, cv_size);
337 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
338 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
339 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
341 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
343 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
350 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
351 if (is_buf_16u && (cat_var_count || is_classifier))
352 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
354 // transform the training data to convenient representation
355 for( vi = 0; vi <= var_count; vi++ )
358 const uchar* mask = 0;
359 int m_step = 0, step;
360 const int* idata = 0;
361 const float* fdata = 0;
364 if( vi < var_count ) // analyze i-th input variable
366 int vi0 = vidx ? vidx[vi] : vi;
367 ci = get_var_type(vi);
368 step = ds_step; m_step = ms_step;
369 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
370 idata = _train_data->data.i + vi0*dv_step;
372 fdata = _train_data->data.fl + vi0*dv_step;
374 mask = _missing_mask->data.ptr + vi0*mv_step;
376 else // analyze _responses
379 step = CV_IS_MAT_CONT(_responses->type) ?
380 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
381 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
382 idata = _responses->data.i;
384 fdata = _responses->data.fl;
387 if( (vi < var_count && ci>=0) ||
388 (vi == var_count && is_classifier) ) // process categorical variable or response
390 int c_count, prev_label;
394 udst = (unsigned short*)(buf->data.s + vi*sample_count);
396 idst = buf->data.i + vi*sample_count;
399 for( i = 0; i < sample_count; i++ )
401 int val = INT_MAX, si = sidx ? sidx[i] : i;
402 if( !mask || !mask[si*m_step] )
405 val = idata[si*step];
408 float t = fdata[si*step];
410 if( fabs(t - val) > FLT_EPSILON )
412 sprintf( err, "%d-th value of %d-th (categorical) "
413 "variable is not an integer", i, vi );
414 CV_ERROR( CV_StsBadArg, err );
420 sprintf( err, "%d-th value of %d-th (categorical) "
421 "variable is too large", i, vi );
422 CV_ERROR( CV_StsBadArg, err );
429 pair16u32s_ptr[i].u = udst + i;
430 pair16u32s_ptr[i].i = _idst + i;
435 int_ptr[i] = idst + i;
439 c_count = num_valid > 0;
442 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
443 // count the categories
444 for( i = 1; i < num_valid; i++ )
445 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
450 icvSortIntPtr( int_ptr, sample_count, 0 );
451 // count the categories
452 for( i = 1; i < num_valid; i++ )
453 c_count += *int_ptr[i] != *int_ptr[i-1];
457 max_c_count = MAX( max_c_count, c_count );
458 cat_count->data.i[ci] = c_count;
459 cat_ofs->data.i[ci] = total_c_count;
461 // resize cat_map, if need
462 if( cat_map->cols < total_c_count + c_count )
465 CV_CALL( cat_map = cvCreateMat( 1,
466 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
467 for( i = 0; i < total_c_count; i++ )
468 cat_map->data.i[i] = tmp_map->data.i[i];
469 cvReleaseMat( &tmp_map );
472 c_map = cat_map->data.i + total_c_count;
473 total_c_count += c_count;
478 // compact the class indices and build the map
479 prev_label = ~*pair16u32s_ptr[0].i;
480 for( i = 0; i < num_valid; i++ )
482 int cur_label = *pair16u32s_ptr[i].i;
483 if( cur_label != prev_label )
484 c_map[++c_count] = prev_label = cur_label;
485 *pair16u32s_ptr[i].u = (unsigned short)c_count;
487 // replace labels for missing values with -1
488 for( ; i < sample_count; i++ )
489 *pair16u32s_ptr[i].u = 65535;
493 // compact the class indices and build the map
494 prev_label = ~*int_ptr[0];
495 for( i = 0; i < num_valid; i++ )
497 int cur_label = *int_ptr[i];
498 if( cur_label != prev_label )
499 c_map[++c_count] = prev_label = cur_label;
500 *int_ptr[i] = c_count;
502 // replace labels for missing values with -1
503 for( ; i < sample_count; i++ )
507 else if( ci < 0 ) // process ordered variable
510 udst = (unsigned short*)(buf->data.s + vi*sample_count);
512 idst = buf->data.i + vi*sample_count;
514 for( i = 0; i < sample_count; i++ )
517 int si = sidx ? sidx[i] : i;
518 if( !mask || !mask[si*m_step] )
521 val = (float)idata[si*step];
523 val = fdata[si*step];
525 if( fabs(val) >= ord_nan )
527 sprintf( err, "%d-th value of %d-th (ordered) "
528 "variable (=%g) is too large", i, vi, val );
529 CV_ERROR( CV_StsBadArg, err );
534 udst[i] = (unsigned short)i;
541 icvSortUShAux( udst, num_valid, _fdst);
543 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );
547 data_root->set_num_valid(vi, num_valid);
552 udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);
554 idst = buf->data.i + work_var_count*sample_count;
556 for (i = 0; i < sample_count; i++)
559 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
561 idst[i] = sidx ? sidx[i] : i;
566 unsigned short* udst = 0;
572 udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
573 for( i = vi = 0; i < sample_count; i++ )
575 udst[i] = (unsigned short)vi++;
576 vi &= vi < cv_n ? -1 : 0;
579 for( i = 0; i < sample_count; i++ )
581 int a = cvRandInt(r) % sample_count;
582 int b = cvRandInt(r) % sample_count;
583 unsigned short unsh = (unsigned short)vi;
584 CV_SWAP( udst[a], udst[b], unsh );
589 idst = buf->data.i + (get_work_var_count()-1)*sample_count;
590 for( i = vi = 0; i < sample_count; i++ )
593 vi &= vi < cv_n ? -1 : 0;
596 for( i = 0; i < sample_count; i++ )
598 int a = cvRandInt(r) % sample_count;
599 int b = cvRandInt(r) % sample_count;
600 CV_SWAP( idst[a], idst[b], vi );
606 cat_map->cols = MAX( total_c_count, 1 );
608 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
609 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
610 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
612 have_priors = is_classifier && params.priors;
615 int m = get_num_classes();
617 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
618 for( i = 0; i < m; i++ )
620 double val = have_priors ? params.priors[i] : 1.;
622 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
623 priors->data.db[i] = val;
629 cvScale( priors, priors, 1./sum );
631 CV_CALL( priors_mult = cvCloneMat( priors ));
632 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
636 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
637 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
649 cvFree( &pair16u32s_ptr);
650 cvReleaseMat( &var_type0 );
651 cvReleaseMat( &sample_indices );
652 cvReleaseMat( &tmp_map );
655 void CvDTreeTrainData::do_responses_copy()
657 responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
658 cvCopy( responses, responses_copy);
659 responses = responses_copy;
662 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
664 CvDTreeNode* root = 0;
665 CvMat* isubsample_idx = 0;
666 CvMat* subsample_co = 0;
668 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
673 CV_ERROR( CV_StsError, "No training data has been set" );
676 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
678 if( !isubsample_idx )
680 // make a copy of the root node
683 root = new_node( 0, 1, 0, 0 );
686 root->num_valid = temp.num_valid;
687 if( root->num_valid )
689 for( i = 0; i < var_count; i++ )
690 root->num_valid[i] = data_root->num_valid[i];
692 root->cv_Tn = temp.cv_Tn;
693 root->cv_node_risk = temp.cv_node_risk;
694 root->cv_node_error = temp.cv_node_error;
698 int* sidx = isubsample_idx->data.i;
699 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
700 int* co, cur_ofs = 0;
702 int work_var_count = get_work_var_count();
703 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
705 root = new_node( 0, count, 1, 0 );
707 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
708 cvZero( subsample_co );
709 co = subsample_co->data.i;
710 for( i = 0; i < count; i++ )
712 for( i = 0; i < sample_count; i++ )
723 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
724 for( vi = 0; vi < work_var_count; vi++ )
726 int ci = get_var_type(vi);
728 if( ci >= 0 || vi >= var_count )
731 const int* src = get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
735 unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
736 vi*sample_count + root->offset);
737 for( i = 0; i < count; i++ )
739 int val = src[sidx[i]];
740 udst[i] = (unsigned short)val;
741 num_valid += val >= 0;
746 int* idst = buf->data.i + root->buf_idx*buf->cols +
747 vi*sample_count + root->offset;
748 for( i = 0; i < count; i++ )
750 int val = src[sidx[i]];
752 num_valid += val >= 0;
757 root->set_num_valid(vi, num_valid);
761 int *src_idx_buf = (int*)(uchar*)inn_buf;
762 float *src_val_buf = (float*)(src_idx_buf + sample_count);
763 int* sample_indices_buf = (int*)(src_val_buf + sample_count);
764 const int* src_idx = 0;
765 const float* src_val = 0;
766 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
767 int j = 0, idx, count_i;
768 int num_valid = data_root->get_num_valid(vi);
772 unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
773 vi*sample_count + data_root->offset);
774 for( i = 0; i < num_valid; i++ )
779 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
780 udst_idx[j] = (unsigned short)cur_ofs;
783 root->set_num_valid(vi, j);
785 for( ; i < sample_count; i++ )
790 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
791 udst_idx[j] = (unsigned short)cur_ofs;
796 int* idst_idx = buf->data.i + root->buf_idx*buf->cols +
797 vi*sample_count + root->offset;
798 for( i = 0; i < num_valid; i++ )
803 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
804 idst_idx[j] = cur_ofs;
807 root->set_num_valid(vi, j);
809 for( ; i < sample_count; i++ )
814 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
815 idst_idx[j] = cur_ofs;
820 // sample indices subsampling
821 const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
824 unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
825 get_work_var_count()*sample_count + root->offset);
826 for (i = 0; i < count; i++)
827 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
831 int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols +
832 get_work_var_count()*sample_count + root->offset;
833 for (i = 0; i < count; i++)
834 sample_idx_dst[i] = sample_idx_src[sidx[i]];
840 cvReleaseMat( &isubsample_idx );
841 cvReleaseMat( &subsample_co );
847 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
848 float* values, uchar* missing,
849 float* responses, bool get_class_idx )
851 CvMat* subsample_idx = 0;
852 CvMat* subsample_co = 0;
854 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
858 int i, vi, total = sample_count, count = total, cur_ofs = 0;
862 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
865 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
866 sidx = subsample_idx->data.i;
867 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
868 co = subsample_co->data.i;
869 cvZero( subsample_co );
870 count = subsample_idx->cols + subsample_idx->rows - 1;
871 for( i = 0; i < count; i++ )
873 for( i = 0; i < total; i++ )
875 int count_i = co[i*2];
878 co[i*2+1] = cur_ofs*var_count;
885 memset( missing, 1, count*var_count );
887 for( vi = 0; vi < var_count; vi++ )
889 int ci = get_var_type(vi);
890 if( ci >= 0 ) // categorical
892 float* dst = values + vi;
893 uchar* m = missing ? missing + vi : 0;
894 const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
896 for( i = 0; i < count; i++, dst += var_count )
898 int idx = sidx ? sidx[i] : i;
903 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
910 float* dst = values + vi;
911 uchar* m = missing ? missing + vi : 0;
912 int count1 = data_root->get_num_valid(vi);
913 float *src_val_buf = (float*)(uchar*)inn_buf;
914 int* src_idx_buf = (int*)(src_val_buf + sample_count);
915 int* sample_indices_buf = src_idx_buf + sample_count;
916 const float *src_val = 0;
917 const int* src_idx = 0;
918 get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
920 for( i = 0; i < count1; i++ )
922 int idx = src_idx[i];
927 cur_ofs = co[idx*2+1];
930 cur_ofs = idx*var_count;
933 float val = src_val[i];
934 for( ; count_i > 0; count_i--, cur_ofs += var_count )
950 const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
951 for( i = 0; i < count; i++ )
953 int idx = sidx ? sidx[i] : i;
954 int val = get_class_idx ? src[idx] :
955 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
956 responses[i] = (float)val;
961 float* val_buf = (float*)(uchar*)inn_buf;
962 int* sample_idx_buf = (int*)(val_buf + sample_count);
963 const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
964 for( i = 0; i < count; i++ )
966 int idx = sidx ? sidx[i] : i;
967 responses[i] = _values[idx];
974 cvReleaseMat( &subsample_idx );
975 cvReleaseMat( &subsample_co );
979 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
980 int storage_idx, int offset )
982 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
984 node->sample_count = count;
985 node->depth = parent ? parent->depth + 1 : 0;
986 node->parent = parent;
987 node->left = node->right = 0;
993 node->buf_idx = storage_idx;
994 node->offset = offset;
996 node->num_valid = (int*)cvSetNew( nv_heap );
999 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1000 node->complexity = 0;
1002 if( params.cv_folds > 0 && cv_heap )
1004 int cv_n = params.cv_folds;
1006 node->cv_Tn = (int*)cvSetNew( cv_heap );
1007 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1008 node->cv_node_error = node->cv_node_risk + cv_n;
1014 node->cv_node_risk = 0;
1015 node->cv_node_error = 0;
1022 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1023 int split_point, int inversed, float quality )
1025 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1026 split->var_idx = vi;
1027 split->condensed_idx = INT_MIN;
1028 split->ord.c = cmp_val;
1029 split->ord.split_point = split_point;
1030 split->inversed = inversed;
1031 split->quality = quality;
1038 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1040 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1041 int i, n = (max_c_count + 31)/32;
1043 split->var_idx = vi;
1044 split->condensed_idx = INT_MIN;
1045 split->inversed = 0;
1046 split->quality = quality;
1047 for( i = 0; i < n; i++ )
1048 split->subset[i] = 0;
1055 void CvDTreeTrainData::free_node( CvDTreeNode* node )
1057 CvDTreeSplit* split = node->split;
1058 free_node_data( node );
1061 CvDTreeSplit* next = split->next;
1062 cvSetRemoveByPtr( split_heap, split );
1066 cvSetRemoveByPtr( node_heap, node );
1070 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1072 if( node->num_valid )
1074 cvSetRemoveByPtr( nv_heap, node->num_valid );
1075 node->num_valid = 0;
1077 // do not free cv_* fields, as all the cross-validation related data is released at once.
1081 void CvDTreeTrainData::free_train_data()
1083 cvReleaseMat( &counts );
1084 cvReleaseMat( &buf );
1085 cvReleaseMat( &direction );
1086 cvReleaseMat( &split_buf );
1087 cvReleaseMemStorage( &temp_storage );
1088 cvReleaseMat( &responses_copy );
1089 cv_heap = nv_heap = 0;
1093 void CvDTreeTrainData::clear()
1097 cvReleaseMemStorage( &tree_storage );
1099 cvReleaseMat( &var_idx );
1100 cvReleaseMat( &var_type );
1101 cvReleaseMat( &cat_count );
1102 cvReleaseMat( &cat_ofs );
1103 cvReleaseMat( &cat_map );
1104 cvReleaseMat( &priors );
1105 cvReleaseMat( &priors_mult );
1107 node_heap = split_heap = 0;
1109 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1110 have_labels = have_priors = is_classifier = false;
1112 buf_count = buf_size = 0;
1121 int CvDTreeTrainData::get_num_classes() const
1123 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1127 int CvDTreeTrainData::get_var_type(int vi) const
1129 return var_type->data.i[vi];
1132 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
1133 const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
1135 int vidx = var_idx ? var_idx->data.i[vi] : vi;
1136 int node_sample_count = n->sample_count;
1137 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1139 const int* sample_indices = get_sample_indices(n, sample_indices_buf);
1142 *sorted_indices = buf->data.i + n->buf_idx*buf->cols +
1143 vi*sample_count + n->offset;
1145 const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
1146 vi*sample_count + n->offset );
1147 for( int i = 0; i < node_sample_count; i++ )
1148 sorted_indices_buf[i] = short_indices[i];
1149 *sorted_indices = sorted_indices_buf;
1152 if( tflag == CV_ROW_SAMPLE )
1154 for( int i = 0; i < node_sample_count &&
1155 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1157 int idx = (*sorted_indices)[i];
1158 idx = sample_indices[idx];
1159 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1163 for( int i = 0; i < node_sample_count &&
1164 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1166 int idx = (*sorted_indices)[i];
1167 idx = sample_indices[idx];
1168 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1171 *ord_values = ord_values_buf;
1175 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
1178 return get_cat_var_data( n, var_count, labels_buf);
1182 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
1184 return get_cat_var_data( n, get_work_var_count(), indices_buf );
1187 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
1189 int sample_count = n->sample_count;
1190 int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
1191 const int* indices = get_sample_indices(n, sample_indices_buf);
1193 for( int i = 0; i < sample_count &&
1194 (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1196 int idx = indices[i];
1197 values_buf[i] = *(responses->data.fl + idx * r_step);
1204 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
1207 return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
1212 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
1214 const int* cat_values = 0;
1216 cat_values = buf->data.i + n->buf_idx*buf->cols +
1217 vi*sample_count + n->offset;
1219 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
1220 vi*sample_count + n->offset);
1221 for( int i = 0; i < n->sample_count; i++ )
1222 cat_values_buf[i] = short_values[i];
1223 cat_values = cat_values_buf;
1229 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1231 int idx = n->buf_idx + 1;
1232 if( idx >= buf_count )
1233 idx = shared ? 1 : 0;
1238 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1240 CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1244 int vi, vcount = var_count;
1246 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1247 cvWriteInt( fs, "var_all", var_all );
1248 cvWriteInt( fs, "var_count", var_count );
1249 cvWriteInt( fs, "ord_var_count", ord_var_count );
1250 cvWriteInt( fs, "cat_var_count", cat_var_count );
1252 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1253 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1257 cvWriteInt( fs, "max_categories", params.max_categories );
1261 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1264 cvWriteInt( fs, "max_depth", params.max_depth );
1265 cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1266 cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1268 if( params.cv_folds > 1 )
1270 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1271 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1275 cvWrite( fs, "priors", priors );
1277 cvEndWriteStruct( fs );
1280 cvWrite( fs, "var_idx", var_idx );
1282 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1284 for( vi = 0; vi < vcount; vi++ )
1285 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1287 cvEndWriteStruct( fs );
1289 if( cat_count && (cat_var_count > 0 || is_classifier) )
1291 CV_ASSERT( cat_count != 0 );
1292 cvWrite( fs, "cat_count", cat_count );
1293 cvWrite( fs, "cat_map", cat_map );
1300 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1302 CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1306 CvFileNode *tparams_node, *vartype_node;
1308 int vi, max_split_size, tree_block_size;
1310 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1311 var_all = cvReadIntByName( fs, node, "var_all" );
1312 var_count = cvReadIntByName( fs, node, "var_count", var_all );
1313 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1314 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1316 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1318 if( tparams_node ) // training parameters are not necessary
1320 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1324 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1328 params.regression_accuracy =
1329 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1332 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1333 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1334 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1336 if( params.cv_folds > 1 )
1338 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1339 params.truncate_pruned_tree =
1340 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1343 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1346 if( !CV_IS_MAT(priors) )
1347 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1348 priors_mult = cvCloneMat( priors );
1352 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1355 if( !CV_IS_MAT(var_idx) ||
1356 (var_idx->cols != 1 && var_idx->rows != 1) ||
1357 var_idx->cols + var_idx->rows - 1 != var_count ||
1358 CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1359 CV_ERROR( CV_StsParseError,
1360 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1362 for( vi = 0; vi < var_count; vi++ )
1363 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1364 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1367 ////// read var type
1368 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1372 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1374 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1375 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1378 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1379 vartype_node->data.seq->total != var_count )
1380 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1382 cvStartReadSeq( vartype_node->data.seq, &reader );
1384 for( vi = 0; vi < var_count; vi++ )
1386 CvFileNode* n = (CvFileNode*)reader.ptr;
1387 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1388 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1389 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1390 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1393 var_type->data.i[var_count] = cat_var_count;
1395 ord_var_count = ~ord_var_count;
1396 if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1397 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1400 if( cat_var_count > 0 || is_classifier )
1402 int ccount, total_c_count = 0;
1403 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1404 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1406 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1407 (cat_count->cols != 1 && cat_count->rows != 1) ||
1408 CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1409 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1410 (cat_map->cols != 1 && cat_map->rows != 1) ||
1411 CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1412 CV_ERROR( CV_StsParseError,
1413 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1415 ccount = cat_var_count + is_classifier;
1417 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1418 cat_ofs->data.i[0] = 0;
1421 for( vi = 0; vi < ccount; vi++ )
1423 int val = cat_count->data.i[vi];
1425 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1426 max_c_count = MAX( max_c_count, val );
1427 cat_ofs->data.i[vi+1] = total_c_count += val;
1430 if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1431 CV_ERROR( CV_StsBadSize,
1432 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1435 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1436 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1438 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1439 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1440 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1441 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1442 sizeof(CvDTreeNode), tree_storage ));
1443 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1444 max_split_size, tree_storage ));
1449 /////////////////////// Decision Tree /////////////////////////
1455 default_model_name = "my_tree";
1461 void CvDTree::clear()
1463 cvReleaseMat( &var_importance );
1473 pruned_tree_idx = -1;
1483 const CvDTreeNode* CvDTree::get_root() const
1489 int CvDTree::get_pruned_tree_idx() const
1491 return pruned_tree_idx;
1495 CvDTreeTrainData* CvDTree::get_data()
1501 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1502 const CvMat* _responses, const CvMat* _var_idx,
1503 const CvMat* _sample_idx, const CvMat* _var_type,
1504 const CvMat* _missing_mask, CvDTreeParams _params )
1506 bool result = false;
1508 CV_FUNCNAME( "CvDTree::train" );
1513 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1514 _var_idx, _sample_idx, _var_type,
1515 _missing_mask, _params, false );
1516 CV_CALL( result = do_train(0) );
1523 bool CvDTree::train( const Mat& _train_data, int _tflag,
1524 const Mat& _responses, const Mat& _var_idx,
1525 const Mat& _sample_idx, const Mat& _var_type,
1526 const Mat& _missing_mask, CvDTreeParams _params )
1528 CvMat tdata = _train_data, responses = _responses, vidx=_var_idx,
1529 sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask;
1530 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1531 vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1535 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1537 bool result = false;
1539 CV_FUNCNAME( "CvDTree::train" );
1543 const CvMat* values = _data->get_values();
1544 const CvMat* response = _data->get_responses();
1545 const CvMat* missing = _data->get_missing();
1546 const CvMat* var_types = _data->get_var_types();
1547 const CvMat* train_sidx = _data->get_train_sample_idx();
1548 const CvMat* var_idx = _data->get_var_idx();
1550 CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1551 train_sidx, var_types, missing, _params ) );
1558 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1560 bool result = false;
1562 CV_FUNCNAME( "CvDTree::train" );
1568 data->shared = true;
1569 CV_CALL( result = do_train(_subsample_idx));
1577 bool CvDTree::do_train( const CvMat* _subsample_idx )
1579 bool result = false;
1581 CV_FUNCNAME( "CvDTree::do_train" );
1585 root = data->subsample_data( _subsample_idx );
1587 CV_CALL( try_split_node(root));
1589 if( data->params.cv_folds > 0 )
1590 CV_CALL( prune_cv());
1593 data->free_train_data();
1603 void CvDTree::try_split_node( CvDTreeNode* node )
1605 CvDTreeSplit* best_split = 0;
1606 int i, n = node->sample_count, vi;
1607 bool can_split = true;
1608 double quality_scale;
1610 calc_node_value( node );
1612 if( node->sample_count <= data->params.min_sample_count ||
1613 node->depth >= data->params.max_depth )
1616 if( can_split && data->is_classifier )
1618 // check if we have a "pure" node,
1619 // we assume that cls_count is filled by calc_node_value()
1620 int* cls_count = data->counts->data.i;
1621 int nz = 0, m = data->get_num_classes();
1622 for( i = 0; i < m; i++ )
1623 nz += cls_count[i] != 0;
1624 if( nz == 1 ) // there is only one class
1627 else if( can_split )
1629 if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1635 best_split = find_best_split(node);
1636 // TODO: check the split quality ...
1637 node->split = best_split;
1639 if( !can_split || !best_split )
1641 data->free_node_data(node);
1645 quality_scale = calc_node_dir( node );
1646 if( data->params.use_surrogates )
1648 // find all the surrogate splits
1649 // and sort them by their similarity to the primary one
1650 for( vi = 0; vi < data->var_count; vi++ )
1652 CvDTreeSplit* split;
1653 int ci = data->get_var_type(vi);
1655 if( vi == best_split->var_idx )
1659 split = find_surrogate_split_cat( node, vi );
1661 split = find_surrogate_split_ord( node, vi );
1666 CvDTreeSplit* prev_split = node->split;
1667 split->quality = (float)(split->quality*quality_scale);
1669 while( prev_split->next &&
1670 prev_split->next->quality > split->quality )
1671 prev_split = prev_split->next;
1672 split->next = prev_split->next;
1673 prev_split->next = split;
1677 split_node_data( node );
1678 try_split_node( node->left );
1679 try_split_node( node->right );
1683 // calculate direction (left(-1),right(1),missing(0))
1684 // for each sample using the best split
1685 // the function returns scale coefficients for surrogate split quality factors.
1686 // the scale is applied to normalize surrogate split quality relatively to the
1687 // best (primary) split quality. That is, if a surrogate split is absolutely
1688 // identical to the primary split, its quality will be set to the maximum value =
1689 // quality of the primary split; otherwise, it will be lower.
1690 // besides, the function compute node->maxlr,
1691 // minimum possible quality (w/o considering the above mentioned scale)
1692 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1693 // are not discarded.
1694 double CvDTree::calc_node_dir( CvDTreeNode* node )
1696 char* dir = (char*)data->direction->data.ptr;
1697 int i, n = node->sample_count, vi = node->split->var_idx;
1700 assert( !node->split->inversed );
1702 if( data->get_var_type(vi) >= 0 ) // split on categorical var
1704 cv::AutoBuffer<int> inn_buf(n*(data->have_priors ? 1 : 2));
1705 int* labels_buf = (int*)inn_buf;
1706 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1707 const int* subset = node->split->subset;
1708 if( !data->have_priors )
1710 int sum = 0, sum_abs = 0;
1712 for( i = 0; i < n; i++ )
1714 int idx = labels[i];
1715 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1716 CV_DTREE_CAT_DIR(idx,subset) : 0;
1717 sum += d; sum_abs += d & 1;
1721 R = (sum_abs + sum) >> 1;
1722 L = (sum_abs - sum) >> 1;
1726 const double* priors = data->priors_mult->data.db;
1727 double sum = 0, sum_abs = 0;
1728 int* responses_buf = labels_buf + n;
1729 const int* responses = data->get_class_labels(node, responses_buf);
1731 for( i = 0; i < n; i++ )
1733 int idx = labels[i];
1734 double w = priors[responses[i]];
1735 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1736 sum += d*w; sum_abs += (d & 1)*w;
1740 R = (sum_abs + sum) * 0.5;
1741 L = (sum_abs - sum) * 0.5;
1744 else // split on ordered var
1746 int split_point = node->split->ord.split_point;
1747 int n1 = node->get_num_valid(vi);
1748 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
1749 float* val_buf = (float*)(uchar*)inn_buf;
1750 int* sorted_buf = (int*)(val_buf + n);
1751 int* sample_idx_buf = sorted_buf + n;
1752 const float* val = 0;
1753 const int* sorted = 0;
1754 data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
1756 assert( 0 <= split_point && split_point < n1-1 );
1758 if( !data->have_priors )
1760 for( i = 0; i <= split_point; i++ )
1761 dir[sorted[i]] = (char)-1;
1762 for( ; i < n1; i++ )
1763 dir[sorted[i]] = (char)1;
1765 dir[sorted[i]] = (char)0;
1768 R = n1 - split_point + 1;
1772 const double* priors = data->priors_mult->data.db;
1773 int* responses_buf = sample_idx_buf + n;
1774 const int* responses = data->get_class_labels(node, responses_buf);
1777 for( i = 0; i <= split_point; i++ )
1779 int idx = sorted[i];
1780 double w = priors[responses[idx]];
1781 dir[idx] = (char)-1;
1785 for( ; i < n1; i++ )
1787 int idx = sorted[i];
1788 double w = priors[responses[idx]];
1794 dir[sorted[i]] = (char)0;
1797 node->maxlr = MAX( L, R );
1798 return node->split->quality/(L + R);
1805 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
1809 splitSize = tree->get_data()->split_heap->elem_size;
1811 bestSplit = (CvDTreeSplit*)(new char[splitSize]);
1812 memset((CvDTreeSplit*)bestSplit, 0, splitSize);
1813 bestSplit->quality = -1;
1814 bestSplit->condensed_idx = INT_MIN;
1815 split = (CvDTreeSplit*)(new char[splitSize]);
1816 memset((CvDTreeSplit*)split, 0, splitSize);
1817 //haveSplit = false;
1820 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
1824 splitSize = tree->get_data()->split_heap->elem_size;
1826 bestSplit = (CvDTreeSplit*)(new char[splitSize]);
1827 memcpy((CvDTreeSplit*)(bestSplit), (const CvDTreeSplit*)finder.bestSplit, splitSize);
1828 split = (CvDTreeSplit*)(new char[splitSize]);
1829 memset((CvDTreeSplit*)split, 0, splitSize);
1832 void DTreeBestSplitFinder::operator()(const BlockedRange& range)
1834 int vi, vi1 = range.begin(), vi2 = range.end();
1835 int n = node->sample_count;
1836 CvDTreeTrainData* data = tree->get_data();
1837 AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
1839 for( vi = vi1; vi < vi2; vi++ )
1842 int ci = data->get_var_type(vi);
1843 if( node->get_num_valid(vi) <= 1 )
1846 if( data->is_classifier )
1849 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1851 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1856 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1858 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1861 if( res && bestSplit->quality < split->quality )
1862 memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );
1866 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
1868 if( bestSplit->quality < rhs.bestSplit->quality )
1869 memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)rhs.bestSplit, splitSize );
1874 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1876 DTreeBestSplitFinder finder( this, node );
1878 cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
1880 CvDTreeSplit *bestSplit = data->new_split_cat( 0, -1.0f );
1881 memcpy( bestSplit, finder.bestSplit, finder.splitSize );
1886 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1887 float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
1889 const float epsilon = FLT_EPSILON*2;
1890 int n = node->sample_count;
1891 int n1 = node->get_num_valid(vi);
1892 int m = data->get_num_classes();
1894 int base_size = 2*m*sizeof(int);
1895 cv::AutoBuffer<uchar> inn_buf(base_size);
1897 inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
1898 uchar* base_buf = (uchar*)inn_buf;
1899 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1900 float* values_buf = (float*)ext_buf;
1901 int* sorted_indices_buf = (int*)(values_buf + n);
1902 int* sample_indices_buf = sorted_indices_buf + n;
1903 const float* values = 0;
1904 const int* sorted_indices = 0;
1905 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
1906 &sorted_indices, sample_indices_buf );
1907 int* responses_buf = sample_indices_buf + n;
1908 const int* responses = data->get_class_labels( node, responses_buf );
1910 const int* rc0 = data->counts->data.i;
1911 int* lc = (int*)base_buf;
1914 double lsum2 = 0, rsum2 = 0, best_val = init_quality;
1915 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1917 // init arrays of class instance counters on both sides of the split
1918 for( i = 0; i < m; i++ )
1924 // compensate for missing values
1925 for( i = n1; i < n; i++ )
1927 rc[responses[sorted_indices[i]]]--;
1934 for( i = 0; i < m; i++ )
1935 rsum2 += (double)rc[i]*rc[i];
1937 for( i = 0; i < n1 - 1; i++ )
1939 int idx = responses[sorted_indices[i]];
1942 lv = lc[idx]; rv = rc[idx];
1945 lc[idx] = lv + 1; rc[idx] = rv - 1;
1947 if( values[i] + epsilon < values[i+1] )
1949 double val = (lsum2*R + rsum2*L)/((double)L*R);
1950 if( best_val < val )
1960 double L = 0, R = 0;
1961 for( i = 0; i < m; i++ )
1963 double wv = rc[i]*priors[i];
1968 for( i = 0; i < n1 - 1; i++ )
1970 int idx = responses[sorted_indices[i]];
1972 double p = priors[idx], p2 = p*p;
1974 lv = lc[idx]; rv = rc[idx];
1975 lsum2 += p2*(lv*2 + 1);
1976 rsum2 -= p2*(rv*2 - 1);
1977 lc[idx] = lv + 1; rc[idx] = rv - 1;
1979 if( values[i] + epsilon < values[i+1] )
1981 double val = (lsum2*R + rsum2*L)/((double)L*R);
1982 if( best_val < val )
1991 CvDTreeSplit* split = 0;
1994 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1995 split->var_idx = vi;
1996 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
1997 split->ord.split_point = best_i;
1998 split->inversed = 0;
1999 split->quality = (float)best_val;
2005 void CvDTree::cluster_categories( const int* vectors, int n, int m,
2006 int* csums, int k, int* labels )
2008 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2009 int iters = 0, max_iters = 100;
2011 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
2012 double *v_weights = buf, *c_weights = buf + n;
2013 bool modified = true;
2014 CvRNG* r = &data->rng;
2016 // assign labels randomly
2017 for( i = 0; i < n; i++ )
2020 const int* v = vectors + i*m;
2021 labels[i] = i < k ? i : (cvRandInt(r) % k);
2023 // compute weight of each vector
2024 for( j = 0; j < m; j++ )
2026 v_weights[i] = sum ? 1./sum : 0.;
2029 for( i = 0; i < n; i++ )
2031 int i1 = cvRandInt(r) % n;
2032 int i2 = cvRandInt(r) % n;
2033 CV_SWAP( labels[i1], labels[i2], j );
2036 for( iters = 0; iters <= max_iters; iters++ )
2039 for( i = 0; i < k; i++ )
2041 for( j = 0; j < m; j++ )
2045 for( i = 0; i < n; i++ )
2047 const int* v = vectors + i*m;
2048 int* s = csums + labels[i]*m;
2049 for( j = 0; j < m; j++ )
2053 // exit the loop here, when we have up-to-date csums
2054 if( iters == max_iters || !modified )
2059 // calculate weight of each cluster
2060 for( i = 0; i < k; i++ )
2062 const int* s = csums + i*m;
2064 for( j = 0; j < m; j++ )
2066 c_weights[i] = sum ? 1./sum : 0;
2069 // now for each vector determine the closest cluster
2070 for( i = 0; i < n; i++ )
2072 const int* v = vectors + i*m;
2073 double alpha = v_weights[i];
2074 double min_dist2 = DBL_MAX;
2077 for( idx = 0; idx < k; idx++ )
2079 const int* s = csums + idx*m;
2080 double dist2 = 0., beta = c_weights[idx];
2081 for( j = 0; j < m; j++ )
2083 double t = v[j]*alpha - s[j]*beta;
2086 if( min_dist2 > dist2 )
2093 if( min_idx != labels[i] )
2095 labels[i] = min_idx;
2101 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
2102 CvDTreeSplit* _split, uchar* _ext_buf )
2104 int ci = data->get_var_type(vi);
2105 int n = node->sample_count;
2106 int m = data->get_num_classes();
2107 int _mi = data->cat_count->data.i[ci], mi = _mi;
2109 int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
2110 if( m > 2 && mi > data->params.max_categories )
2111 base_size += (m*min(data->params.max_categories, n) + mi)*sizeof(int);
2113 base_size += mi*sizeof(int*);
2114 cv::AutoBuffer<uchar> inn_buf(base_size);
2116 inn_buf.allocate(base_size + 2*n*sizeof(int));
2117 uchar* base_buf = (uchar*)inn_buf;
2118 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2120 int* lc = (int*)base_buf;
2122 int* _cjk = rc + m*2, *cjk = _cjk;
2123 double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
2125 int* labels_buf = (int*)ext_buf;
2126 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2127 int* responses_buf = labels_buf + n;
2128 const int* responses = data->get_class_labels(node, responses_buf);
2130 int* cluster_labels = 0;
2133 double L = 0, R = 0;
2134 double best_val = init_quality;
2135 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2136 const double* priors = data->priors_mult->data.db;
2138 // init array of counters:
2139 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
2140 for( j = -1; j < mi; j++ )
2141 for( k = 0; k < m; k++ )
2144 for( i = 0; i < n; i++ )
2146 j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2153 if( mi > data->params.max_categories )
2155 mi = MIN(data->params.max_categories, n);
2156 cjk = (int*)(c_weights + _mi);
2157 cluster_labels = cjk + m*mi;
2158 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2166 int_ptr = (int**)(c_weights + _mi);
2167 for( j = 0; j < mi; j++ )
2168 int_ptr[j] = cjk + j*2 + 1;
2169 icvSortIntPtr( int_ptr, mi, 0 );
2174 for( k = 0; k < m; k++ )
2177 for( j = 0; j < mi; j++ )
2178 sum += cjk[j*m + k];
2183 for( j = 0; j < mi; j++ )
2186 for( k = 0; k < m; k++ )
2187 sum += cjk[j*m + k]*priors[k];
2192 for( ; subset_i < subset_n; subset_i++ )
2196 double lsum2 = 0, rsum2 = 0;
2199 idx = (int)(int_ptr[subset_i] - cjk)/2;
2202 int graycode = (subset_i>>1)^subset_i;
2203 int diff = graycode ^ prevcode;
2205 // determine index of the changed bit.
2207 idx = diff >= (1 << 16) ? 16 : 0;
2208 u.f = (float)(((diff >> 16) | diff) & 65535);
2209 idx += (u.i >> 23) - 127;
2210 subtract = graycode < prevcode;
2211 prevcode = graycode;
2215 weight = c_weights[idx];
2216 if( weight < FLT_EPSILON )
2221 for( k = 0; k < m; k++ )
2224 int lval = lc[k] + t;
2225 int rval = rc[k] - t;
2226 double p = priors[k], p2 = p*p;
2227 lsum2 += p2*lval*lval;
2228 rsum2 += p2*rval*rval;
2229 lc[k] = lval; rc[k] = rval;
2236 for( k = 0; k < m; k++ )
2239 int lval = lc[k] - t;
2240 int rval = rc[k] + t;
2241 double p = priors[k], p2 = p*p;
2242 lsum2 += p2*lval*lval;
2243 rsum2 += p2*rval*rval;
2244 lc[k] = lval; rc[k] = rval;
2250 if( L > FLT_EPSILON && R > FLT_EPSILON )
2252 double val = (lsum2*R + rsum2*L)/((double)L*R);
2253 if( best_val < val )
2256 best_subset = subset_i;
2261 CvDTreeSplit* split = 0;
2262 if( best_subset >= 0 )
2264 split = _split ? _split : data->new_split_cat( 0, -1.0f );
2265 split->var_idx = vi;
2266 split->quality = (float)best_val;
2267 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2270 for( i = 0; i <= best_subset; i++ )
2272 idx = (int)(int_ptr[i] - cjk) >> 1;
2273 split->subset[idx >> 5] |= 1 << (idx & 31);
2278 for( i = 0; i < _mi; i++ )
2280 idx = cluster_labels ? cluster_labels[i] : i;
2281 if( best_subset & (1 << idx) )
2282 split->subset[i >> 5] |= 1 << (i & 31);
2290 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2292 const float epsilon = FLT_EPSILON*2;
2293 int n = node->sample_count;
2294 int n1 = node->get_num_valid(vi);
2296 cv::AutoBuffer<uchar> inn_buf;
2298 inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
2299 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2300 float* values_buf = (float*)ext_buf;
2301 int* sorted_indices_buf = (int*)(values_buf + n);
2302 int* sample_indices_buf = sorted_indices_buf + n;
2303 const float* values = 0;
2304 const int* sorted_indices = 0;
2305 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2306 float* responses_buf = (float*)(sample_indices_buf + n);
2307 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
2310 double best_val = init_quality, lsum = 0, rsum = node->value*n;
2313 // compensate for missing values
2314 for( i = n1; i < n; i++ )
2315 rsum -= responses[sorted_indices[i]];
2317 // find the optimal split
2318 for( i = 0; i < n1 - 1; i++ )
2320 float t = responses[sorted_indices[i]];
2325 if( values[i] + epsilon < values[i+1] )
2327 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2328 if( best_val < val )
2336 CvDTreeSplit* split = 0;
2339 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2340 split->var_idx = vi;
2341 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2342 split->ord.split_point = best_i;
2343 split->inversed = 0;
2344 split->quality = (float)best_val;
2349 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2351 int ci = data->get_var_type(vi);
2352 int n = node->sample_count;
2353 int mi = data->cat_count->data.i[ci];
2355 int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
2356 cv::AutoBuffer<uchar> inn_buf(base_size);
2358 inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
2359 uchar* base_buf = (uchar*)inn_buf;
2360 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2361 int* labels_buf = (int*)ext_buf;
2362 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2363 float* responses_buf = (float*)(labels_buf + n);
2364 int* sample_indices_buf = (int*)(responses_buf + n);
2365 const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
2367 double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2368 int* counts = (int*)(sum + mi) + 1;
2369 double** sum_ptr = (double**)(counts + mi);
2370 int i, L = 0, R = 0;
2371 double best_val = init_quality, lsum = 0, rsum = 0;
2372 int best_subset = -1, subset_i;
2374 for( i = -1; i < mi; i++ )
2375 sum[i] = counts[i] = 0;
2377 // calculate sum response and weight of each category of the input var
2378 for( i = 0; i < n; i++ )
2380 int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2381 double s = sum[idx] + responses[i];
2382 int nc = counts[idx] + 1;
2387 // calculate average response in each category
2388 for( i = 0; i < mi; i++ )
2392 sum[i] /= MAX(counts[i],1);
2393 sum_ptr[i] = sum + i;
2396 icvSortDblPtr( sum_ptr, mi, 0 );
2398 // revert back to unnormalized sums
2399 // (there should be a very little loss of accuracy)
2400 for( i = 0; i < mi; i++ )
2401 sum[i] *= counts[i];
2403 for( subset_i = 0; subset_i < mi-1; subset_i++ )
2405 int idx = (int)(sum_ptr[subset_i] - sum);
2406 int ni = counts[idx];
2410 double s = sum[idx];
2416 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2417 if( best_val < val )
2420 best_subset = subset_i;
2426 CvDTreeSplit* split = 0;
2427 if( best_subset >= 0 )
2429 split = _split ? _split : data->new_split_cat( 0, -1.0f);
2430 split->var_idx = vi;
2431 split->quality = (float)best_val;
2432 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2433 for( i = 0; i <= best_subset; i++ )
2435 int idx = (int)(sum_ptr[i] - sum);
2436 split->subset[idx >> 5] |= 1 << (idx & 31);
2442 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
2444 const float epsilon = FLT_EPSILON*2;
2445 const char* dir = (char*)data->direction->data.ptr;
2446 int n = node->sample_count, n1 = node->get_num_valid(vi);
2447 cv::AutoBuffer<uchar> inn_buf;
2449 inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
2450 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2451 float* values_buf = (float*)ext_buf;
2452 int* sorted_indices_buf = (int*)(values_buf + n);
2453 int* sample_indices_buf = sorted_indices_buf + n;
2454 const float* values = 0;
2455 const int* sorted_indices = 0;
2456 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2457 // LL - number of samples that both the primary and the surrogate splits send to the left
2458 // LR - ... primary split sends to the left and the surrogate split sends to the right
2459 // RL - ... primary split sends to the right and the surrogate split sends to the left
2460 // RR - ... both send to the right
2461 int i, best_i = -1, best_inversed = 0;
2464 if( !data->have_priors )
2466 int LL = 0, RL = 0, LR, RR;
2467 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2468 int sum = 0, sum_abs = 0;
2470 for( i = 0; i < n1; i++ )
2472 int d = dir[sorted_indices[i]];
2473 sum += d; sum_abs += d & 1;
2476 // sum_abs = R + L; sum = R - L
2477 RR = (sum_abs + sum) >> 1;
2478 LR = (sum_abs - sum) >> 1;
2480 // initially all the samples are sent to the right by the surrogate split,
2481 // LR of them are sent to the left by primary split, and RR - to the right.
2482 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2483 for( i = 0; i < n1 - 1; i++ )
2485 int d = dir[sorted_indices[i]];
2490 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2493 best_i = i; best_inversed = 0;
2499 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2502 best_i = i; best_inversed = 1;
2506 best_val = _best_val;
2510 double LL = 0, RL = 0, LR, RR;
2511 double worst_val = node->maxlr;
2512 double sum = 0, sum_abs = 0;
2513 const double* priors = data->priors_mult->data.db;
2514 int* responses_buf = sample_indices_buf + n;
2515 const int* responses = data->get_class_labels(node, responses_buf);
2516 best_val = worst_val;
2518 for( i = 0; i < n1; i++ )
2520 int idx = sorted_indices[i];
2521 double w = priors[responses[idx]];
2523 sum += d*w; sum_abs += (d & 1)*w;
2526 // sum_abs = R + L; sum = R - L
2527 RR = (sum_abs + sum)*0.5;
2528 LR = (sum_abs - sum)*0.5;
2530 // initially all the samples are sent to the right by the surrogate split,
2531 // LR of them are sent to the left by primary split, and RR - to the right.
2532 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2533 for( i = 0; i < n1 - 1; i++ )
2535 int idx = sorted_indices[i];
2536 double w = priors[responses[idx]];
2542 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2545 best_i = i; best_inversed = 0;
2551 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2554 best_i = i; best_inversed = 1;
2559 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2560 (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2564 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
2566 const char* dir = (char*)data->direction->data.ptr;
2567 int n = node->sample_count;
2568 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2570 int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
2571 cv::AutoBuffer<uchar> inn_buf(base_size);
2573 inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
2574 uchar* base_buf = (uchar*)inn_buf;
2575 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2577 int* labels_buf = (int*)ext_buf;
2578 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2579 // LL - number of samples that both the primary and the surrogate splits send to the left
2580 // LR - ... primary split sends to the left and the surrogate split sends to the right
2581 // RL - ... primary split sends to the right and the surrogate split sends to the left
2582 // RR - ... both send to the right
2583 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2584 double best_val = 0;
2585 double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2586 double* rc = lc + mi + 1;
2588 for( i = -1; i < mi; i++ )
2591 // for each category calculate the weight of samples
2592 // sent to the left (lc) and to the right (rc) by the primary split
2593 if( !data->have_priors )
2595 int* _lc = (int*)rc + 1;
2596 int* _rc = _lc + mi + 1;
2598 for( i = -1; i < mi; i++ )
2599 _lc[i] = _rc[i] = 0;
2601 for( i = 0; i < n; i++ )
2603 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2605 int sum = _lc[idx] + d;
2606 int sum_abs = _rc[idx] + (d & 1);
2607 _lc[idx] = sum; _rc[idx] = sum_abs;
2610 for( i = 0; i < mi; i++ )
2613 int sum_abs = _rc[i];
2614 lc[i] = (sum_abs - sum) >> 1;
2615 rc[i] = (sum_abs + sum) >> 1;
2620 const double* priors = data->priors_mult->data.db;
2621 int* responses_buf = labels_buf + n;
2622 const int* responses = data->get_class_labels(node, responses_buf);
2624 for( i = 0; i < n; i++ )
2626 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2627 double w = priors[responses[i]];
2629 double sum = lc[idx] + d*w;
2630 double sum_abs = rc[idx] + (d & 1)*w;
2631 lc[idx] = sum; rc[idx] = sum_abs;
2634 for( i = 0; i < mi; i++ )
2637 double sum_abs = rc[i];
2638 lc[i] = (sum_abs - sum) * 0.5;
2639 rc[i] = (sum_abs + sum) * 0.5;
2643 // 2. now form the split.
2644 // in each category send all the samples to the same direction as majority
2645 for( i = 0; i < mi; i++ )
2647 double lval = lc[i], rval = rc[i];
2650 split->subset[i >> 5] |= 1 << (i & 31);
2658 split->quality = (float)best_val;
2659 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2660 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2666 void CvDTree::calc_node_value( CvDTreeNode* node )
2668 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2669 int m = data->get_num_classes();
2671 int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
2672 int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
2673 cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
2674 uchar* base_buf = (uchar*)inn_buf;
2675 uchar* ext_buf = base_buf + base_size;
2677 int* cv_labels_buf = (int*)ext_buf;
2678 const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
2680 if( data->is_classifier )
2682 // in case of classification tree:
2683 // * node value is the label of the class that has the largest weight in the node.
2684 // * node risk is the weighted number of misclassified samples,
2685 // * j-th cross-validation fold value and risk are calculated as above,
2686 // but using the samples with cv_labels(*)!=j.
2687 // * j-th cross-validation fold error is calculated as the weighted number of
2688 // misclassified samples with cv_labels(*)==j.
2690 // compute the number of instances of each class
2691 int* cls_count = data->counts->data.i;
2692 int* responses_buf = cv_labels_buf + n;
2693 const int* responses = data->get_class_labels(node, responses_buf);
2694 int* cv_cls_count = (int*)base_buf;
2695 double max_val = -1, total_weight = 0;
2697 double* priors = data->priors_mult->data.db;
2699 for( k = 0; k < m; k++ )
2704 for( i = 0; i < n; i++ )
2705 cls_count[responses[i]]++;
2709 for( j = 0; j < cv_n; j++ )
2710 for( k = 0; k < m; k++ )
2711 cv_cls_count[j*m + k] = 0;
2713 for( i = 0; i < n; i++ )
2715 j = cv_labels[i]; k = responses[i];
2716 cv_cls_count[j*m + k]++;
2719 for( j = 0; j < cv_n; j++ )
2720 for( k = 0; k < m; k++ )
2721 cls_count[k] += cv_cls_count[j*m + k];
2724 if( data->have_priors && node->parent == 0 )
2726 // compute priors_mult from priors, take the sample ratio into account.
2728 for( k = 0; k < m; k++ )
2730 int n_k = cls_count[k];
2731 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2735 for( k = 0; k < m; k++ )
2739 for( k = 0; k < m; k++ )
2741 double val = cls_count[k]*priors[k];
2742 total_weight += val;
2750 node->class_idx = max_k;
2751 node->value = data->cat_map->data.i[
2752 data->cat_ofs->data.i[data->cat_var_count] + max_k];
2753 node->node_risk = total_weight - max_val;
2755 for( j = 0; j < cv_n; j++ )
2757 double sum_k = 0, sum = 0, max_val_k = 0;
2758 max_val = -1; max_k = -1;
2760 for( k = 0; k < m; k++ )
2762 double w = priors[k];
2763 double val_k = cv_cls_count[j*m + k]*w;
2764 double val = cls_count[k]*w - val_k;
2775 node->cv_Tn[j] = INT_MAX;
2776 node->cv_node_risk[j] = sum - max_val;
2777 node->cv_node_error[j] = sum_k - max_val_k;
2782 // in case of regression tree:
2783 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2784 // n is the number of samples in the node.
2785 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2786 // * j-th cross-validation fold value and risk are calculated as above,
2787 // but using the samples with cv_labels(*)!=j.
2788 // * j-th cross-validation fold error is calculated
2789 // using samples with cv_labels(*)==j as the test subset:
2790 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2791 // where node_value_j is the node value calculated
2792 // as described in the previous bullet, and summation is done
2793 // over the samples with cv_labels(*)==j.
2795 double sum = 0, sum2 = 0;
2796 float* values_buf = (float*)(cv_labels_buf + n);
2797 int* sample_indices_buf = (int*)(values_buf + n);
2798 const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
2799 double *cv_sum = 0, *cv_sum2 = 0;
2804 for( i = 0; i < n; i++ )
2806 double t = values[i];
2813 cv_sum = (double*)base_buf;
2814 cv_sum2 = cv_sum + cv_n;
2815 cv_count = (int*)(cv_sum2 + cv_n);
2817 for( j = 0; j < cv_n; j++ )
2819 cv_sum[j] = cv_sum2[j] = 0.;
2823 for( i = 0; i < n; i++ )
2826 double t = values[i];
2827 double s = cv_sum[j] + t;
2828 double s2 = cv_sum2[j] + t*t;
2829 int nc = cv_count[j] + 1;
2835 for( j = 0; j < cv_n; j++ )
2842 node->node_risk = sum2 - (sum/n)*sum;
2843 node->value = sum/n;
2845 for( j = 0; j < cv_n; j++ )
2847 double s = cv_sum[j], si = sum - s;
2848 double s2 = cv_sum2[j], s2i = sum2 - s2;
2849 int c = cv_count[j], ci = n - c;
2850 double r = si/MAX(ci,1);
2851 node->cv_node_risk[j] = s2i - r*r*ci;
2852 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2853 node->cv_Tn[j] = INT_MAX;
2859 void CvDTree::complete_node_dir( CvDTreeNode* node )
2861 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2862 int nz = n - node->get_num_valid(node->split->var_idx);
2863 char* dir = (char*)data->direction->data.ptr;
2865 // try to complete direction using surrogate splits
2866 if( nz && data->params.use_surrogates )
2868 cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
2869 CvDTreeSplit* split = node->split->next;
2870 for( ; split != 0 && nz; split = split->next )
2872 int inversed_mask = split->inversed ? -1 : 0;
2873 vi = split->var_idx;
2875 if( data->get_var_type(vi) >= 0 ) // split on categorical var
2877 int* labels_buf = (int*)(uchar*)inn_buf;
2878 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2879 const int* subset = split->subset;
2881 for( i = 0; i < n; i++ )
2883 int idx = labels[i];
2884 if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2887 int d = CV_DTREE_CAT_DIR(idx,subset);
2888 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2894 else // split on ordered var
2896 float* values_buf = (float*)(uchar*)inn_buf;
2897 int* sorted_indices_buf = (int*)(values_buf + n);
2898 int* sample_indices_buf = sorted_indices_buf + n;
2899 const float* values = 0;
2900 const int* sorted_indices = 0;
2901 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2902 int split_point = split->ord.split_point;
2903 int n1 = node->get_num_valid(vi);
2905 assert( 0 <= split_point && split_point < n-1 );
2907 for( i = 0; i < n1; i++ )
2909 int idx = sorted_indices[i];
2912 int d = i <= split_point ? -1 : 1;
2913 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2922 // find the default direction for the rest
2925 for( i = nr = 0; i < n; i++ )
2928 d0 = nl > nr ? -1 : nr > nl;
2931 // make sure that every sample is directed either to the left or to the right
2932 for( i = 0; i < n; i++ )
2942 dir[i] = (char)d; // remap (-1,1) to (0,1)
2947 void CvDTree::split_node_data( CvDTreeNode* node )
2949 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
2950 char* dir = (char*)data->direction->data.ptr;
2951 CvDTreeNode *left = 0, *right = 0;
2952 int* new_idx = data->split_buf->data.i;
2953 int new_buf_idx = data->get_child_buf_idx( node );
2954 int work_var_count = data->get_work_var_count();
2955 CvMat* buf = data->buf;
2956 cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
2957 int* temp_buf = (int*)(uchar*)inn_buf;
2959 complete_node_dir(node);
2961 for( i = nl = nr = 0; i < n; i++ )
2964 // initialize new indices for splitting ordered variables
2965 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2970 bool split_input_data;
2971 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2972 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
2974 split_input_data = node->depth + 1 < data->params.max_depth &&
2975 (node->left->sample_count > data->params.min_sample_count ||
2976 node->right->sample_count > data->params.min_sample_count);
2978 // split ordered variables, keep both halves sorted.
2979 for( vi = 0; vi < data->var_count; vi++ )
2981 int ci = data->get_var_type(vi);
2983 if( ci >= 0 || !split_input_data )
2986 int n1 = node->get_num_valid(vi);
2987 float* src_val_buf = (float*)(uchar*)(temp_buf + n);
2988 int* src_sorted_idx_buf = (int*)(src_val_buf + n);
2989 int* src_sample_idx_buf = src_sorted_idx_buf + n;
2990 const float* src_val = 0;
2991 const int* src_sorted_idx = 0;
2992 data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
2994 for(i = 0; i < n; i++)
2995 temp_buf[i] = src_sorted_idx[i];
2997 if (data->is_buf_16u)
2999 unsigned short *ldst, *rdst, *ldst0, *rdst0;
3000 //unsigned short tl, tr;
3001 ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
3002 vi*scount + left->offset);
3003 rdst0 = rdst = (unsigned short*)(ldst + nl);
3006 for( i = 0; i < n1; i++ )
3008 int idx = temp_buf[i];
3013 *rdst = (unsigned short)idx;
3018 *ldst = (unsigned short)idx;
3023 left->set_num_valid(vi, (int)(ldst - ldst0));
3024 right->set_num_valid(vi, (int)(rdst - rdst0));
3029 int idx = temp_buf[i];
3034 *rdst = (unsigned short)idx;
3039 *ldst = (unsigned short)idx;
3046 int *ldst0, *ldst, *rdst0, *rdst;
3047 ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols +
3048 vi*scount + left->offset;
3049 rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols +
3050 vi*scount + right->offset;
3053 for( i = 0; i < n1; i++ )
3055 int idx = temp_buf[i];
3070 left->set_num_valid(vi, (int)(ldst - ldst0));
3071 right->set_num_valid(vi, (int)(rdst - rdst0));
3076 int idx = temp_buf[i];
3093 // split categorical vars, responses and cv_labels using new_idx relocation table
3094 for( vi = 0; vi < work_var_count; vi++ )
3096 int ci = data->get_var_type(vi);
3097 int n1 = node->get_num_valid(vi), nr1 = 0;
3099 if( ci < 0 || (vi < data->var_count && !split_input_data) )
3102 int *src_lbls_buf = temp_buf + n;
3103 const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
3105 for(i = 0; i < n; i++)
3106 temp_buf[i] = src_lbls[i];
3108 if (data->is_buf_16u)
3110 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
3111 vi*scount + left->offset);
3112 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
3113 vi*scount + right->offset);
3115 for( i = 0; i < n; i++ )
3118 int idx = temp_buf[i];
3121 *rdst = (unsigned short)idx;
3123 nr1 += (idx != 65535 )&d;
3127 *ldst = (unsigned short)idx;
3132 if( vi < data->var_count )
3134 left->set_num_valid(vi, n1 - nr1);
3135 right->set_num_valid(vi, nr1);
3140 int *ldst = buf->data.i + left->buf_idx*buf->cols +
3141 vi*scount + left->offset;
3142 int *rdst = buf->data.i + right->buf_idx*buf->cols +
3143 vi*scount + right->offset;
3145 for( i = 0; i < n; i++ )
3148 int idx = temp_buf[i];
3153 nr1 += (idx >= 0)&d;
3163 if( vi < data->var_count )
3165 left->set_num_valid(vi, n1 - nr1);
3166 right->set_num_valid(vi, nr1);
3172 // split sample indices
3173 int *sample_idx_src_buf = temp_buf + n;
3174 const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
3176 for(i = 0; i < n; i++)
3177 temp_buf[i] = sample_idx_src[i];
3179 int pos = data->get_work_var_count();
3180 if (data->is_buf_16u)
3182 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
3183 pos*scount + left->offset);
3184 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
3185 pos*scount + right->offset);
3186 for (i = 0; i < n; i++)
3189 unsigned short idx = (unsigned short)temp_buf[i];
3204 int* ldst = buf->data.i + left->buf_idx*buf->cols +
3205 pos*scount + left->offset;
3206 int* rdst = buf->data.i + right->buf_idx*buf->cols +
3207 pos*scount + right->offset;
3208 for (i = 0; i < n; i++)
3211 int idx = temp_buf[i];
3225 // deallocate the parent node data that is not needed anymore
3226 data->free_node_data(node);
3229 float CvDTree::calc_error( CvMLData* _data, int type, vector<float> *resp )
3232 const CvMat* values = _data->get_values();
3233 const CvMat* response = _data->get_responses();
3234 const CvMat* missing = _data->get_missing();
3235 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3236 const CvMat* var_types = _data->get_var_types();
3237 int* sidx = sample_idx ? sample_idx->data.i : 0;
3238 int r_step = CV_IS_MAT_CONT(response->type) ?
3239 1 : response->step / CV_ELEM_SIZE(response->type);
3240 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3241 int sample_count = sample_idx ? sample_idx->cols : 0;
3242 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3243 float* pred_resp = 0;
3244 if( resp && (sample_count > 0) )
3246 resp->resize( sample_count );
3247 pred_resp = &((*resp)[0]);
3250 if ( is_classifier )
3252 for( int i = 0; i < sample_count; i++ )
3255 int si = sidx ? sidx[i] : i;
3256 cvGetRow( values, &sample, si );
3258 cvGetRow( missing, &miss, si );
3259 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3262 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
3265 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3269 for( int i = 0; i < sample_count; i++ )
3272 int si = sidx ? sidx[i] : i;
3273 cvGetRow( values, &sample, si );
3275 cvGetRow( missing, &miss, si );
3276 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3279 float d = r - response->data.fl[si*r_step];
3282 err = sample_count ? err / (float)sample_count : -FLT_MAX;
3287 void CvDTree::prune_cv()
3293 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3294 // 2. choose the best tree index (if need, apply 1SE rule).
3295 // 3. store the best index and cut the branches.
3297 CV_FUNCNAME( "CvDTree::prune_cv" );
3301 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3302 // currently, 1SE for regression is not implemented
3303 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3305 double min_err = 0, min_err_se = 0;
3308 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3310 // build the main tree sequence, calculate alpha's
3313 double min_alpha = update_tree_rnc(tree_count, -1);
3314 if( cut_tree(tree_count, -1, min_alpha) )
3317 if( ab->cols <= tree_count )
3319 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3320 for( ti = 0; ti < ab->cols; ti++ )
3321 temp->data.db[ti] = ab->data.db[ti];
3322 cvReleaseMat( &ab );
3327 ab->data.db[tree_count] = min_alpha;
3330 ab->data.db[0] = 0.;
3332 if( tree_count > 0 )
3334 for( ti = 1; ti < tree_count-1; ti++ )
3335 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3336 ab->data.db[tree_count-1] = DBL_MAX*0.5;
3338 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3339 err = err_jk->data.db;
3341 for( j = 0; j < cv_n; j++ )
3344 for( ; tk < tree_count; tj++ )
3346 double min_alpha = update_tree_rnc(tj, j);
3347 if( cut_tree(tj, j, min_alpha) )
3348 min_alpha = DBL_MAX;
3350 for( ; tk < tree_count; tk++ )
3352 if( ab->data.db[tk] > min_alpha )
3354 err[j*tree_count + tk] = root->tree_error;
3359 for( ti = 0; ti < tree_count; ti++ )
3362 for( j = 0; j < cv_n; j++ )
3363 sum_err += err[j*tree_count + ti];
3364 if( ti == 0 || sum_err < min_err )
3369 min_err_se = sqrt( sum_err*(n - sum_err) );
3371 else if( sum_err < min_err + min_err_se )
3376 pruned_tree_idx = min_idx;
3377 free_prune_data(data->params.truncate_pruned_tree != 0);
3381 cvReleaseMat( &err_jk );
3382 cvReleaseMat( &ab );
3383 cvReleaseMat( &temp );
3387 double CvDTree::update_tree_rnc( int T, int fold )
3389 CvDTreeNode* node = root;
3390 double min_alpha = DBL_MAX;
3394 CvDTreeNode* parent;
3397 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3398 if( t <= T || !node->left )
3400 node->complexity = 1;
3401 node->tree_risk = node->node_risk;
3402 node->tree_error = 0.;
3405 node->tree_risk = node->cv_node_risk[fold];
3406 node->tree_error = node->cv_node_error[fold];
3413 for( parent = node->parent; parent && parent->right == node;
3414 node = parent, parent = parent->parent )
3416 parent->complexity += node->complexity;
3417 parent->tree_risk += node->tree_risk;
3418 parent->tree_error += node->tree_error;
3420 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3421 - parent->tree_risk)/(parent->complexity - 1);
3422 min_alpha = MIN( min_alpha, parent->alpha );
3428 parent->complexity = node->complexity;
3429 parent->tree_risk = node->tree_risk;
3430 parent->tree_error = node->tree_error;
3431 node = parent->right;
3438 int CvDTree::cut_tree( int T, int fold, double min_alpha )
3440 CvDTreeNode* node = root;
3446 CvDTreeNode* parent;
3449 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3450 if( t <= T || !node->left )
3452 if( node->alpha <= min_alpha + FLT_EPSILON )
3455 node->cv_Tn[fold] = T;
3465 for( parent = node->parent; parent && parent->right == node;
3466 node = parent, parent = parent->parent )
3472 node = parent->right;
3479 void CvDTree::free_prune_data(bool cut_tree)
3481 CvDTreeNode* node = root;
3485 CvDTreeNode* parent;
3488 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3489 // as we will clear the whole cross-validation heap at the end
3491 node->cv_node_error = node->cv_node_risk = 0;
3497 for( parent = node->parent; parent && parent->right == node;
3498 node = parent, parent = parent->parent )
3500 if( cut_tree && parent->Tn <= pruned_tree_idx )
3502 data->free_node( parent->left );
3503 data->free_node( parent->right );
3504 parent->left = parent->right = 0;
3511 node = parent->right;
3515 cvClearSet( data->cv_heap );
3519 void CvDTree::free_tree()
3521 if( root && data && data->shared )
3523 pruned_tree_idx = INT_MIN;
3524 free_prune_data(true);
3525 data->free_node(root);
3530 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3531 const CvMat* _missing, bool preprocessed_input ) const
3533 CvDTreeNode* result = 0;
3536 CV_FUNCNAME( "CvDTree::predict" );
3540 int i, step, mstep = 0;
3541 const float* sample;
3543 CvDTreeNode* node = root;
3550 CV_ERROR( CV_StsError, "The tree has not been trained yet" );
3552 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3553 (_sample->cols != 1 && _sample->rows != 1) ||
3554 (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3555 (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3556 CV_ERROR( CV_StsBadArg,
3557 "the input sample must be 1d floating-point vector with the same "
3558 "number of elements as the total number of variables used for training" );
3560 sample = _sample->data.fl;
3561 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3563 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3565 int n = data->cat_count->cols;
3566 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
3567 for( i = 0; i < n; i++ )
3573 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3574 !CV_ARE_SIZES_EQ(_missing, _sample) )
3575 CV_ERROR( CV_StsBadArg,
3576 "the missing data mask must be 8-bit vector of the same size as input sample" );
3577 m = _missing->data.ptr;
3578 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3581 vtype = data->var_type->data.i;
3582 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3583 cmap = data->cat_map ? data->cat_map->data.i : 0;
3584 cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3586 while( node->Tn > pruned_tree_idx && node->left )
3588 CvDTreeSplit* split = node->split;
3590 for( ; !dir && split != 0; split = split->next )
3592 int vi = split->var_idx;
3594 i = vidx ? vidx[vi] : vi;
3595 float val = sample[i*step];
3596 if( m && m[i*mstep] )
3598 if( ci < 0 ) // ordered
3599 dir = val <= split->ord.c ? -1 : 1;
3603 if( preprocessed_input )
3610 int a = c = cofs[ci];
3611 int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3613 int ival = cvRound(val);
3615 CV_ERROR( CV_StsBadArg,
3616 "one of input categorical variable is not an integer" );
3623 if( ival < cmap[c] )
3625 else if( ival > cmap[c] )
3631 if( c < 0 || ival != cmap[c] )
3634 catbuf[ci] = c -= cofs[ci];
3637 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3638 dir = CV_DTREE_CAT_DIR(c, split->subset);
3641 if( split->inversed )
3647 double diff = node->right->sample_count - node->left->sample_count;
3648 dir = diff < 0 ? -1 : 1;
3650 node = dir < 0 ? node->left : node->right;
3661 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3663 CvMat sample = _sample, mmask = _missing;
3664 return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3668 const CvMat* CvDTree::get_var_importance()
3670 if( !var_importance )
3672 CvDTreeNode* node = root;
3676 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3677 cvZero( var_importance );
3678 importance = var_importance->data.db;
3682 CvDTreeNode* parent;
3683 for( ;; node = node->left )
3685 CvDTreeSplit* split = node->split;
3687 if( !node->left || node->Tn <= pruned_tree_idx )
3690 for( ; split != 0; split = split->next )
3691 importance[split->var_idx] += split->quality;
3694 for( parent = node->parent; parent && parent->right == node;
3695 node = parent, parent = parent->parent )
3701 node = parent->right;
3704 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3707 return var_importance;
3711 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3715 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3716 cvWriteInt( fs, "var", split->var_idx );
3717 cvWriteReal( fs, "quality", split->quality );
3719 ci = data->get_var_type(split->var_idx);
3720 if( ci >= 0 ) // split on a categorical var
3722 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3723 for( i = 0; i < n; i++ )
3724 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3726 // ad-hoc rule when to use inverse categorical split notation
3727 // to achieve more compact and clear representation
3728 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3730 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3731 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3733 for( i = 0; i < n; i++ )
3735 int dir = CV_DTREE_CAT_DIR(i,split->subset);
3736 if( dir*default_dir < 0 )
3737 cvWriteInt( fs, 0, i );
3739 cvEndWriteStruct( fs );
3742 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3744 cvEndWriteStruct( fs );
3748 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3750 CvDTreeSplit* split;
3752 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3754 cvWriteInt( fs, "depth", node->depth );
3755 cvWriteInt( fs, "sample_count", node->sample_count );
3756 cvWriteReal( fs, "value", node->value );
3758 if( data->is_classifier )
3759 cvWriteInt( fs, "norm_class_idx", node->class_idx );
3761 cvWriteInt( fs, "Tn", node->Tn );
3762 cvWriteInt( fs, "complexity", node->complexity );
3763 cvWriteReal( fs, "alpha", node->alpha );
3764 cvWriteReal( fs, "node_risk", node->node_risk );
3765 cvWriteReal( fs, "tree_risk", node->tree_risk );
3766 cvWriteReal( fs, "tree_error", node->tree_error );
3770 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3772 for( split = node->split; split != 0; split = split->next )
3773 write_split( fs, split );
3775 cvEndWriteStruct( fs );
3778 cvEndWriteStruct( fs );
3782 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3784 //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3788 CvDTreeNode* node = root;
3790 // traverse the tree and save all the nodes in depth-first order
3793 CvDTreeNode* parent;
3796 write_node( fs, node );
3802 for( parent = node->parent; parent && parent->right == node;
3803 node = parent, parent = parent->parent )
3809 node = parent->right;
3816 void CvDTree::write( CvFileStorage* fs, const char* name ) const
3818 //CV_FUNCNAME( "CvDTree::write" );
3822 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3824 //get_var_importance();
3825 data->write_params( fs );
3826 //if( var_importance )
3827 //cvWrite( fs, "var_importance", var_importance );
3830 cvEndWriteStruct( fs );
3836 void CvDTree::write( CvFileStorage* fs ) const
3838 //CV_FUNCNAME( "CvDTree::write" );
3842 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3844 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3845 write_tree_nodes( fs );
3846 cvEndWriteStruct( fs );
3852 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3854 CvDTreeSplit* split = 0;
3856 CV_FUNCNAME( "CvDTree::read_split" );
3862 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3863 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3865 vi = cvReadIntByName( fs, fnode, "var", -1 );
3866 if( (unsigned)vi >= (unsigned)data->var_count )
3867 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3869 ci = data->get_var_type(vi);
3870 if( ci >= 0 ) // split on categorical var
3872 int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3875 split = data->new_split_cat( vi, 0 );
3876 inseq = cvGetFileNodeByName( fs, fnode, "in" );
3879 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3883 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3884 CV_ERROR( CV_StsParseError,
3885 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3887 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3889 val = inseq->data.i;
3890 if( (unsigned)val >= (unsigned)n )
3891 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3893 split->subset[val >> 5] |= 1 << (val & 31);
3897 cvStartReadSeq( inseq->data.seq, &reader );
3899 for( i = 0; i < reader.seq->total; i++ )
3901 CvFileNode* inode = (CvFileNode*)reader.ptr;
3902 val = inode->data.i;
3903 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3904 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3906 split->subset[val >> 5] |= 1 << (val & 31);
3907 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3911 // for categorical splits we do not use inversed splits,
3912 // instead we inverse the variable set in the split
3914 for( i = 0; i < (n + 31) >> 5; i++ )
3915 split->subset[i] ^= -1;
3919 CvFileNode* cmp_node;
3920 split = data->new_split_ord( vi, 0, 0, 0, 0 );
3922 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3925 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3926 split->inversed = 1;
3929 split->ord.c = (float)cvReadReal( cmp_node );
3932 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3940 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3942 CvDTreeNode* node = 0;
3944 CV_FUNCNAME( "CvDTree::read_node" );
3951 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3952 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3954 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3955 depth = cvReadIntByName( fs, fnode, "depth", -1 );
3956 if( depth != node->depth )
3957 CV_ERROR( CV_StsParseError, "incorrect node depth" );
3959 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3960 node->value = cvReadRealByName( fs, fnode, "value" );
3961 if( data->is_classifier )
3962 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3964 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3965 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3966 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3967 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3968 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3969 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3971 splits = cvGetFileNodeByName( fs, fnode, "splits" );
3975 CvDTreeSplit* last_split = 0;
3977 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3978 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3980 cvStartReadSeq( splits->data.seq, &reader );
3981 for( i = 0; i < reader.seq->total; i++ )
3983 CvDTreeSplit* split;
3984 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3986 node->split = last_split = split;
3988 last_split = last_split->next = split;
3990 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4000 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4002 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4008 CvDTreeNode* parent = &_root;
4010 parent->left = parent->right = parent->parent = 0;
4012 cvStartReadSeq( fnode->data.seq, &reader );
4014 for( i = 0; i < reader.seq->total; i++ )
4018 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4020 parent->left = node;
4022 parent->right = node;
4027 while( parent && parent->right )
4028 parent = parent->parent;
4031 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4040 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4042 CvDTreeTrainData* _data = new CvDTreeTrainData();
4043 _data->read_params( fs, fnode );
4045 read( fs, fnode, _data );
4046 get_var_importance();
4050 // a special entry point for reading weak decision trees from the tree ensembles
4051 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4053 CV_FUNCNAME( "CvDTree::read" );
4057 CvFileNode* tree_nodes;
4062 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4063 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4064 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4066 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4067 read_tree_nodes( fs, tree_nodes );