1 /*M///////////////////////////////////////////////////////////////////////////////////////
\r
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
\r
5 // By downloading, copying, installing or using the software you agree to this license.
\r
6 // If you do not agree to this license, do not download, install,
\r
7 // copy or use the software.
\r
10 // Intel License Agreement
\r
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
\r
13 // Third party copyrights are property of their respective owners.
\r
15 // Redistribution and use in source and binary forms, with or without modification,
\r
16 // are permitted provided that the following conditions are met:
\r
18 // * Redistribution's of source code must retain the above copyright notice,
\r
19 // this list of conditions and the following disclaimer.
\r
21 // * Redistribution's in binary form must reproduce the above copyright notice,
\r
22 // this list of conditions and the following disclaimer in the documentation
\r
23 // and/or other materials provided with the distribution.
\r
25 // * The name of Intel Corporation may not be used to endorse or promote products
\r
26 // derived from this software without specific prior written permission.
\r
28 // This software is provided by the copyright holders and contributors "as is" and
\r
29 // any express or implied warranties, including, but not limited to, the implied
\r
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
\r
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
\r
32 // indirect, incidental, special, exemplary, or consequential damages
\r
33 // (including, but not limited to, procurement of substitute goods or services;
\r
34 // loss of use, data, or profits; or business interruption) however caused
\r
35 // and on any theory of liability, whether in contract, strict liability,
\r
36 // or tort (including negligence or otherwise) arising in any way out of
\r
37 // the use of this software, even if advised of the possibility of such damage.
\r
44 static const float ord_nan = FLT_MAX*0.5f;
\r
45 static const int min_block_size = 1 << 16;
\r
46 static const int block_size_delta = 1 << 10;
\r
48 CvDTreeTrainData::CvDTreeTrainData()
\r
50 var_idx = var_type = cat_count = cat_ofs = cat_map =
\r
51 priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
\r
52 tree_storage = temp_storage = 0;
\r
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
\r
59 const CvMat* _responses, const CvMat* _var_idx,
\r
60 const CvMat* _sample_idx, const CvMat* _var_type,
\r
61 const CvMat* _missing_mask, const CvDTreeParams& _params,
\r
62 bool _shared, bool _add_labels )
\r
64 var_idx = var_type = cat_count = cat_ofs = cat_map =
\r
65 priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
\r
67 tree_storage = temp_storage = 0;
\r
69 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
\r
70 _var_type, _missing_mask, _params, _shared, _add_labels );
\r
74 CvDTreeTrainData::~CvDTreeTrainData()
\r
80 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
\r
84 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
\r
91 if( params.max_categories < 2 )
\r
92 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
\r
93 params.max_categories = MIN( params.max_categories, 15 );
\r
95 if( params.max_depth < 0 )
\r
96 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
\r
97 params.max_depth = MIN( params.max_depth, 25 );
\r
99 params.min_sample_count = MAX(params.min_sample_count,1);
\r
101 if( params.cv_folds < 0 )
\r
102 CV_ERROR( CV_StsOutOfRange,
\r
103 "params.cv_folds should be =0 (the tree is not pruned) "
\r
104 "or n>0 (tree is pruned using n-fold cross-validation)" );
\r
106 if( params.cv_folds == 1 )
\r
107 params.cv_folds = 0;
\r
109 if( params.regression_accuracy < 0 )
\r
110 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
\r
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
\r
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
\r
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
\r
123 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])
\r
124 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )
\r
125 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )
\r
127 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
\r
128 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
\r
130 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
\r
131 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
\r
132 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
\r
133 bool _shared, bool _add_labels, bool _update_data )
\r
135 CvMat* sample_indices = 0;
\r
136 CvMat* var_type0 = 0;
\r
137 CvMat* tmp_map = 0;
\r
139 CvPair16u32s* pair16u32s_ptr = 0;
\r
140 CvDTreeTrainData* data = 0;
\r
143 unsigned short* udst = 0;
\r
146 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
\r
150 int sample_all = 0, r_type = 0, cv_n;
\r
151 int total_c_count = 0;
\r
152 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
\r
153 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
\r
156 const int *sidx = 0, *vidx = 0;
\r
158 if( _update_data && data_root )
\r
160 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
\r
161 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
\r
163 // compare new and old train data
\r
164 if( !(data->var_count == var_count &&
\r
165 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
\r
166 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
\r
167 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
\r
168 CV_ERROR( CV_StsBadArg,
\r
169 "The new training data must have the same types and the input and output variables "
\r
170 "and the same categories for categorical variables" );
\r
172 cvReleaseMat( &priors );
\r
173 cvReleaseMat( &priors_mult );
\r
174 cvReleaseMat( &buf );
\r
175 cvReleaseMat( &direction );
\r
176 cvReleaseMat( &split_buf );
\r
177 cvReleaseMemStorage( &temp_storage );
\r
179 priors = data->priors; data->priors = 0;
\r
180 priors_mult = data->priors_mult; data->priors_mult = 0;
\r
181 buf = data->buf; data->buf = 0;
\r
182 buf_count = data->buf_count; buf_size = data->buf_size;
\r
183 sample_count = data->sample_count;
\r
185 direction = data->direction; data->direction = 0;
\r
186 split_buf = data->split_buf; data->split_buf = 0;
\r
187 temp_storage = data->temp_storage; data->temp_storage = 0;
\r
188 nv_heap = data->nv_heap; cv_heap = data->cv_heap;
\r
190 data_root = new_node( 0, sample_count, 0, 0 );
\r
199 CV_CALL( set_params( _params ));
\r
201 // check parameter types and sizes
\r
202 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
\r
204 train_data = _train_data;
\r
205 responses = _responses;
\r
207 if( _tflag == CV_ROW_SAMPLE )
\r
209 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
\r
211 if( _missing_mask )
\r
212 ms_step = _missing_mask->step, mv_step = 1;
\r
216 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
\r
218 if( _missing_mask )
\r
219 mv_step = _missing_mask->step, ms_step = 1;
\r
223 sample_count = sample_all;
\r
224 var_count = var_all;
\r
228 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
\r
229 sidx = sample_indices->data.i;
\r
230 sample_count = sample_indices->rows + sample_indices->cols - 1;
\r
235 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
\r
236 vidx = var_idx->data.i;
\r
237 var_count = var_idx->rows + var_idx->cols - 1;
\r
240 is_buf_16u = false;
\r
241 if ( sample_count < 65536 )
\r
242 is_buf_16u = true;
\r
244 if( !CV_IS_MAT(_responses) ||
\r
245 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
\r
246 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
\r
247 (_responses->rows != 1 && _responses->cols != 1) ||
\r
248 _responses->rows + _responses->cols - 1 != sample_all )
\r
249 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
\r
250 "floating-point vector containing as many elements as "
\r
251 "the total number of samples in the training data matrix" );
\r
254 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
\r
256 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
\r
260 ord_var_count = -1;
\r
262 is_classifier = r_type == CV_VAR_CATEGORICAL;
\r
264 // step 0. calc the number of categorical vars
\r
265 for( vi = 0; vi < var_count; vi++ )
\r
267 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
\r
268 cat_var_count++ : ord_var_count--;
\r
271 ord_var_count = ~ord_var_count;
\r
272 cv_n = params.cv_folds;
\r
273 // set the two last elements of var_type array to be able
\r
274 // to locate responses and cross-validation labels using
\r
275 // the corresponding get_* functions.
\r
276 var_type->data.i[var_count] = cat_var_count;
\r
277 var_type->data.i[var_count+1] = cat_var_count+1;
\r
279 // in case of single ordered predictor we need dummy cv_labels
\r
280 // for safe split_node_data() operation
\r
281 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
\r
283 work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
\r
284 buf_size = (work_var_count + 1)*sample_count;
\r
286 buf_count = shared ? 2 : 1;
\r
290 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
\r
291 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
\r
295 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
\r
296 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
\r
299 size = is_classifier ? (cat_var_count+1) : cat_var_count;
\r
300 size = !size ? 1 : size;
\r
301 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
\r
302 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
\r
304 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
\r
305 size = !size ? 1 : size;
\r
306 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
\r
308 // now calculate the maximum size of split,
\r
309 // create memory storage that will keep nodes and splits of the decision tree
\r
310 // allocate root node and the buffer for the whole training data
\r
311 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
\r
312 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
\r
313 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
\r
314 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
\r
315 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
\r
316 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
\r
318 nv_size = var_count*sizeof(int);
\r
319 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
\r
321 temp_block_size = nv_size;
\r
325 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
\r
326 CV_ERROR( CV_StsOutOfRange,
\r
327 "The many folds in cross-validation for such a small dataset" );
\r
329 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
\r
330 temp_block_size = MAX(temp_block_size, cv_size);
\r
333 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
\r
334 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
\r
335 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
\r
337 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
\r
339 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
\r
346 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
\r
347 if (is_buf_16u && (cat_var_count || is_classifier))
\r
348 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
\r
350 // transform the training data to convenient representation
\r
351 for( vi = 0; vi <= var_count; vi++ )
\r
354 const uchar* mask = 0;
\r
355 int m_step = 0, step;
\r
356 const int* idata = 0;
\r
357 const float* fdata = 0;
\r
360 if( vi < var_count ) // analyze i-th input variable
\r
362 int vi0 = vidx ? vidx[vi] : vi;
\r
363 ci = get_var_type(vi);
\r
364 step = ds_step; m_step = ms_step;
\r
365 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
\r
366 idata = _train_data->data.i + vi0*dv_step;
\r
368 fdata = _train_data->data.fl + vi0*dv_step;
\r
369 if( _missing_mask )
\r
370 mask = _missing_mask->data.ptr + vi0*mv_step;
\r
372 else // analyze _responses
\r
374 ci = cat_var_count;
\r
375 step = CV_IS_MAT_CONT(_responses->type) ?
\r
376 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
\r
377 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
\r
378 idata = _responses->data.i;
\r
380 fdata = _responses->data.fl;
\r
383 if( (vi < var_count && ci>=0) ||
\r
384 (vi == var_count && is_classifier) ) // process categorical variable or response
\r
386 int c_count, prev_label;
\r
390 udst = (unsigned short*)(buf->data.s + vi*sample_count);
\r
392 idst = buf->data.i + vi*sample_count;
\r
395 for( i = 0; i < sample_count; i++ )
\r
397 int val = INT_MAX, si = sidx ? sidx[i] : i;
\r
398 if( !mask || !mask[si*m_step] )
\r
401 val = idata[si*step];
\r
404 float t = fdata[si*step];
\r
408 sprintf( err, "%d-th value of %d-th (categorical) "
\r
409 "variable is not an integer", i, vi );
\r
410 CV_ERROR( CV_StsBadArg, err );
\r
414 if( val == INT_MAX )
\r
416 sprintf( err, "%d-th value of %d-th (categorical) "
\r
417 "variable is too large", i, vi );
\r
418 CV_ERROR( CV_StsBadArg, err );
\r
425 pair16u32s_ptr[i].u = udst + i;
\r
426 pair16u32s_ptr[i].i = _idst + i;
\r
431 int_ptr[i] = idst + i;
\r
435 c_count = num_valid > 0;
\r
438 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
\r
439 // count the categories
\r
440 for( i = 1; i < num_valid; i++ )
\r
441 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
\r
446 icvSortIntPtr( int_ptr, sample_count, 0 );
\r
447 // count the categories
\r
448 for( i = 1; i < num_valid; i++ )
\r
449 c_count += *int_ptr[i] != *int_ptr[i-1];
\r
453 max_c_count = MAX( max_c_count, c_count );
\r
454 cat_count->data.i[ci] = c_count;
\r
455 cat_ofs->data.i[ci] = total_c_count;
\r
457 // resize cat_map, if need
\r
458 if( cat_map->cols < total_c_count + c_count )
\r
461 CV_CALL( cat_map = cvCreateMat( 1,
\r
462 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
\r
463 for( i = 0; i < total_c_count; i++ )
\r
464 cat_map->data.i[i] = tmp_map->data.i[i];
\r
465 cvReleaseMat( &tmp_map );
\r
468 c_map = cat_map->data.i + total_c_count;
\r
469 total_c_count += c_count;
\r
474 // compact the class indices and build the map
\r
475 prev_label = ~*pair16u32s_ptr[0].i;
\r
476 for( i = 0; i < num_valid; i++ )
\r
478 int cur_label = *pair16u32s_ptr[i].i;
\r
479 if( cur_label != prev_label )
\r
480 c_map[++c_count] = prev_label = cur_label;
\r
481 *pair16u32s_ptr[i].u = (unsigned short)c_count;
\r
483 // replace labels for missing values with -1
\r
484 for( ; i < sample_count; i++ )
\r
485 *pair16u32s_ptr[i].u = 65535;
\r
489 // compact the class indices and build the map
\r
490 prev_label = ~*int_ptr[0];
\r
491 for( i = 0; i < num_valid; i++ )
\r
493 int cur_label = *int_ptr[i];
\r
494 if( cur_label != prev_label )
\r
495 c_map[++c_count] = prev_label = cur_label;
\r
496 *int_ptr[i] = c_count;
\r
498 // replace labels for missing values with -1
\r
499 for( ; i < sample_count; i++ )
\r
503 else if( ci < 0 ) // process ordered variable
\r
506 udst = (unsigned short*)(buf->data.s + vi*sample_count);
\r
508 idst = buf->data.i + vi*sample_count;
\r
510 for( i = 0; i < sample_count; i++ )
\r
512 float val = ord_nan;
\r
513 int si = sidx ? sidx[i] : i;
\r
514 if( !mask || !mask[si*m_step] )
\r
517 val = (float)idata[si*step];
\r
519 val = fdata[si*step];
\r
521 if( fabs(val) >= ord_nan )
\r
523 sprintf( err, "%d-th value of %d-th (ordered) "
\r
524 "variable (=%g) is too large", i, vi, val );
\r
525 CV_ERROR( CV_StsBadArg, err );
\r
530 udst[i] = (unsigned short)i;
\r
532 idst[i] = i; // ïåðåÃåñòè âûøå â if( idata )
\r
537 icvSortUShAux( udst, num_valid, _fdst);
\r
539 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );
\r
542 if( vi < var_count )
\r
543 data_root->set_num_valid(vi, num_valid);
\r
546 // set sample labels
\r
548 udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);
\r
550 idst = buf->data.i + work_var_count*sample_count;
\r
552 for (i = 0; i < sample_count; i++)
\r
555 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
\r
557 idst[i] = sidx ? sidx[i] : i;
\r
562 unsigned short* udst = 0;
\r
568 udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
\r
569 for( i = vi = 0; i < sample_count; i++ )
\r
571 udst[i] = (unsigned short)vi++;
\r
572 vi &= vi < cv_n ? -1 : 0;
\r
575 for( i = 0; i < sample_count; i++ )
\r
577 int a = cvRandInt(r) % sample_count;
\r
578 int b = cvRandInt(r) % sample_count;
\r
579 unsigned short unsh = (unsigned short)vi;
\r
580 CV_SWAP( udst[a], udst[b], unsh );
\r
585 idst = buf->data.i + (get_work_var_count()-1)*sample_count;
\r
586 for( i = vi = 0; i < sample_count; i++ )
\r
589 vi &= vi < cv_n ? -1 : 0;
\r
592 for( i = 0; i < sample_count; i++ )
\r
594 int a = cvRandInt(r) % sample_count;
\r
595 int b = cvRandInt(r) % sample_count;
\r
596 CV_SWAP( idst[a], idst[b], vi );
\r
602 cat_map->cols = MAX( total_c_count, 1 );
\r
604 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
\r
605 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
\r
606 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
\r
608 have_priors = is_classifier && params.priors;
\r
609 if( is_classifier )
\r
611 int m = get_num_classes();
\r
613 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
\r
614 for( i = 0; i < m; i++ )
\r
616 double val = have_priors ? params.priors[i] : 1.;
\r
618 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
\r
619 priors->data.db[i] = val;
\r
623 // normalize weights
\r
625 cvScale( priors, priors, 1./sum );
\r
627 CV_CALL( priors_mult = cvCloneMat( priors ));
\r
628 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
\r
632 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
\r
633 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
\r
636 int maxNumThreads = 1;
\r
638 maxNumThreads = cv::getNumThreads();
\r
640 pred_float_buf.resize(maxNumThreads);
\r
641 pred_int_buf.resize(maxNumThreads);
\r
642 resp_float_buf.resize(maxNumThreads);
\r
643 resp_int_buf.resize(maxNumThreads);
\r
644 cv_lables_buf.resize(maxNumThreads);
\r
645 sample_idx_buf.resize(maxNumThreads);
\r
646 for( int ti = 0; ti < maxNumThreads; ti++ )
\r
648 pred_float_buf[ti].resize(sample_count);
\r
649 pred_int_buf[ti].resize(sample_count);
\r
650 resp_float_buf[ti].resize(sample_count);
\r
651 resp_int_buf[ti].resize(sample_count);
\r
652 cv_lables_buf[ti].resize(sample_count);
\r
653 sample_idx_buf[ti].resize(sample_count);
\r
666 cvFree( &int_ptr );
\r
667 cvReleaseMat( &var_type0 );
\r
668 cvReleaseMat( &sample_indices );
\r
669 cvReleaseMat( &tmp_map );
\r
672 void CvDTreeTrainData::do_responses_copy()
\r
674 responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
\r
675 cvCopy( responses, responses_copy);
\r
676 responses = responses_copy;
\r
679 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
\r
681 CvDTreeNode* root = 0;
\r
682 CvMat* isubsample_idx = 0;
\r
683 CvMat* subsample_co = 0;
\r
685 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
\r
690 CV_ERROR( CV_StsError, "No training data has been set" );
\r
692 if( _subsample_idx )
\r
693 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
\r
695 if( !isubsample_idx )
\r
697 // make a copy of the root node
\r
700 root = new_node( 0, 1, 0, 0 );
\r
702 *root = *data_root;
\r
703 root->num_valid = temp.num_valid;
\r
704 if( root->num_valid )
\r
706 for( i = 0; i < var_count; i++ )
\r
707 root->num_valid[i] = data_root->num_valid[i];
\r
709 root->cv_Tn = temp.cv_Tn;
\r
710 root->cv_node_risk = temp.cv_node_risk;
\r
711 root->cv_node_error = temp.cv_node_error;
\r
715 int* sidx = isubsample_idx->data.i;
\r
716 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
\r
717 int* co, cur_ofs = 0;
\r
719 int work_var_count = get_work_var_count();
\r
720 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
\r
722 root = new_node( 0, count, 1, 0 );
\r
724 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
\r
725 cvZero( subsample_co );
\r
726 co = subsample_co->data.i;
\r
727 for( i = 0; i < count; i++ )
\r
729 for( i = 0; i < sample_count; i++ )
\r
733 co[i*2+1] = cur_ofs;
\r
734 cur_ofs += co[i*2];
\r
740 for( vi = 0; vi < work_var_count; vi++ )
\r
742 int ci = get_var_type(vi);
\r
744 if( ci >= 0 || vi >= var_count )
\r
746 int* src_buf = get_pred_int_buf();
\r
747 const int* src = 0;
\r
750 get_cat_var_data( data_root, vi, src_buf, &src );
\r
754 unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
\r
755 vi*sample_count + root->offset);
\r
756 for( i = 0; i < count; i++ )
\r
758 int val = src[sidx[i]];
\r
759 udst[i] = (unsigned short)val;
\r
760 num_valid += val >= 0;
\r
765 int* idst = buf->data.i + root->buf_idx*buf->cols +
\r
766 vi*sample_count + root->offset;
\r
767 for( i = 0; i < count; i++ )
\r
769 int val = src[sidx[i]];
\r
771 num_valid += val >= 0;
\r
775 if( vi < var_count )
\r
776 root->set_num_valid(vi, num_valid);
\r
780 int *src_idx_buf = get_pred_int_buf();
\r
781 const int* src_idx = 0;
\r
782 float *src_val_buf = get_pred_float_buf();
\r
783 const float* src_val = 0;
\r
784 int j = 0, idx, count_i;
\r
785 int num_valid = data_root->get_num_valid(vi);
\r
787 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );
\r
790 unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
\r
791 vi*sample_count + data_root->offset);
\r
792 for( i = 0; i < num_valid; i++ )
\r
795 count_i = co[idx*2];
\r
797 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
\r
798 udst_idx[j] = (unsigned short)cur_ofs;
\r
801 root->set_num_valid(vi, j);
\r
803 for( ; i < sample_count; i++ )
\r
806 count_i = co[idx*2];
\r
808 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
\r
809 udst_idx[j] = (unsigned short)cur_ofs;
\r
814 int* idst_idx = buf->data.i + root->buf_idx*buf->cols +
\r
815 vi*sample_count + root->offset;
\r
816 for( i = 0; i < num_valid; i++ )
\r
819 count_i = co[idx*2];
\r
821 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
\r
822 idst_idx[j] = cur_ofs;
\r
825 root->set_num_valid(vi, j);
\r
827 for( ; i < sample_count; i++ )
\r
830 count_i = co[idx*2];
\r
832 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
\r
833 idst_idx[j] = cur_ofs;
\r
838 // sample indices subsampling
\r
839 int* sample_idx_src_buf = get_sample_idx_buf();
\r
840 const int* sample_idx_src = 0;
\r
841 get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);
\r
844 unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
\r
845 get_work_var_count()*sample_count + root->offset);
\r
846 for (i = 0; i < count; i++)
\r
847 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
\r
851 int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols +
\r
852 get_work_var_count()*sample_count + root->offset;
\r
853 for (i = 0; i < count; i++)
\r
854 sample_idx_dst[i] = sample_idx_src[sidx[i]];
\r
860 cvReleaseMat( &isubsample_idx );
\r
861 cvReleaseMat( &subsample_co );
\r
867 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
\r
868 float* values, uchar* missing,
\r
869 float* responses, bool get_class_idx )
\r
871 CvMat* subsample_idx = 0;
\r
872 CvMat* subsample_co = 0;
\r
874 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
\r
878 int i, vi, total = sample_count, count = total, cur_ofs = 0;
\r
882 if( _subsample_idx )
\r
884 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
\r
885 sidx = subsample_idx->data.i;
\r
886 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
\r
887 co = subsample_co->data.i;
\r
888 cvZero( subsample_co );
\r
889 count = subsample_idx->cols + subsample_idx->rows - 1;
\r
890 for( i = 0; i < count; i++ )
\r
892 for( i = 0; i < total; i++ )
\r
894 int count_i = co[i*2];
\r
897 co[i*2+1] = cur_ofs*var_count;
\r
898 cur_ofs += count_i;
\r
904 memset( missing, 1, count*var_count );
\r
906 for( vi = 0; vi < var_count; vi++ )
\r
908 int ci = get_var_type(vi);
\r
909 if( ci >= 0 ) // categorical
\r
911 float* dst = values + vi;
\r
912 uchar* m = missing ? missing + vi : 0;
\r
913 int* src_buf = get_pred_int_buf();
\r
914 const int* src = 0;
\r
915 get_cat_var_data(data_root, vi, src_buf, &src);
\r
917 for( i = 0; i < count; i++, dst += var_count )
\r
919 int idx = sidx ? sidx[i] : i;
\r
920 int val = src[idx];
\r
924 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
\r
931 float* dst = values + vi;
\r
932 uchar* m = missing ? missing + vi : 0;
\r
933 int count1 = data_root->get_num_valid(vi);
\r
934 float *src_val_buf = get_pred_float_buf();
\r
935 const float *src_val = 0;
\r
936 int* src_idx_buf = get_pred_int_buf();
\r
937 const int* src_idx = 0;
\r
938 get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);
\r
940 for( i = 0; i < count1; i++ )
\r
942 int idx = src_idx[i];
\r
946 count_i = co[idx*2];
\r
947 cur_ofs = co[idx*2+1];
\r
950 cur_ofs = idx*var_count;
\r
953 float val = src_val[i];
\r
954 for( ; count_i > 0; count_i--, cur_ofs += var_count )
\r
956 dst[cur_ofs] = val;
\r
968 if( is_classifier )
\r
970 int* src_buf = get_resp_int_buf();
\r
971 const int* src = 0;
\r
972 get_class_labels(data_root, src_buf, &src);
\r
973 for( i = 0; i < count; i++ )
\r
975 int idx = sidx ? sidx[i] : i;
\r
976 int val = get_class_idx ? src[idx] :
\r
977 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
\r
978 responses[i] = (float)val;
\r
983 float *_values_buf = get_resp_float_buf();
\r
984 const float* _values = 0;
\r
985 get_ord_responses(data_root, _values_buf, &_values);
\r
986 for( i = 0; i < count; i++ )
\r
988 int idx = sidx ? sidx[i] : i;
\r
989 responses[i] = _values[idx];
\r
996 cvReleaseMat( &subsample_idx );
\r
997 cvReleaseMat( &subsample_co );
\r
1001 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
\r
1002 int storage_idx, int offset )
\r
1004 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
\r
1006 node->sample_count = count;
\r
1007 node->depth = parent ? parent->depth + 1 : 0;
\r
1008 node->parent = parent;
\r
1009 node->left = node->right = 0;
\r
1012 node->class_idx = 0;
\r
1015 node->buf_idx = storage_idx;
\r
1016 node->offset = offset;
\r
1018 node->num_valid = (int*)cvSetNew( nv_heap );
\r
1020 node->num_valid = 0;
\r
1021 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
\r
1022 node->complexity = 0;
\r
1024 if( params.cv_folds > 0 && cv_heap )
\r
1026 int cv_n = params.cv_folds;
\r
1027 node->Tn = INT_MAX;
\r
1028 node->cv_Tn = (int*)cvSetNew( cv_heap );
\r
1029 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
\r
1030 node->cv_node_error = node->cv_node_risk + cv_n;
\r
1036 node->cv_node_risk = 0;
\r
1037 node->cv_node_error = 0;
\r
1044 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
\r
1045 int split_point, int inversed, float quality )
\r
1047 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
\r
1048 split->var_idx = vi;
\r
1049 split->condensed_idx = INT_MIN;
\r
1050 split->ord.c = cmp_val;
\r
1051 split->ord.split_point = split_point;
\r
1052 split->inversed = inversed;
\r
1053 split->quality = quality;
\r
1060 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
\r
1062 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
\r
1063 int i, n = (max_c_count + 31)/32;
\r
1065 split->var_idx = vi;
\r
1066 split->condensed_idx = INT_MIN;
\r
1067 split->inversed = 0;
\r
1068 split->quality = quality;
\r
1069 for( i = 0; i < n; i++ )
\r
1070 split->subset[i] = 0;
\r
1077 void CvDTreeTrainData::free_node( CvDTreeNode* node )
\r
1079 CvDTreeSplit* split = node->split;
\r
1080 free_node_data( node );
\r
1083 CvDTreeSplit* next = split->next;
\r
1084 cvSetRemoveByPtr( split_heap, split );
\r
1088 cvSetRemoveByPtr( node_heap, node );
\r
1092 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
\r
1094 if( node->num_valid )
\r
1096 cvSetRemoveByPtr( nv_heap, node->num_valid );
\r
1097 node->num_valid = 0;
\r
1099 // do not free cv_* fields, as all the cross-validation related data is released at once.
\r
1103 void CvDTreeTrainData::free_train_data()
\r
1105 cvReleaseMat( &counts );
\r
1106 cvReleaseMat( &buf );
\r
1107 cvReleaseMat( &direction );
\r
1108 cvReleaseMat( &split_buf );
\r
1109 cvReleaseMemStorage( &temp_storage );
\r
1110 cvReleaseMat( &responses_copy );
\r
1111 pred_float_buf.clear();
\r
1112 pred_int_buf.clear();
\r
1113 resp_float_buf.clear();
\r
1114 resp_int_buf.clear();
\r
1115 cv_lables_buf.clear();
\r
1116 sample_idx_buf.clear();
\r
1118 cv_heap = nv_heap = 0;
\r
1122 void CvDTreeTrainData::clear()
\r
1124 free_train_data();
\r
1126 cvReleaseMemStorage( &tree_storage );
\r
1128 cvReleaseMat( &var_idx );
\r
1129 cvReleaseMat( &var_type );
\r
1130 cvReleaseMat( &cat_count );
\r
1131 cvReleaseMat( &cat_ofs );
\r
1132 cvReleaseMat( &cat_map );
\r
1133 cvReleaseMat( &priors );
\r
1134 cvReleaseMat( &priors_mult );
\r
1136 node_heap = split_heap = 0;
\r
1138 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
\r
1139 have_labels = have_priors = is_classifier = false;
\r
1141 buf_count = buf_size = 0;
\r
1150 int CvDTreeTrainData::get_num_classes() const
\r
1152 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
\r
1156 int CvDTreeTrainData::get_var_type(int vi) const
\r
1158 return var_type->data.i[vi];
\r
1161 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )
\r
1163 int vidx = var_idx ? var_idx->data.i[vi] : vi;
\r
1164 int node_sample_count = n->sample_count;
\r
1165 int* sample_indices_buf = get_sample_idx_buf();
\r
1166 const int* sample_indices = 0;
\r
1167 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
\r
1169 get_sample_indices(n, sample_indices_buf, &sample_indices);
\r
1172 *indices = buf->data.i + n->buf_idx*buf->cols +
\r
1173 vi*sample_count + n->offset;
\r
1175 const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
\r
1176 vi*sample_count + n->offset );
\r
1177 for( int i = 0; i < node_sample_count; i++ )
\r
1178 indices_buf[i] = short_indices[i];
\r
1179 *indices = indices_buf;
\r
1182 if( tflag == CV_ROW_SAMPLE )
\r
1184 for( int i = 0; i < node_sample_count &&
\r
1185 ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )
\r
1187 int idx = (*indices)[i];
\r
1188 idx = sample_indices[idx];
\r
1189 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
\r
1193 for( int i = 0; i < node_sample_count &&
\r
1194 ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )
\r
1196 int idx = (*indices)[i];
\r
1197 idx = sample_indices[idx];
\r
1198 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
\r
1201 *ord_values = ord_values_buf;
\r
1202 return 0; //TODO: return the number of non-missing values
\r
1206 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
\r
1208 if (is_classifier)
\r
1209 get_cat_var_data( n, var_count, labels_buf, labels );
\r
1212 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
\r
1214 get_cat_var_data( n, get_work_var_count(), indices_buf, indices );
\r
1217 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)
\r
1219 int sample_count = n->sample_count;
\r
1220 int* indices_buf = get_sample_idx_buf();
\r
1221 const int* indices = 0;
\r
1223 int r_step = responses->step/CV_ELEM_SIZE(responses->type);
\r
1225 get_sample_indices(n, indices_buf, &indices);
\r
1228 for( int i = 0; i < sample_count &&
\r
1229 (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
\r
1231 int idx = indices[i];
\r
1232 values_buf[i] = *(responses->data.fl + idx * r_step);
\r
1235 *values = values_buf;
\r
1239 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
\r
1242 get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );
\r
1246 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
\r
1249 *cat_values = buf->data.i + n->buf_idx*buf->cols +
\r
1250 vi*sample_count + n->offset;
\r
1252 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
\r
1253 vi*sample_count + n->offset);
\r
1254 for( int i = 0; i < n->sample_count; i++ )
\r
1255 cat_values_buf[i] = short_values[i];
\r
1256 *cat_values = cat_values_buf;
\r
1259 return 0; //TODO: return the number of non-missing values
\r
1263 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
\r
1265 int idx = n->buf_idx + 1;
\r
1266 if( idx >= buf_count )
\r
1267 idx = shared ? 1 : 0;
\r
1272 void CvDTreeTrainData::write_params( CvFileStorage* fs )
\r
1274 CV_FUNCNAME( "CvDTreeTrainData::write_params" );
\r
1278 int vi, vcount = var_count;
\r
1280 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
\r
1281 cvWriteInt( fs, "var_all", var_all );
\r
1282 cvWriteInt( fs, "var_count", var_count );
\r
1283 cvWriteInt( fs, "ord_var_count", ord_var_count );
\r
1284 cvWriteInt( fs, "cat_var_count", cat_var_count );
\r
1286 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
\r
1287 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
\r
1289 if( is_classifier )
\r
1291 cvWriteInt( fs, "max_categories", params.max_categories );
\r
1295 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
\r
1298 cvWriteInt( fs, "max_depth", params.max_depth );
\r
1299 cvWriteInt( fs, "min_sample_count", params.min_sample_count );
\r
1300 cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
\r
1302 if( params.cv_folds > 1 )
\r
1304 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
\r
1305 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
\r
1309 cvWrite( fs, "priors", priors );
\r
1311 cvEndWriteStruct( fs );
\r
1314 cvWrite( fs, "var_idx", var_idx );
\r
1316 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
\r
1318 for( vi = 0; vi < vcount; vi++ )
\r
1319 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
\r
1321 cvEndWriteStruct( fs );
\r
1323 if( cat_count && (cat_var_count > 0 || is_classifier) )
\r
1325 CV_ASSERT( cat_count != 0 );
\r
1326 cvWrite( fs, "cat_count", cat_count );
\r
1327 cvWrite( fs, "cat_map", cat_map );
\r
1334 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
\r
1336 CV_FUNCNAME( "CvDTreeTrainData::read_params" );
\r
1340 CvFileNode *tparams_node, *vartype_node;
\r
1341 CvSeqReader reader;
\r
1342 int vi, max_split_size, tree_block_size;
\r
1344 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
\r
1345 var_all = cvReadIntByName( fs, node, "var_all" );
\r
1346 var_count = cvReadIntByName( fs, node, "var_count", var_all );
\r
1347 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
\r
1348 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
\r
1350 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
\r
1352 if( tparams_node ) // training parameters are not necessary
\r
1354 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
\r
1356 if( is_classifier )
\r
1358 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
\r
1362 params.regression_accuracy =
\r
1363 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
\r
1366 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
\r
1367 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
\r
1368 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
\r
1370 if( params.cv_folds > 1 )
\r
1372 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
\r
1373 params.truncate_pruned_tree =
\r
1374 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
\r
1377 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
\r
1380 if( !CV_IS_MAT(priors) )
\r
1381 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
\r
1382 priors_mult = cvCloneMat( priors );
\r
1386 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
\r
1389 if( !CV_IS_MAT(var_idx) ||
\r
1390 (var_idx->cols != 1 && var_idx->rows != 1) ||
\r
1391 var_idx->cols + var_idx->rows - 1 != var_count ||
\r
1392 CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
\r
1393 CV_ERROR( CV_StsParseError,
\r
1394 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
\r
1396 for( vi = 0; vi < var_count; vi++ )
\r
1397 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
\r
1398 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
\r
1401 ////// read var type
\r
1402 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
\r
1404 cat_var_count = 0;
\r
1405 ord_var_count = -1;
\r
1406 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
\r
1408 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
\r
1409 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
\r
1412 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
\r
1413 vartype_node->data.seq->total != var_count )
\r
1414 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
\r
1416 cvStartReadSeq( vartype_node->data.seq, &reader );
\r
1418 for( vi = 0; vi < var_count; vi++ )
\r
1420 CvFileNode* n = (CvFileNode*)reader.ptr;
\r
1421 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
\r
1422 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
\r
1423 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
\r
1424 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
\r
1427 var_type->data.i[var_count] = cat_var_count;
\r
1429 ord_var_count = ~ord_var_count;
\r
1430 if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
\r
1431 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
\r
1434 if( cat_var_count > 0 || is_classifier )
\r
1436 int ccount, total_c_count = 0;
\r
1437 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
\r
1438 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
\r
1440 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
\r
1441 (cat_count->cols != 1 && cat_count->rows != 1) ||
\r
1442 CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
\r
1443 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
\r
1444 (cat_map->cols != 1 && cat_map->rows != 1) ||
\r
1445 CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
\r
1446 CV_ERROR( CV_StsParseError,
\r
1447 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
\r
1449 ccount = cat_var_count + is_classifier;
\r
1451 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
\r
1452 cat_ofs->data.i[0] = 0;
\r
1455 for( vi = 0; vi < ccount; vi++ )
\r
1457 int val = cat_count->data.i[vi];
\r
1459 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
\r
1460 max_c_count = MAX( max_c_count, val );
\r
1461 cat_ofs->data.i[vi+1] = total_c_count += val;
\r
1464 if( cat_map->cols + cat_map->rows - 1 != total_c_count )
\r
1465 CV_ERROR( CV_StsBadSize,
\r
1466 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
\r
1469 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
\r
1470 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
\r
1472 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
\r
1473 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
\r
1474 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
\r
1475 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
\r
1476 sizeof(CvDTreeNode), tree_storage ));
\r
1477 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
\r
1478 max_split_size, tree_storage ));
\r
1483 float* CvDTreeTrainData::get_pred_float_buf()
\r
1485 return &pred_float_buf[cv::getThreadNum()][0];
\r
1487 int* CvDTreeTrainData::get_pred_int_buf()
\r
1489 return &pred_int_buf[cv::getThreadNum()][0];
\r
1491 float* CvDTreeTrainData::get_resp_float_buf()
\r
1493 return &resp_float_buf[cv::getThreadNum()][0];
\r
1495 int* CvDTreeTrainData::get_resp_int_buf()
\r
1497 return &resp_int_buf[cv::getThreadNum()][0];
\r
1499 int* CvDTreeTrainData::get_cv_lables_buf()
\r
1501 return &cv_lables_buf[cv::getThreadNum()][0];
\r
1503 int* CvDTreeTrainData::get_sample_idx_buf()
\r
1505 return &sample_idx_buf[cv::getThreadNum()][0];
\r
1508 /////////////////////// Decision Tree /////////////////////////
\r
1510 CvDTree::CvDTree()
\r
1513 var_importance = 0;
\r
1514 default_model_name = "my_tree";
\r
1520 void CvDTree::clear()
\r
1522 cvReleaseMat( &var_importance );
\r
1525 if( !data->shared )
\r
1532 pruned_tree_idx = -1;
\r
1536 CvDTree::~CvDTree()
\r
1542 const CvDTreeNode* CvDTree::get_root() const
\r
1548 int CvDTree::get_pruned_tree_idx() const
\r
1550 return pruned_tree_idx;
\r
1554 CvDTreeTrainData* CvDTree::get_data()
\r
1560 bool CvDTree::train( const CvMat* _train_data, int _tflag,
\r
1561 const CvMat* _responses, const CvMat* _var_idx,
\r
1562 const CvMat* _sample_idx, const CvMat* _var_type,
\r
1563 const CvMat* _missing_mask, CvDTreeParams _params )
\r
1565 bool result = false;
\r
1567 CV_FUNCNAME( "CvDTree::train" );
\r
1572 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
\r
1573 _var_idx, _sample_idx, _var_type,
\r
1574 _missing_mask, _params, false );
\r
1575 CV_CALL( result = do_train(0) );
\r
1582 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
\r
1584 bool result = false;
\r
1586 CV_FUNCNAME( "CvDTree::train" );
\r
1590 const CvMat* values = _data->get_values();
\r
1591 const CvMat* response = _data->get_response();
\r
1592 const CvMat* missing = _data->get_missing();
\r
1593 const CvMat* var_types = _data->get_var_types();
\r
1594 const CvMat* train_sidx = _data->get_train_sample_idx();
\r
1595 const CvMat* var_idx = _data->get_var_idx();
\r
1597 CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
\r
1598 train_sidx, var_types, missing, _params ) );
\r
1605 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
\r
1607 bool result = false;
\r
1609 CV_FUNCNAME( "CvDTree::train" );
\r
1615 data->shared = true;
\r
1616 CV_CALL( result = do_train(_subsample_idx));
\r
1624 bool CvDTree::do_train( const CvMat* _subsample_idx )
\r
1626 bool result = false;
\r
1628 CV_FUNCNAME( "CvDTree::do_train" );
\r
1632 root = data->subsample_data( _subsample_idx );
\r
1634 CV_CALL( try_split_node(root));
\r
1636 if( data->params.cv_folds > 0 )
\r
1637 CV_CALL( prune_cv());
\r
1639 if( !data->shared )
\r
1640 data->free_train_data();
\r
1650 void CvDTree::try_split_node( CvDTreeNode* node )
\r
1652 CvDTreeSplit* best_split = 0;
\r
1653 int i, n = node->sample_count, vi;
\r
1654 bool can_split = true;
\r
1655 double quality_scale;
\r
1657 calc_node_value( node );
\r
1659 if( node->sample_count <= data->params.min_sample_count ||
\r
1660 node->depth >= data->params.max_depth )
\r
1661 can_split = false;
\r
1663 if( can_split && data->is_classifier )
\r
1665 // check if we have a "pure" node,
\r
1666 // we assume that cls_count is filled by calc_node_value()
\r
1667 int* cls_count = data->counts->data.i;
\r
1668 int nz = 0, m = data->get_num_classes();
\r
1669 for( i = 0; i < m; i++ )
\r
1670 nz += cls_count[i] != 0;
\r
1671 if( nz == 1 ) // there is only one class
\r
1672 can_split = false;
\r
1674 else if( can_split )
\r
1676 if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
\r
1677 can_split = false;
\r
1682 best_split = find_best_split(node);
\r
1683 // TODO: check the split quality ...
\r
1684 node->split = best_split;
\r
1686 if( !can_split || !best_split )
\r
1688 data->free_node_data(node);
\r
1692 quality_scale = calc_node_dir( node );
\r
1693 if( data->params.use_surrogates )
\r
1695 // find all the surrogate splits
\r
1696 // and sort them by their similarity to the primary one
\r
1697 for( vi = 0; vi < data->var_count; vi++ )
\r
1699 CvDTreeSplit* split;
\r
1700 int ci = data->get_var_type(vi);
\r
1702 if( vi == best_split->var_idx )
\r
1706 split = find_surrogate_split_cat( node, vi );
\r
1708 split = find_surrogate_split_ord( node, vi );
\r
1712 // insert the split
\r
1713 CvDTreeSplit* prev_split = node->split;
\r
1714 split->quality = (float)(split->quality*quality_scale);
\r
1716 while( prev_split->next &&
\r
1717 prev_split->next->quality > split->quality )
\r
1718 prev_split = prev_split->next;
\r
1719 split->next = prev_split->next;
\r
1720 prev_split->next = split;
\r
1724 split_node_data( node );
\r
1725 try_split_node( node->left );
\r
1726 try_split_node( node->right );
\r
1730 // calculate direction (left(-1),right(1),missing(0))
\r
1731 // for each sample using the best split
\r
1732 // the function returns scale coefficients for surrogate split quality factors.
\r
1733 // the scale is applied to normalize surrogate split quality relatively to the
\r
1734 // best (primary) split quality. That is, if a surrogate split is absolutely
\r
1735 // identical to the primary split, its quality will be set to the maximum value =
\r
1736 // quality of the primary split; otherwise, it will be lower.
\r
1737 // besides, the function compute node->maxlr,
\r
1738 // minimum possible quality (w/o considering the above mentioned scale)
\r
1739 // for a surrogate split. Surrogate splits with quality less than node->maxlr
\r
1740 // are not discarded.
\r
1741 double CvDTree::calc_node_dir( CvDTreeNode* node )
\r
1743 char* dir = (char*)data->direction->data.ptr;
\r
1744 int i, n = node->sample_count, vi = node->split->var_idx;
\r
1747 assert( !node->split->inversed );
\r
1749 if( data->get_var_type(vi) >= 0 ) // split on categorical var
\r
1751 int* labels_buf = data->get_pred_int_buf();
\r
1752 const int* labels = 0;
\r
1753 const int* subset = node->split->subset;
\r
1754 data->get_cat_var_data( node, vi, labels_buf, &labels );
\r
1755 if( !data->have_priors )
\r
1757 int sum = 0, sum_abs = 0;
\r
1759 for( i = 0; i < n; i++ )
\r
1761 int idx = labels[i];
\r
1762 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
\r
1763 CV_DTREE_CAT_DIR(idx,subset) : 0;
\r
1764 sum += d; sum_abs += d & 1;
\r
1768 R = (sum_abs + sum) >> 1;
\r
1769 L = (sum_abs - sum) >> 1;
\r
1773 const double* priors = data->priors_mult->data.db;
\r
1774 double sum = 0, sum_abs = 0;
\r
1775 int *responses_buf = data->get_resp_int_buf();
\r
1776 const int* responses;
\r
1777 data->get_class_labels(node, responses_buf, &responses);
\r
1779 for( i = 0; i < n; i++ )
\r
1781 int idx = labels[i];
\r
1782 double w = priors[responses[i]];
\r
1783 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
\r
1784 sum += d*w; sum_abs += (d & 1)*w;
\r
1788 R = (sum_abs + sum) * 0.5;
\r
1789 L = (sum_abs - sum) * 0.5;
\r
1792 else // split on ordered var
\r
1794 int split_point = node->split->ord.split_point;
\r
1795 int n1 = node->get_num_valid(vi);
\r
1797 float* val_buf = data->get_pred_float_buf();
\r
1798 const float* val = 0;
\r
1799 int* sorted_buf = data->get_pred_int_buf();
\r
1800 const int* sorted = 0;
\r
1801 data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);
\r
1803 assert( 0 <= split_point && split_point < n1-1 );
\r
1805 if( !data->have_priors )
\r
1807 for( i = 0; i <= split_point; i++ )
\r
1808 dir[sorted[i]] = (char)-1;
\r
1809 for( ; i < n1; i++ )
\r
1810 dir[sorted[i]] = (char)1;
\r
1811 for( ; i < n; i++ )
\r
1812 dir[sorted[i]] = (char)0;
\r
1814 L = split_point-1;
\r
1815 R = n1 - split_point + 1;
\r
1819 const double* priors = data->priors_mult->data.db;
\r
1820 int* responses_buf = data->get_resp_int_buf();
\r
1821 const int* responses = 0;
\r
1822 data->get_class_labels(node, responses_buf, &responses);
\r
1825 for( i = 0; i <= split_point; i++ )
\r
1827 int idx = sorted[i];
\r
1828 double w = priors[responses[idx]];
\r
1829 dir[idx] = (char)-1;
\r
1833 for( ; i < n1; i++ )
\r
1835 int idx = sorted[i];
\r
1836 double w = priors[responses[idx]];
\r
1837 dir[idx] = (char)1;
\r
1841 for( ; i < n; i++ )
\r
1842 dir[sorted[i]] = (char)0;
\r
1845 node->maxlr = MAX( L, R );
\r
1846 return node->split->quality/(L + R);
\r
1849 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
\r
1852 CvDTreeSplit *bestSplit = 0;
\r
1853 int maxNumThreads = 1;
\r
1855 maxNumThreads = cv::getNumThreads();
\r
1857 vector<CvDTreeSplit*> splits(maxNumThreads);
\r
1858 vector<CvDTreeSplit*> bestSplits(maxNumThreads);
\r
1859 for (int i = 0; i < maxNumThreads; i++)
\r
1861 splits[i] = data->new_split_cat( 0, -1.0f );
\r
1862 bestSplits[i] = data->new_split_cat( 0, -1.0f );
\r
1865 bool can_split = false;
\r
1867 #pragma omp parallel for num_threads(maxNumThreads) schedule(dynamic)
\r
1869 for( vi = 0; vi < data->var_count; vi++ )
\r
1871 CvDTreeSplit *res, *t;
\r
1872 int threadIdx = cv::getThreadNum();
\r
1873 int ci = data->get_var_type(vi);
\r
1874 if( node->get_num_valid(vi) <= 1 )
\r
1877 if( data->is_classifier )
\r
1880 res = find_split_cat_class( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );
\r
1882 res = find_split_ord_class( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );
\r
1887 res = find_split_cat_reg( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );
\r
1889 res = find_split_ord_reg( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );
\r
1895 if( bestSplits[threadIdx]->quality < splits[threadIdx]->quality )
\r
1896 CV_SWAP( bestSplits[threadIdx], splits[threadIdx], t );
\r
1901 bestSplit = bestSplits[0];
\r
1902 for(int i = 1; i < maxNumThreads; i++)
\r
1904 if( bestSplit->quality < bestSplits[i]->quality )
\r
1905 bestSplit = bestSplits[i];
\r
1908 for(int i = 0; i < maxNumThreads; i++)
\r
1910 cvSetRemoveByPtr( data->split_heap, splits[i] );
\r
1911 if( bestSplits[i] != bestSplit )
\r
1912 cvSetRemoveByPtr( data->split_heap, bestSplits[i] );
\r
1917 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
\r
1918 float init_quality, CvDTreeSplit* _split )
\r
1920 const float epsilon = FLT_EPSILON*2;
\r
1921 int n = node->sample_count;
\r
1922 int n1 = node->get_num_valid(vi);
\r
1923 int m = data->get_num_classes();
\r
1925 float* values_buf = data->get_pred_float_buf();
\r
1926 const float* values = 0;
\r
1927 int* indices_buf = data->get_pred_int_buf();
\r
1928 const int* indices = 0;
\r
1929 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
\r
1930 int* responses_buf = data->get_resp_int_buf();
\r
1931 const int* responses = 0;
\r
1932 data->get_class_labels( node, responses_buf, &responses );
\r
1934 const int* rc0 = data->counts->data.i;
\r
1935 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
\r
1936 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
\r
1937 int i, best_i = -1;
\r
1938 double lsum2 = 0, rsum2 = 0, best_val = init_quality;
\r
1939 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
\r
1941 // init arrays of class instance counters on both sides of the split
\r
1942 for( i = 0; i < m; i++ )
\r
1948 // compensate for missing values
\r
1949 for( i = n1; i < n; i++ )
\r
1951 rc[responses[indices[i]]]--;
\r
1956 int L = 0, R = n1;
\r
1958 for( i = 0; i < m; i++ )
\r
1959 rsum2 += (double)rc[i]*rc[i];
\r
1961 for( i = 0; i < n1 - 1; i++ )
\r
1963 int idx = responses[indices[i]];
\r
1966 lv = lc[idx]; rv = rc[idx];
\r
1967 lsum2 += lv*2 + 1;
\r
1968 rsum2 -= rv*2 - 1;
\r
1969 lc[idx] = lv + 1; rc[idx] = rv - 1;
\r
1971 if( values[i] + epsilon < values[i+1] )
\r
1973 double val = (lsum2*R + rsum2*L)/((double)L*R);
\r
1974 if( best_val < val )
\r
1984 double L = 0, R = 0;
\r
1985 for( i = 0; i < m; i++ )
\r
1987 double wv = rc[i]*priors[i];
\r
1992 for( i = 0; i < n1 - 1; i++ )
\r
1994 int idx = responses[indices[i]];
\r
1996 double p = priors[idx], p2 = p*p;
\r
1998 lv = lc[idx]; rv = rc[idx];
\r
1999 lsum2 += p2*(lv*2 + 1);
\r
2000 rsum2 -= p2*(rv*2 - 1);
\r
2001 lc[idx] = lv + 1; rc[idx] = rv - 1;
\r
2003 if( values[i] + epsilon < values[i+1] )
\r
2005 double val = (lsum2*R + rsum2*L)/((double)L*R);
\r
2006 if( best_val < val )
\r
2015 CvDTreeSplit* split = 0;
\r
2018 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
\r
2019 split->var_idx = vi;
\r
2020 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
\r
2021 split->ord.split_point = best_i;
\r
2022 split->inversed = 0;
\r
2023 split->quality = (float)best_val;
\r
2029 void CvDTree::cluster_categories( const int* vectors, int n, int m,
\r
2030 int* csums, int k, int* labels )
\r
2032 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
\r
2033 int iters = 0, max_iters = 100;
\r
2035 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
\r
2036 double *v_weights = buf, *c_weights = buf + k;
\r
2037 bool modified = true;
\r
2038 CvRNG* r = &data->rng;
\r
2040 // assign labels randomly
\r
2041 for( i = idx = 0; i < n; i++ )
\r
2044 const int* v = vectors + i*m;
\r
2045 labels[i] = idx++;
\r
2046 idx &= idx < k ? -1 : 0;
\r
2048 // compute weight of each vector
\r
2049 for( j = 0; j < m; j++ )
\r
2051 v_weights[i] = sum ? 1./sum : 0.;
\r
2054 for( i = 0; i < n; i++ )
\r
2056 int i1 = cvRandInt(r) % n;
\r
2057 int i2 = cvRandInt(r) % n;
\r
2058 CV_SWAP( labels[i1], labels[i2], j );
\r
2061 for( iters = 0; iters <= max_iters; iters++ )
\r
2063 // calculate csums
\r
2064 for( i = 0; i < k; i++ )
\r
2066 for( j = 0; j < m; j++ )
\r
2067 csums[i*m + j] = 0;
\r
2070 for( i = 0; i < n; i++ )
\r
2072 const int* v = vectors + i*m;
\r
2073 int* s = csums + labels[i]*m;
\r
2074 for( j = 0; j < m; j++ )
\r
2078 // exit the loop here, when we have up-to-date csums
\r
2079 if( iters == max_iters || !modified )
\r
2084 // calculate weight of each cluster
\r
2085 for( i = 0; i < k; i++ )
\r
2087 const int* s = csums + i*m;
\r
2089 for( j = 0; j < m; j++ )
\r
2091 c_weights[i] = sum ? 1./sum : 0;
\r
2094 // now for each vector determine the closest cluster
\r
2095 for( i = 0; i < n; i++ )
\r
2097 const int* v = vectors + i*m;
\r
2098 double alpha = v_weights[i];
\r
2099 double min_dist2 = DBL_MAX;
\r
2102 for( idx = 0; idx < k; idx++ )
\r
2104 const int* s = csums + idx*m;
\r
2105 double dist2 = 0., beta = c_weights[idx];
\r
2106 for( j = 0; j < m; j++ )
\r
2108 double t = v[j]*alpha - s[j]*beta;
\r
2111 if( min_dist2 > dist2 )
\r
2113 min_dist2 = dist2;
\r
2118 if( min_idx != labels[i] )
\r
2120 labels[i] = min_idx;
\r
2126 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
2128 int ci = data->get_var_type(vi);
\r
2129 int n = node->sample_count;
\r
2130 int m = data->get_num_classes();
\r
2131 int _mi = data->cat_count->data.i[ci], mi = _mi;
\r
2133 int* labels_buf = data->get_pred_int_buf();
\r
2134 const int* labels = 0;
\r
2135 data->get_cat_var_data(node, vi, labels_buf, &labels);
\r
2136 int *responses_buf = data->get_resp_int_buf();
\r
2137 const int* responses = 0;
\r
2138 data->get_class_labels(node, responses_buf, &responses);
\r
2140 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
\r
2141 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
\r
2142 int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
\r
2143 double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
\r
2144 int* cluster_labels = 0;
\r
2145 int** int_ptr = 0;
\r
2147 double L = 0, R = 0;
\r
2148 double best_val = init_quality;
\r
2149 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
\r
2150 const double* priors = data->priors_mult->data.db;
\r
2152 // init array of counters:
\r
2153 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
\r
2154 for( j = -1; j < mi; j++ )
\r
2155 for( k = 0; k < m; k++ )
\r
2158 for( i = 0; i < n; i++ )
\r
2160 j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
\r
2167 if( mi > data->params.max_categories )
\r
2169 mi = MIN(data->params.max_categories, n);
\r
2171 cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
\r
2172 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
\r
2175 subset_n = 1 << mi;
\r
2180 int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
\r
2181 for( j = 0; j < mi; j++ )
\r
2182 int_ptr[j] = cjk + j*2 + 1;
\r
2183 icvSortIntPtr( int_ptr, mi, 0 );
\r
2188 for( k = 0; k < m; k++ )
\r
2191 for( j = 0; j < mi; j++ )
\r
2192 sum += cjk[j*m + k];
\r
2197 for( j = 0; j < mi; j++ )
\r
2200 for( k = 0; k < m; k++ )
\r
2201 sum += cjk[j*m + k]*priors[k];
\r
2202 c_weights[j] = sum;
\r
2203 R += c_weights[j];
\r
2206 for( ; subset_i < subset_n; subset_i++ )
\r
2210 double lsum2 = 0, rsum2 = 0;
\r
2213 idx = (int)(int_ptr[subset_i] - cjk)/2;
\r
2216 int graycode = (subset_i>>1)^subset_i;
\r
2217 int diff = graycode ^ prevcode;
\r
2219 // determine index of the changed bit.
\r
2221 idx = diff >= (1 << 16) ? 16 : 0;
\r
2222 u.f = (float)(((diff >> 16) | diff) & 65535);
\r
2223 idx += (u.i >> 23) - 127;
\r
2224 subtract = graycode < prevcode;
\r
2225 prevcode = graycode;
\r
2228 crow = cjk + idx*m;
\r
2229 weight = c_weights[idx];
\r
2230 if( weight < FLT_EPSILON )
\r
2235 for( k = 0; k < m; k++ )
\r
2238 int lval = lc[k] + t;
\r
2239 int rval = rc[k] - t;
\r
2240 double p = priors[k], p2 = p*p;
\r
2241 lsum2 += p2*lval*lval;
\r
2242 rsum2 += p2*rval*rval;
\r
2243 lc[k] = lval; rc[k] = rval;
\r
2250 for( k = 0; k < m; k++ )
\r
2253 int lval = lc[k] - t;
\r
2254 int rval = rc[k] + t;
\r
2255 double p = priors[k], p2 = p*p;
\r
2256 lsum2 += p2*lval*lval;
\r
2257 rsum2 += p2*rval*rval;
\r
2258 lc[k] = lval; rc[k] = rval;
\r
2264 if( L > FLT_EPSILON && R > FLT_EPSILON )
\r
2266 double val = (lsum2*R + rsum2*L)/((double)L*R);
\r
2267 if( best_val < val )
\r
2270 best_subset = subset_i;
\r
2275 CvDTreeSplit* split = 0;
\r
2276 if( best_subset >= 0 )
\r
2278 split = _split ? _split : data->new_split_cat( 0, -1.0f );
\r
2279 split->var_idx = vi;
\r
2280 split->quality = (float)best_val;
\r
2281 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
\r
2284 for( i = 0; i <= best_subset; i++ )
\r
2286 idx = (int)(int_ptr[i] - cjk) >> 1;
\r
2287 split->subset[idx >> 5] |= 1 << (idx & 31);
\r
2292 for( i = 0; i < _mi; i++ )
\r
2294 idx = cluster_labels ? cluster_labels[i] : i;
\r
2295 if( best_subset & (1 << idx) )
\r
2296 split->subset[i >> 5] |= 1 << (i & 31);
\r
2304 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
2306 const float epsilon = FLT_EPSILON*2;
\r
2307 int n = node->sample_count;
\r
2308 int n1 = node->get_num_valid(vi);
\r
2310 float* values_buf = data->get_pred_float_buf();
\r
2311 const float* values = 0;
\r
2312 int* indices_buf = data->get_pred_int_buf();
\r
2313 const int* indices = 0;
\r
2314 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
\r
2315 float* responses_buf = data->get_resp_float_buf();
\r
2316 const float* responses = 0;
\r
2317 data->get_ord_responses( node, responses_buf, &responses );
\r
2319 int i, best_i = -1;
\r
2320 double best_val = init_quality, lsum = 0, rsum = node->value*n;
\r
2321 int L = 0, R = n1;
\r
2323 // compensate for missing values
\r
2324 for( i = n1; i < n; i++ )
\r
2325 rsum -= responses[indices[i]];
\r
2327 // find the optimal split
\r
2328 for( i = 0; i < n1 - 1; i++ )
\r
2330 float t = responses[indices[i]];
\r
2335 if( values[i] + epsilon < values[i+1] )
\r
2337 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
\r
2338 if( best_val < val )
\r
2346 CvDTreeSplit* split = 0;
\r
2349 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
\r
2350 split->var_idx = vi;
\r
2351 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
\r
2352 split->ord.split_point = best_i;
\r
2353 split->inversed = 0;
\r
2354 split->quality = (float)best_val;
\r
2359 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
2361 int ci = data->get_var_type(vi);
\r
2362 int n = node->sample_count;
\r
2363 int mi = data->cat_count->data.i[ci];
\r
2364 int* labels_buf = data->get_pred_int_buf();
\r
2365 const int* labels = 0;
\r
2366 float* responses_buf = data->get_resp_float_buf();
\r
2367 const float* responses = 0;
\r
2368 data->get_cat_var_data(node, vi, labels_buf, &labels);
\r
2369 data->get_ord_responses(node, responses_buf, &responses);
\r
2371 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
\r
2372 int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
\r
2373 double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );
\r
2374 int i, L = 0, R = 0;
\r
2375 double best_val = init_quality, lsum = 0, rsum = 0;
\r
2376 int best_subset = -1, subset_i;
\r
2378 for( i = -1; i < mi; i++ )
\r
2379 sum[i] = counts[i] = 0;
\r
2381 // calculate sum response and weight of each category of the input var
\r
2382 for( i = 0; i < n; i++ )
\r
2384 int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
\r
2385 double s = sum[idx] + responses[i];
\r
2386 int nc = counts[idx] + 1;
\r
2391 // calculate average response in each category
\r
2392 for( i = 0; i < mi; i++ )
\r
2396 sum[i] /= MAX(counts[i],1);
\r
2397 sum_ptr[i] = sum + i;
\r
2400 icvSortDblPtr( sum_ptr, mi, 0 );
\r
2402 // revert back to unnormalized sums
\r
2403 // (there should be a very little loss of accuracy)
\r
2404 for( i = 0; i < mi; i++ )
\r
2405 sum[i] *= counts[i];
\r
2407 for( subset_i = 0; subset_i < mi-1; subset_i++ )
\r
2409 int idx = (int)(sum_ptr[subset_i] - sum);
\r
2410 int ni = counts[idx];
\r
2414 double s = sum[idx];
\r
2415 lsum += s; L += ni;
\r
2416 rsum -= s; R -= ni;
\r
2420 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
\r
2421 if( best_val < val )
\r
2424 best_subset = subset_i;
\r
2430 CvDTreeSplit* split = 0;
\r
2431 if( best_subset >= 0 )
\r
2433 split = _split ? _split : data->new_split_cat( 0, -1.0f);
\r
2434 split->var_idx = vi;
\r
2435 split->quality = (float)best_val;
\r
2436 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
\r
2437 for( i = 0; i <= best_subset; i++ )
\r
2439 int idx = (int)(sum_ptr[i] - sum);
\r
2440 split->subset[idx >> 5] |= 1 << (idx & 31);
\r
2446 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
\r
2448 const float epsilon = FLT_EPSILON*2;
\r
2449 const char* dir = (char*)data->direction->data.ptr;
\r
2450 int n1 = node->get_num_valid(vi);
\r
2451 float* values_buf = data->get_pred_float_buf();
\r
2452 const float* values = 0;
\r
2453 int* indices_buf = data->get_pred_int_buf();
\r
2454 const int* indices = 0;
\r
2455 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
\r
2456 // LL - number of samples that both the primary and the surrogate splits send to the left
\r
2457 // LR - ... primary split sends to the left and the surrogate split sends to the right
\r
2458 // RL - ... primary split sends to the right and the surrogate split sends to the left
\r
2459 // RR - ... both send to the right
\r
2460 int i, best_i = -1, best_inversed = 0;
\r
2463 if( !data->have_priors )
\r
2465 int LL = 0, RL = 0, LR, RR;
\r
2466 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
\r
2467 int sum = 0, sum_abs = 0;
\r
2469 for( i = 0; i < n1; i++ )
\r
2471 int d = dir[indices[i]];
\r
2472 sum += d; sum_abs += d & 1;
\r
2475 // sum_abs = R + L; sum = R - L
\r
2476 RR = (sum_abs + sum) >> 1;
\r
2477 LR = (sum_abs - sum) >> 1;
\r
2479 // initially all the samples are sent to the right by the surrogate split,
\r
2480 // LR of them are sent to the left by primary split, and RR - to the right.
\r
2481 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
\r
2482 for( i = 0; i < n1 - 1; i++ )
\r
2484 int d = dir[indices[i]];
\r
2489 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
\r
2491 best_val = LL + RR;
\r
2492 best_i = i; best_inversed = 0;
\r
2498 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
\r
2500 best_val = RL + LR;
\r
2501 best_i = i; best_inversed = 1;
\r
2505 best_val = _best_val;
\r
2509 double LL = 0, RL = 0, LR, RR;
\r
2510 double worst_val = node->maxlr;
\r
2511 double sum = 0, sum_abs = 0;
\r
2512 const double* priors = data->priors_mult->data.db;
\r
2513 int* responses_buf = data->get_resp_int_buf();
\r
2514 const int* responses = 0;
\r
2515 data->get_class_labels(node, responses_buf, &responses);
\r
2516 best_val = worst_val;
\r
2518 for( i = 0; i < n1; i++ )
\r
2520 int idx = indices[i];
\r
2521 double w = priors[responses[idx]];
\r
2523 sum += d*w; sum_abs += (d & 1)*w;
\r
2526 // sum_abs = R + L; sum = R - L
\r
2527 RR = (sum_abs + sum)*0.5;
\r
2528 LR = (sum_abs - sum)*0.5;
\r
2530 // initially all the samples are sent to the right by the surrogate split,
\r
2531 // LR of them are sent to the left by primary split, and RR - to the right.
\r
2532 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
\r
2533 for( i = 0; i < n1 - 1; i++ )
\r
2535 int idx = indices[i];
\r
2536 double w = priors[responses[idx]];
\r
2542 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
\r
2544 best_val = LL + RR;
\r
2545 best_i = i; best_inversed = 0;
\r
2551 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
\r
2553 best_val = RL + LR;
\r
2554 best_i = i; best_inversed = 1;
\r
2559 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
\r
2560 (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
\r
2564 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
\r
2566 const char* dir = (char*)data->direction->data.ptr;
\r
2567 int n = node->sample_count;
\r
2568 int* labels_buf = data->get_pred_int_buf();
\r
2569 const int* labels = 0;
\r
2570 data->get_cat_var_data(node, vi, labels_buf, &labels);
\r
2571 // LL - number of samples that both the primary and the surrogate splits send to the left
\r
2572 // LR - ... primary split sends to the left and the surrogate split sends to the right
\r
2573 // RL - ... primary split sends to the right and the surrogate split sends to the left
\r
2574 // RR - ... both send to the right
\r
2575 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
\r
2576 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
\r
2577 double best_val = 0;
\r
2578 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
\r
2579 double* rc = lc + mi + 1;
\r
2581 for( i = -1; i < mi; i++ )
\r
2582 lc[i] = rc[i] = 0;
\r
2584 // for each category calculate the weight of samples
\r
2585 // sent to the left (lc) and to the right (rc) by the primary split
\r
2586 if( !data->have_priors )
\r
2588 int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
\r
2589 int* _rc = _lc + mi + 1;
\r
2591 for( i = -1; i < mi; i++ )
\r
2592 _lc[i] = _rc[i] = 0;
\r
2594 for( i = 0; i < n; i++ )
\r
2596 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
\r
2598 int sum = _lc[idx] + d;
\r
2599 int sum_abs = _rc[idx] + (d & 1);
\r
2600 _lc[idx] = sum; _rc[idx] = sum_abs;
\r
2603 for( i = 0; i < mi; i++ )
\r
2606 int sum_abs = _rc[i];
\r
2607 lc[i] = (sum_abs - sum) >> 1;
\r
2608 rc[i] = (sum_abs + sum) >> 1;
\r
2613 const double* priors = data->priors_mult->data.db;
\r
2614 int* responses_buf = data->get_resp_int_buf();
\r
2615 const int* responses = 0;
\r
2616 data->get_class_labels(node, responses_buf, &responses);
\r
2618 for( i = 0; i < n; i++ )
\r
2620 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
\r
2621 double w = priors[responses[i]];
\r
2623 double sum = lc[idx] + d*w;
\r
2624 double sum_abs = rc[idx] + (d & 1)*w;
\r
2625 lc[idx] = sum; rc[idx] = sum_abs;
\r
2628 for( i = 0; i < mi; i++ )
\r
2630 double sum = lc[i];
\r
2631 double sum_abs = rc[i];
\r
2632 lc[i] = (sum_abs - sum) * 0.5;
\r
2633 rc[i] = (sum_abs + sum) * 0.5;
\r
2637 // 2. now form the split.
\r
2638 // in each category send all the samples to the same direction as majority
\r
2639 for( i = 0; i < mi; i++ )
\r
2641 double lval = lc[i], rval = rc[i];
\r
2644 split->subset[i >> 5] |= 1 << (i & 31);
\r
2652 split->quality = (float)best_val;
\r
2653 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
\r
2654 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
\r
2660 void CvDTree::calc_node_value( CvDTreeNode* node )
\r
2662 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
\r
2663 int* cv_labels_buf = data->get_cv_lables_buf();
\r
2664 const int* cv_labels = 0;
\r
2665 data->get_cv_labels(node, cv_labels_buf, &cv_labels);
\r
2667 if( data->is_classifier )
\r
2669 // in case of classification tree:
\r
2670 // * node value is the label of the class that has the largest weight in the node.
\r
2671 // * node risk is the weighted number of misclassified samples,
\r
2672 // * j-th cross-validation fold value and risk are calculated as above,
\r
2673 // but using the samples with cv_labels(*)!=j.
\r
2674 // * j-th cross-validation fold error is calculated as the weighted number of
\r
2675 // misclassified samples with cv_labels(*)==j.
\r
2677 // compute the number of instances of each class
\r
2678 int* cls_count = data->counts->data.i;
\r
2679 int* responses_buf = data->get_resp_int_buf();
\r
2680 const int* responses = 0;
\r
2681 data->get_class_labels(node, responses_buf, &responses);
\r
2682 int m = data->get_num_classes();
\r
2683 int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
\r
2684 double max_val = -1, total_weight = 0;
\r
2686 double* priors = data->priors_mult->data.db;
\r
2688 for( k = 0; k < m; k++ )
\r
2693 for( i = 0; i < n; i++ )
\r
2694 cls_count[responses[i]]++;
\r
2698 for( j = 0; j < cv_n; j++ )
\r
2699 for( k = 0; k < m; k++ )
\r
2700 cv_cls_count[j*m + k] = 0;
\r
2702 for( i = 0; i < n; i++ )
\r
2704 j = cv_labels[i]; k = responses[i];
\r
2705 cv_cls_count[j*m + k]++;
\r
2708 for( j = 0; j < cv_n; j++ )
\r
2709 for( k = 0; k < m; k++ )
\r
2710 cls_count[k] += cv_cls_count[j*m + k];
\r
2713 if( data->have_priors && node->parent == 0 )
\r
2715 // compute priors_mult from priors, take the sample ratio into account.
\r
2717 for( k = 0; k < m; k++ )
\r
2719 int n_k = cls_count[k];
\r
2720 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
\r
2724 for( k = 0; k < m; k++ )
\r
2728 for( k = 0; k < m; k++ )
\r
2730 double val = cls_count[k]*priors[k];
\r
2731 total_weight += val;
\r
2732 if( max_val < val )
\r
2739 node->class_idx = max_k;
\r
2740 node->value = data->cat_map->data.i[
\r
2741 data->cat_ofs->data.i[data->cat_var_count] + max_k];
\r
2742 node->node_risk = total_weight - max_val;
\r
2744 for( j = 0; j < cv_n; j++ )
\r
2746 double sum_k = 0, sum = 0, max_val_k = 0;
\r
2747 max_val = -1; max_k = -1;
\r
2749 for( k = 0; k < m; k++ )
\r
2751 double w = priors[k];
\r
2752 double val_k = cv_cls_count[j*m + k]*w;
\r
2753 double val = cls_count[k]*w - val_k;
\r
2756 if( max_val < val )
\r
2759 max_val_k = val_k;
\r
2764 node->cv_Tn[j] = INT_MAX;
\r
2765 node->cv_node_risk[j] = sum - max_val;
\r
2766 node->cv_node_error[j] = sum_k - max_val_k;
\r
2771 // in case of regression tree:
\r
2772 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
\r
2773 // n is the number of samples in the node.
\r
2774 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
\r
2775 // * j-th cross-validation fold value and risk are calculated as above,
\r
2776 // but using the samples with cv_labels(*)!=j.
\r
2777 // * j-th cross-validation fold error is calculated
\r
2778 // using samples with cv_labels(*)==j as the test subset:
\r
2779 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
\r
2780 // where node_value_j is the node value calculated
\r
2781 // as described in the previous bullet, and summation is done
\r
2782 // over the samples with cv_labels(*)==j.
\r
2784 double sum = 0, sum2 = 0;
\r
2785 float* values_buf = data->get_resp_float_buf();
\r
2786 const float* values = 0;
\r
2787 data->get_ord_responses(node, values_buf, &values);
\r
2788 double *cv_sum = 0, *cv_sum2 = 0;
\r
2789 int* cv_count = 0;
\r
2793 for( i = 0; i < n; i++ )
\r
2795 double t = values[i];
\r
2802 cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
\r
2803 cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
\r
2804 cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
\r
2806 for( j = 0; j < cv_n; j++ )
\r
2808 cv_sum[j] = cv_sum2[j] = 0.;
\r
2812 for( i = 0; i < n; i++ )
\r
2815 double t = values[i];
\r
2816 double s = cv_sum[j] + t;
\r
2817 double s2 = cv_sum2[j] + t*t;
\r
2818 int nc = cv_count[j] + 1;
\r
2824 for( j = 0; j < cv_n; j++ )
\r
2827 sum2 += cv_sum2[j];
\r
2831 node->node_risk = sum2 - (sum/n)*sum;
\r
2832 node->value = sum/n;
\r
2834 for( j = 0; j < cv_n; j++ )
\r
2836 double s = cv_sum[j], si = sum - s;
\r
2837 double s2 = cv_sum2[j], s2i = sum2 - s2;
\r
2838 int c = cv_count[j], ci = n - c;
\r
2839 double r = si/MAX(ci,1);
\r
2840 node->cv_node_risk[j] = s2i - r*r*ci;
\r
2841 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
\r
2842 node->cv_Tn[j] = INT_MAX;
\r
2848 void CvDTree::complete_node_dir( CvDTreeNode* node )
\r
2850 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
\r
2851 int nz = n - node->get_num_valid(node->split->var_idx);
\r
2852 char* dir = (char*)data->direction->data.ptr;
\r
2854 // try to complete direction using surrogate splits
\r
2855 if( nz && data->params.use_surrogates )
\r
2857 CvDTreeSplit* split = node->split->next;
\r
2858 for( ; split != 0 && nz; split = split->next )
\r
2860 int inversed_mask = split->inversed ? -1 : 0;
\r
2861 vi = split->var_idx;
\r
2863 if( data->get_var_type(vi) >= 0 ) // split on categorical var
\r
2865 int* labels_buf = data->get_pred_int_buf();
\r
2866 const int* labels = 0;
\r
2867 data->get_cat_var_data(node, vi, labels_buf, &labels);
\r
2868 const int* subset = split->subset;
\r
2870 for( i = 0; i < n; i++ )
\r
2872 int idx = labels[i];
\r
2873 if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
\r
2876 int d = CV_DTREE_CAT_DIR(idx,subset);
\r
2877 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
\r
2883 else // split on ordered var
\r
2885 float* values_buf = data->get_pred_float_buf();
\r
2886 const float* values = 0;
\r
2887 int* indices_buf = data->get_pred_int_buf();
\r
2888 const int* indices = 0;
\r
2889 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
\r
2890 int split_point = split->ord.split_point;
\r
2891 int n1 = node->get_num_valid(vi);
\r
2893 assert( 0 <= split_point && split_point < n-1 );
\r
2895 for( i = 0; i < n1; i++ )
\r
2897 int idx = indices[i];
\r
2900 int d = i <= split_point ? -1 : 1;
\r
2901 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
\r
2910 // find the default direction for the rest
\r
2913 for( i = nr = 0; i < n; i++ )
\r
2916 d0 = nl > nr ? -1 : nr > nl;
\r
2919 // make sure that every sample is directed either to the left or to the right
\r
2920 for( i = 0; i < n; i++ )
\r
2930 dir[i] = (char)d; // remap (-1,1) to (0,1)
\r
2935 void CvDTree::split_node_data( CvDTreeNode* node )
\r
2937 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
\r
2938 char* dir = (char*)data->direction->data.ptr;
\r
2939 CvDTreeNode *left = 0, *right = 0;
\r
2940 int* new_idx = data->split_buf->data.i;
\r
2941 int new_buf_idx = data->get_child_buf_idx( node );
\r
2942 int work_var_count = data->get_work_var_count();
\r
2943 CvMat* buf = data->buf;
\r
2944 int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
\r
2946 complete_node_dir(node);
\r
2948 for( i = nl = nr = 0; i < n; i++ )
\r
2951 // initialize new indices for splitting ordered variables
\r
2952 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
\r
2958 bool split_input_data;
\r
2959 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
\r
2960 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
\r
2962 split_input_data = node->depth + 1 < data->params.max_depth &&
\r
2963 (node->left->sample_count > data->params.min_sample_count ||
\r
2964 node->right->sample_count > data->params.min_sample_count);
\r
2966 // split ordered variables, keep both halves sorted.
\r
2967 for( vi = 0; vi < data->var_count; vi++ )
\r
2969 int ci = data->get_var_type(vi);
\r
2970 int n1 = node->get_num_valid(vi);
\r
2971 int *src_idx_buf = data->get_pred_int_buf();
\r
2972 const int* src_idx = 0;
\r
2973 float *src_val_buf = data->get_pred_float_buf();
\r
2974 const float* src_val = 0;
\r
2976 if( ci >= 0 || !split_input_data )
\r
2979 data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);
\r
2981 for(i = 0; i < n; i++)
\r
2982 temp_buf[i] = src_idx[i];
\r
2984 if (data->is_buf_16u)
\r
2986 unsigned short *ldst, *rdst, *ldst0, *rdst0;
\r
2987 //unsigned short tl, tr;
\r
2988 ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
\r
2989 vi*scount + left->offset);
\r
2990 rdst0 = rdst = (unsigned short*)(ldst + nl);
\r
2993 for( i = 0; i < n1; i++ )
\r
2995 int idx = temp_buf[i];
\r
2997 idx = new_idx[idx];
\r
3000 *rdst = (unsigned short)idx;
\r
3005 *ldst = (unsigned short)idx;
\r
3010 left->set_num_valid(vi, (int)(ldst - ldst0));
\r
3011 right->set_num_valid(vi, (int)(rdst - rdst0));
\r
3014 for( ; i < n; i++ )
\r
3016 int idx = temp_buf[i];
\r
3018 idx = new_idx[idx];
\r
3021 *rdst = (unsigned short)idx;
\r
3026 *ldst = (unsigned short)idx;
\r
3033 int *ldst0, *ldst, *rdst0, *rdst;
\r
3034 ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols +
\r
3035 vi*scount + left->offset;
\r
3036 rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols +
\r
3037 vi*scount + right->offset;
\r
3040 for( i = 0; i < n1; i++ )
\r
3042 int idx = temp_buf[i];
\r
3044 idx = new_idx[idx];
\r
3057 left->set_num_valid(vi, (int)(ldst - ldst0));
\r
3058 right->set_num_valid(vi, (int)(rdst - rdst0));
\r
3061 for( ; i < n; i++ )
\r
3063 int idx = temp_buf[i];
\r
3065 idx = new_idx[idx];
\r
3080 // split categorical vars, responses and cv_labels using new_idx relocation table
\r
3081 for( vi = 0; vi < work_var_count; vi++ )
\r
3083 int ci = data->get_var_type(vi);
\r
3084 int n1 = node->get_num_valid(vi), nr1 = 0;
\r
3086 if( ci < 0 || (vi < data->var_count && !split_input_data) )
\r
3089 int *src_lbls_buf = data->get_pred_int_buf();
\r
3090 const int* src_lbls = 0;
\r
3091 data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
\r
3093 for(i = 0; i < n; i++)
\r
3094 temp_buf[i] = src_lbls[i];
\r
3096 if (data->is_buf_16u)
\r
3098 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
\r
3099 vi*scount + left->offset);
\r
3100 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
\r
3101 vi*scount + right->offset);
\r
3103 for( i = 0; i < n; i++ )
\r
3106 int idx = temp_buf[i];
\r
3109 *rdst = (unsigned short)idx;
\r
3111 nr1 += (idx != 65535 )&d;
\r
3115 *ldst = (unsigned short)idx;
\r
3120 if( vi < data->var_count )
\r
3122 left->set_num_valid(vi, n1 - nr1);
\r
3123 right->set_num_valid(vi, nr1);
\r
3128 int *ldst = buf->data.i + left->buf_idx*buf->cols +
\r
3129 vi*scount + left->offset;
\r
3130 int *rdst = buf->data.i + right->buf_idx*buf->cols +
\r
3131 vi*scount + right->offset;
\r
3133 for( i = 0; i < n; i++ )
\r
3136 int idx = temp_buf[i];
\r
3141 nr1 += (idx >= 0)&d;
\r
3151 if( vi < data->var_count )
\r
3153 left->set_num_valid(vi, n1 - nr1);
\r
3154 right->set_num_valid(vi, nr1);
\r
3160 // split sample indices
\r
3161 int *sample_idx_src_buf = data->get_sample_idx_buf();
\r
3162 const int* sample_idx_src = 0;
\r
3163 data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
\r
3165 for(i = 0; i < n; i++)
\r
3166 temp_buf[i] = sample_idx_src[i];
\r
3168 int pos = data->get_work_var_count();
\r
3169 if (data->is_buf_16u)
\r
3171 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
\r
3172 pos*scount + left->offset);
\r
3173 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
\r
3174 pos*scount + right->offset);
\r
3175 for (i = 0; i < n; i++)
\r
3178 unsigned short idx = (unsigned short)temp_buf[i];
\r
3193 int* ldst = buf->data.i + left->buf_idx*buf->cols +
\r
3194 pos*scount + left->offset;
\r
3195 int* rdst = buf->data.i + right->buf_idx*buf->cols +
\r
3196 pos*scount + right->offset;
\r
3197 for (i = 0; i < n; i++)
\r
3200 int idx = temp_buf[i];
\r
3214 // deallocate the parent node data that is not needed anymore
\r
3215 data->free_node_data(node);
\r
3218 float CvDTree::calc_error( CvMLData* _data, int type )
\r
3221 const CvMat* values = _data->get_values();
\r
3222 const CvMat* response = _data->get_response();
\r
3223 const CvMat* missing = _data->get_missing();
\r
3224 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
\r
3225 const CvMat* var_types = _data->get_var_types();
\r
3226 int* sidx = sample_idx ? sample_idx->data.i : 0;
\r
3227 int r_step = CV_IS_MAT_CONT(response->type) ?
\r
3228 1 : response->step / CV_ELEM_SIZE(response->type);
\r
3229 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
\r
3230 int sample_count = sample_idx ? sample_idx->cols : 0;
\r
3231 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
\r
3232 if ( is_classifier )
\r
3234 for( int i = 0; i < sample_count; i++ )
\r
3236 CvMat sample, miss;
\r
3237 int si = sidx ? sidx[i] : i;
\r
3238 cvGetRow( values, &sample, si );
\r
3240 cvGetRow( missing, &miss, si );
\r
3241 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
\r
3242 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
\r
3245 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
\r
3249 for( int i = 0; i < sample_count; i++ )
\r
3251 CvMat sample, miss;
\r
3252 int si = sidx ? sidx[i] : i;
\r
3253 cvGetRow( values, &sample, si );
\r
3255 cvGetRow( missing, &miss, si );
\r
3256 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
\r
3257 float d = r - response->data.fl[si*r_step];
\r
3260 err = sample_count ? err / (float)sample_count : -FLT_MAX;
\r
3265 void CvDTree::prune_cv()
\r
3269 CvMat* err_jk = 0;
\r
3271 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
\r
3272 // 2. choose the best tree index (if need, apply 1SE rule).
\r
3273 // 3. store the best index and cut the branches.
\r
3275 CV_FUNCNAME( "CvDTree::prune_cv" );
\r
3279 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
\r
3280 // currently, 1SE for regression is not implemented
\r
3281 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
\r
3283 double min_err = 0, min_err_se = 0;
\r
3286 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
\r
3288 // build the main tree sequence, calculate alpha's
\r
3289 for(;;tree_count++)
\r
3291 double min_alpha = update_tree_rnc(tree_count, -1);
\r
3292 if( cut_tree(tree_count, -1, min_alpha) )
\r
3295 if( ab->cols <= tree_count )
\r
3297 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
\r
3298 for( ti = 0; ti < ab->cols; ti++ )
\r
3299 temp->data.db[ti] = ab->data.db[ti];
\r
3300 cvReleaseMat( &ab );
\r
3305 ab->data.db[tree_count] = min_alpha;
\r
3308 ab->data.db[0] = 0.;
\r
3310 if( tree_count > 0 )
\r
3312 for( ti = 1; ti < tree_count-1; ti++ )
\r
3313 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
\r
3314 ab->data.db[tree_count-1] = DBL_MAX*0.5;
\r
3316 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
\r
3317 err = err_jk->data.db;
\r
3319 for( j = 0; j < cv_n; j++ )
\r
3321 int tj = 0, tk = 0;
\r
3322 for( ; tk < tree_count; tj++ )
\r
3324 double min_alpha = update_tree_rnc(tj, j);
\r
3325 if( cut_tree(tj, j, min_alpha) )
\r
3326 min_alpha = DBL_MAX;
\r
3328 for( ; tk < tree_count; tk++ )
\r
3330 if( ab->data.db[tk] > min_alpha )
\r
3332 err[j*tree_count + tk] = root->tree_error;
\r
3337 for( ti = 0; ti < tree_count; ti++ )
\r
3339 double sum_err = 0;
\r
3340 for( j = 0; j < cv_n; j++ )
\r
3341 sum_err += err[j*tree_count + ti];
\r
3342 if( ti == 0 || sum_err < min_err )
\r
3344 min_err = sum_err;
\r
3347 min_err_se = sqrt( sum_err*(n - sum_err) );
\r
3349 else if( sum_err < min_err + min_err_se )
\r
3354 pruned_tree_idx = min_idx;
\r
3355 free_prune_data(data->params.truncate_pruned_tree != 0);
\r
3359 cvReleaseMat( &err_jk );
\r
3360 cvReleaseMat( &ab );
\r
3361 cvReleaseMat( &temp );
\r
3365 double CvDTree::update_tree_rnc( int T, int fold )
\r
3367 CvDTreeNode* node = root;
\r
3368 double min_alpha = DBL_MAX;
\r
3372 CvDTreeNode* parent;
\r
3375 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
\r
3376 if( t <= T || !node->left )
\r
3378 node->complexity = 1;
\r
3379 node->tree_risk = node->node_risk;
\r
3380 node->tree_error = 0.;
\r
3383 node->tree_risk = node->cv_node_risk[fold];
\r
3384 node->tree_error = node->cv_node_error[fold];
\r
3388 node = node->left;
\r
3391 for( parent = node->parent; parent && parent->right == node;
\r
3392 node = parent, parent = parent->parent )
\r
3394 parent->complexity += node->complexity;
\r
3395 parent->tree_risk += node->tree_risk;
\r
3396 parent->tree_error += node->tree_error;
\r
3398 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
\r
3399 - parent->tree_risk)/(parent->complexity - 1);
\r
3400 min_alpha = MIN( min_alpha, parent->alpha );
\r
3406 parent->complexity = node->complexity;
\r
3407 parent->tree_risk = node->tree_risk;
\r
3408 parent->tree_error = node->tree_error;
\r
3409 node = parent->right;
\r
3416 int CvDTree::cut_tree( int T, int fold, double min_alpha )
\r
3418 CvDTreeNode* node = root;
\r
3424 CvDTreeNode* parent;
\r
3427 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
\r
3428 if( t <= T || !node->left )
\r
3430 if( node->alpha <= min_alpha + FLT_EPSILON )
\r
3433 node->cv_Tn[fold] = T;
\r
3436 if( node == root )
\r
3440 node = node->left;
\r
3443 for( parent = node->parent; parent && parent->right == node;
\r
3444 node = parent, parent = parent->parent )
\r
3450 node = parent->right;
\r
3457 void CvDTree::free_prune_data(bool cut_tree)
\r
3459 CvDTreeNode* node = root;
\r
3463 CvDTreeNode* parent;
\r
3466 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
\r
3467 // as we will clear the whole cross-validation heap at the end
\r
3469 node->cv_node_error = node->cv_node_risk = 0;
\r
3472 node = node->left;
\r
3475 for( parent = node->parent; parent && parent->right == node;
\r
3476 node = parent, parent = parent->parent )
\r
3478 if( cut_tree && parent->Tn <= pruned_tree_idx )
\r
3480 data->free_node( parent->left );
\r
3481 data->free_node( parent->right );
\r
3482 parent->left = parent->right = 0;
\r
3489 node = parent->right;
\r
3492 if( data->cv_heap )
\r
3493 cvClearSet( data->cv_heap );
\r
3497 void CvDTree::free_tree()
\r
3499 if( root && data && data->shared )
\r
3501 pruned_tree_idx = INT_MIN;
\r
3502 free_prune_data(true);
\r
3503 data->free_node(root);
\r
3508 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
\r
3509 const CvMat* _missing, bool preprocessed_input ) const
\r
3511 CvDTreeNode* result = 0;
\r
3514 CV_FUNCNAME( "CvDTree::predict" );
\r
3518 int i, step, mstep = 0;
\r
3519 const float* sample;
\r
3520 const uchar* m = 0;
\r
3521 CvDTreeNode* node = root;
\r
3528 CV_ERROR( CV_StsError, "The tree has not been trained yet" );
\r
3530 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
\r
3531 (_sample->cols != 1 && _sample->rows != 1) ||
\r
3532 (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
\r
3533 (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
\r
3534 CV_ERROR( CV_StsBadArg,
\r
3535 "the input sample must be 1d floating-point vector with the same "
\r
3536 "number of elements as the total number of variables used for training" );
\r
3538 sample = _sample->data.fl;
\r
3539 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
\r
3541 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
\r
3543 int n = data->cat_count->cols;
\r
3544 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
\r
3545 for( i = 0; i < n; i++ )
\r
3551 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
\r
3552 !CV_ARE_SIZES_EQ(_missing, _sample) )
\r
3553 CV_ERROR( CV_StsBadArg,
\r
3554 "the missing data mask must be 8-bit vector of the same size as input sample" );
\r
3555 m = _missing->data.ptr;
\r
3556 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
\r
3559 vtype = data->var_type->data.i;
\r
3560 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
\r
3561 cmap = data->cat_map ? data->cat_map->data.i : 0;
\r
3562 cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
\r
3564 while( node->Tn > pruned_tree_idx && node->left )
\r
3566 CvDTreeSplit* split = node->split;
\r
3568 for( ; !dir && split != 0; split = split->next )
\r
3570 int vi = split->var_idx;
\r
3571 int ci = vtype[vi];
\r
3572 i = vidx ? vidx[vi] : vi;
\r
3573 float val = sample[i*step];
\r
3574 if( m && m[i*mstep] )
\r
3576 if( ci < 0 ) // ordered
\r
3577 dir = val <= split->ord.c ? -1 : 1;
\r
3578 else // categorical
\r
3581 if( preprocessed_input )
\r
3588 int a = c = cofs[ci];
\r
3589 int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
\r
3591 int ival = cvRound(val);
\r
3593 CV_ERROR( CV_StsBadArg,
\r
3594 "one of input categorical variable is not an integer" );
\r
3601 if( ival < cmap[c] )
\r
3603 else if( ival > cmap[c] )
\r
3609 if( c < 0 || ival != cmap[c] )
\r
3612 catbuf[ci] = c -= cofs[ci];
\r
3615 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
\r
3616 dir = CV_DTREE_CAT_DIR(c, split->subset);
\r
3619 if( split->inversed )
\r
3625 double diff = node->right->sample_count - node->left->sample_count;
\r
3626 dir = diff < 0 ? -1 : 1;
\r
3628 node = dir < 0 ? node->left : node->right;
\r
3639 const CvMat* CvDTree::get_var_importance()
\r
3641 if( !var_importance )
\r
3643 CvDTreeNode* node = root;
\r
3644 double* importance;
\r
3647 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
\r
3648 cvZero( var_importance );
\r
3649 importance = var_importance->data.db;
\r
3653 CvDTreeNode* parent;
\r
3654 for( ;; node = node->left )
\r
3656 CvDTreeSplit* split = node->split;
\r
3658 if( !node->left || node->Tn <= pruned_tree_idx )
\r
3661 for( ; split != 0; split = split->next )
\r
3662 importance[split->var_idx] += split->quality;
\r
3665 for( parent = node->parent; parent && parent->right == node;
\r
3666 node = parent, parent = parent->parent )
\r
3672 node = parent->right;
\r
3675 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
\r
3678 return var_importance;
\r
3682 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
\r
3686 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
\r
3687 cvWriteInt( fs, "var", split->var_idx );
\r
3688 cvWriteReal( fs, "quality", split->quality );
\r
3690 ci = data->get_var_type(split->var_idx);
\r
3691 if( ci >= 0 ) // split on a categorical var
\r
3693 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
\r
3694 for( i = 0; i < n; i++ )
\r
3695 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
\r
3697 // ad-hoc rule when to use inverse categorical split notation
\r
3698 // to achieve more compact and clear representation
\r
3699 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
\r
3701 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
\r
3702 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
\r
3704 for( i = 0; i < n; i++ )
\r
3706 int dir = CV_DTREE_CAT_DIR(i,split->subset);
\r
3707 if( dir*default_dir < 0 )
\r
3708 cvWriteInt( fs, 0, i );
\r
3710 cvEndWriteStruct( fs );
\r
3713 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
\r
3715 cvEndWriteStruct( fs );
\r
3719 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
\r
3721 CvDTreeSplit* split;
\r
3723 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
\r
3725 cvWriteInt( fs, "depth", node->depth );
\r
3726 cvWriteInt( fs, "sample_count", node->sample_count );
\r
3727 cvWriteReal( fs, "value", node->value );
\r
3729 if( data->is_classifier )
\r
3730 cvWriteInt( fs, "norm_class_idx", node->class_idx );
\r
3732 cvWriteInt( fs, "Tn", node->Tn );
\r
3733 cvWriteInt( fs, "complexity", node->complexity );
\r
3734 cvWriteReal( fs, "alpha", node->alpha );
\r
3735 cvWriteReal( fs, "node_risk", node->node_risk );
\r
3736 cvWriteReal( fs, "tree_risk", node->tree_risk );
\r
3737 cvWriteReal( fs, "tree_error", node->tree_error );
\r
3741 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
\r
3743 for( split = node->split; split != 0; split = split->next )
\r
3744 write_split( fs, split );
\r
3746 cvEndWriteStruct( fs );
\r
3749 cvEndWriteStruct( fs );
\r
3753 void CvDTree::write_tree_nodes( CvFileStorage* fs )
\r
3755 //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
\r
3759 CvDTreeNode* node = root;
\r
3761 // traverse the tree and save all the nodes in depth-first order
\r
3764 CvDTreeNode* parent;
\r
3767 write_node( fs, node );
\r
3770 node = node->left;
\r
3773 for( parent = node->parent; parent && parent->right == node;
\r
3774 node = parent, parent = parent->parent )
\r
3780 node = parent->right;
\r
3787 void CvDTree::write( CvFileStorage* fs, const char* name )
\r
3789 //CV_FUNCNAME( "CvDTree::write" );
\r
3793 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
\r
3795 get_var_importance();
\r
3796 data->write_params( fs );
\r
3797 if( var_importance )
\r
3798 cvWrite( fs, "var_importance", var_importance );
\r
3801 cvEndWriteStruct( fs );
\r
3807 void CvDTree::write( CvFileStorage* fs )
\r
3809 //CV_FUNCNAME( "CvDTree::write" );
\r
3813 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
\r
3815 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
\r
3816 write_tree_nodes( fs );
\r
3817 cvEndWriteStruct( fs );
\r
3823 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
\r
3825 CvDTreeSplit* split = 0;
\r
3827 CV_FUNCNAME( "CvDTree::read_split" );
\r
3833 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
\r
3834 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
\r
3836 vi = cvReadIntByName( fs, fnode, "var", -1 );
\r
3837 if( (unsigned)vi >= (unsigned)data->var_count )
\r
3838 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
\r
3840 ci = data->get_var_type(vi);
\r
3841 if( ci >= 0 ) // split on categorical var
\r
3843 int i, n = data->cat_count->data.i[ci], inversed = 0, val;
\r
3844 CvSeqReader reader;
\r
3845 CvFileNode* inseq;
\r
3846 split = data->new_split_cat( vi, 0 );
\r
3847 inseq = cvGetFileNodeByName( fs, fnode, "in" );
\r
3850 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
\r
3854 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
\r
3855 CV_ERROR( CV_StsParseError,
\r
3856 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
\r
3858 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
\r
3860 val = inseq->data.i;
\r
3861 if( (unsigned)val >= (unsigned)n )
\r
3862 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
\r
3864 split->subset[val >> 5] |= 1 << (val & 31);
\r
3868 cvStartReadSeq( inseq->data.seq, &reader );
\r
3870 for( i = 0; i < reader.seq->total; i++ )
\r
3872 CvFileNode* inode = (CvFileNode*)reader.ptr;
\r
3873 val = inode->data.i;
\r
3874 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
\r
3875 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
\r
3877 split->subset[val >> 5] |= 1 << (val & 31);
\r
3878 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
\r
3882 // for categorical splits we do not use inversed splits,
\r
3883 // instead we inverse the variable set in the split
\r
3885 for( i = 0; i < (n + 31) >> 5; i++ )
\r
3886 split->subset[i] ^= -1;
\r
3890 CvFileNode* cmp_node;
\r
3891 split = data->new_split_ord( vi, 0, 0, 0, 0 );
\r
3893 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
\r
3896 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
\r
3897 split->inversed = 1;
\r
3900 split->ord.c = (float)cvReadReal( cmp_node );
\r
3903 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
\r
3911 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
\r
3913 CvDTreeNode* node = 0;
\r
3915 CV_FUNCNAME( "CvDTree::read_node" );
\r
3919 CvFileNode* splits;
\r
3922 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
\r
3923 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
\r
3925 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
\r
3926 depth = cvReadIntByName( fs, fnode, "depth", -1 );
\r
3927 if( depth != node->depth )
\r
3928 CV_ERROR( CV_StsParseError, "incorrect node depth" );
\r
3930 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
\r
3931 node->value = cvReadRealByName( fs, fnode, "value" );
\r
3932 if( data->is_classifier )
\r
3933 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
\r
3935 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
\r
3936 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
\r
3937 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
\r
3938 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
\r
3939 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
\r
3940 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
\r
3942 splits = cvGetFileNodeByName( fs, fnode, "splits" );
\r
3945 CvSeqReader reader;
\r
3946 CvDTreeSplit* last_split = 0;
\r
3948 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
\r
3949 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
\r
3951 cvStartReadSeq( splits->data.seq, &reader );
\r
3952 for( i = 0; i < reader.seq->total; i++ )
\r
3954 CvDTreeSplit* split;
\r
3955 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
\r
3957 node->split = last_split = split;
\r
3959 last_split = last_split->next = split;
\r
3961 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
\r
3971 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
\r
3973 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
\r
3977 CvSeqReader reader;
\r
3978 CvDTreeNode _root;
\r
3979 CvDTreeNode* parent = &_root;
\r
3981 parent->left = parent->right = parent->parent = 0;
\r
3983 cvStartReadSeq( fnode->data.seq, &reader );
\r
3985 for( i = 0; i < reader.seq->total; i++ )
\r
3987 CvDTreeNode* node;
\r
3989 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
\r
3990 if( !parent->left )
\r
3991 parent->left = node;
\r
3993 parent->right = node;
\r
3998 while( parent && parent->right )
\r
3999 parent = parent->parent;
\r
4002 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
\r
4005 root = _root.left;
\r
4011 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
\r
4013 CvDTreeTrainData* _data = new CvDTreeTrainData();
\r
4014 _data->read_params( fs, fnode );
\r
4016 read( fs, fnode, _data );
\r
4017 get_var_importance();
\r
4021 // a special entry point for reading weak decision trees from the tree ensembles
\r
4022 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
\r
4024 CV_FUNCNAME( "CvDTree::read" );
\r
4028 CvFileNode* tree_nodes;
\r
4033 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
\r
4034 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
\r
4035 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
\r
4037 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
\r
4038 read_tree_nodes( fs, tree_nodes );
\r
4043 /* End of file. */
\r