1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 By downloading, copying, installing or using the software you agree to this license.
6 If you do not agree to this license, do not download, install,
7 copy or use the software.
10 Intel License Agreement
12 Copyright (C) 2000, Intel Corporation, all rights reserved.
13 Third party copyrights are property of their respective owners.
15 Redistribution and use in source and binary forms, with or without modification,
16 are permitted provided that the following conditions are met:
18 * Redistribution's of source code must retain the above copyright notice,
19 this list of conditions and the following disclaimer.
21 * Redistribution's in binary form must reproduce the above copyright notice,
22 this list of conditions and the following disclaimer in the documentation
23 and/or other materials provided with the distribution.
25 * The name of Intel Corporation may not be used to endorse or promote products
26 derived from this software without specific prior written permission.
28 This software is provided by the copyright holders and contributors "as is" and
29 any express or implied warranties, including, but not limited to, the implied
30 warranties of merchantability and fitness for a particular purpose are disclaimed.
31 In no event shall the Intel Corporation or contributors be liable for any direct,
32 indirect, incidental, special, exemplary, or consequential damages
33 (including, but not limited to, procurement of substitute goods or services;
34 loss of use, data, or profits; or business interruption) however caused
35 and on any theory of liability, whether in contract, strict liability,
36 or tort (including negligence or otherwise) arising in any way out of
37 the use of this software, even if advised of the possibility of such damage.
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
56 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
57 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
58 bool _shared, bool _add_labels, bool _update_data )
60 CvMat* sample_indices = 0;
64 CvPair16u32s* pair16u32s_ptr = 0;
65 CvDTreeTrainData* data = 0;
68 unsigned short* udst = 0;
71 CV_FUNCNAME( "CvERTreeTrainData::set_data" );
75 int sample_all = 0, r_type = 0, cv_n;
76 int total_c_count = 0;
77 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
78 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
81 const int *sidx = 0, *vidx = 0;
83 if ( _params.use_surrogates )
84 CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
86 if( _update_data && data_root )
88 CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
96 CV_CALL( set_params( _params ));
98 // check parameter types and sizes
99 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
101 train_data = _train_data;
102 responses = _responses;
103 missing_mask = _missing_mask;
105 if( _tflag == CV_ROW_SAMPLE )
107 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
110 ms_step = _missing_mask->step, mv_step = 1;
114 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
117 mv_step = _missing_mask->step, ms_step = 1;
121 sample_count = sample_all;
126 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
127 sidx = sample_indices->data.i;
128 sample_count = sample_indices->rows + sample_indices->cols - 1;
133 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
134 vidx = var_idx->data.i;
135 var_count = var_idx->rows + var_idx->cols - 1;
138 if( !CV_IS_MAT(_responses) ||
139 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
140 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
141 (_responses->rows != 1 && _responses->cols != 1) ||
142 _responses->rows + _responses->cols - 1 != sample_all )
143 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
144 "floating-point vector containing as many elements as "
145 "the total number of samples in the training data matrix" );
148 if ( sample_count < 65536 )
152 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
154 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
160 is_classifier = r_type == CV_VAR_CATEGORICAL;
162 // step 0. calc the number of categorical vars
163 for( vi = 0; vi < var_count; vi++ )
165 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
166 cat_var_count++ : ord_var_count--;
169 ord_var_count = ~ord_var_count;
170 cv_n = params.cv_folds;
171 // set the two last elements of var_type array to be able
172 // to locate responses and cross-validation labels using
173 // the corresponding get_* functions.
174 var_type->data.i[var_count] = cat_var_count;
175 var_type->data.i[var_count+1] = cat_var_count+1;
177 // in case of single ordered predictor we need dummy cv_labels
178 // for safe split_node_data() operation
179 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
181 work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
182 buf_size = (work_var_count + 1)*sample_count;
184 buf_count = shared ? 2 : 1;
188 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
189 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
193 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
194 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
197 size = is_classifier ? cat_var_count+1 : cat_var_count;
198 size = !size ? 1 : size;
199 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
200 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
202 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
203 size = !size ? 1 : size;
204 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
206 // now calculate the maximum size of split,
207 // create memory storage that will keep nodes and splits of the decision tree
208 // allocate root node and the buffer for the whole training data
209 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
210 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
211 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
212 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
213 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
214 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
216 nv_size = var_count*sizeof(int);
217 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
219 temp_block_size = nv_size;
223 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
224 CV_ERROR( CV_StsOutOfRange,
225 "The many folds in cross-validation for such a small dataset" );
227 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
228 temp_block_size = MAX(temp_block_size, cv_size);
231 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
232 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
233 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
235 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
237 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
244 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
245 if (is_buf_16u && (cat_var_count || is_classifier))
246 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
248 // transform the training data to convenient representation
249 for( vi = 0; vi <= var_count; vi++ )
252 const uchar* mask = 0;
253 int m_step = 0, step;
254 const int* idata = 0;
255 const float* fdata = 0;
258 if( vi < var_count ) // analyze i-th input variable
260 int vi0 = vidx ? vidx[vi] : vi;
261 ci = get_var_type(vi);
262 step = ds_step; m_step = ms_step;
263 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
264 idata = _train_data->data.i + vi0*dv_step;
266 fdata = _train_data->data.fl + vi0*dv_step;
268 mask = _missing_mask->data.ptr + vi0*mv_step;
270 else // analyze _responses
273 step = CV_IS_MAT_CONT(_responses->type) ?
274 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
275 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
276 idata = _responses->data.i;
278 fdata = _responses->data.fl;
281 if( (vi < var_count && ci>=0) ||
282 (vi == var_count && is_classifier) ) // process categorical variable or response
284 int c_count, prev_label;
288 udst = (unsigned short*)(buf->data.s + ci*sample_count);
290 idst = buf->data.i + ci*sample_count;
293 for( i = 0; i < sample_count; i++ )
295 int val = INT_MAX, si = sidx ? sidx[i] : i;
296 if( !mask || !mask[si*m_step] )
299 val = idata[si*step];
302 float t = fdata[si*step];
306 sprintf( err, "%d-th value of %d-th (categorical) "
307 "variable is not an integer", i, vi );
308 CV_ERROR( CV_StsBadArg, err );
314 sprintf( err, "%d-th value of %d-th (categorical) "
315 "variable is too large", i, vi );
316 CV_ERROR( CV_StsBadArg, err );
323 pair16u32s_ptr[i].u = udst + i;
324 pair16u32s_ptr[i].i = _idst + i;
329 int_ptr[i] = idst + i;
333 c_count = num_valid > 0;
337 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
338 // count the categories
339 for( i = 1; i < num_valid; i++ )
340 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
345 icvSortIntPtr( int_ptr, sample_count, 0 );
346 // count the categories
347 for( i = 1; i < num_valid; i++ )
348 c_count += *int_ptr[i] != *int_ptr[i-1];
352 max_c_count = MAX( max_c_count, c_count );
353 cat_count->data.i[ci] = c_count;
354 cat_ofs->data.i[ci] = total_c_count;
356 // resize cat_map, if need
357 if( cat_map->cols < total_c_count + c_count )
360 CV_CALL( cat_map = cvCreateMat( 1,
361 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
362 for( i = 0; i < total_c_count; i++ )
363 cat_map->data.i[i] = tmp_map->data.i[i];
364 cvReleaseMat( &tmp_map );
367 c_map = cat_map->data.i + total_c_count;
368 total_c_count += c_count;
373 // compact the class indices and build the map
374 prev_label = ~*pair16u32s_ptr[0].i;
375 for( i = 0; i < num_valid; i++ )
377 int cur_label = *pair16u32s_ptr[i].i;
378 if( cur_label != prev_label )
379 c_map[++c_count] = prev_label = cur_label;
380 *pair16u32s_ptr[i].u = (unsigned short)c_count;
382 // replace labels for missing values with 65535
383 for( ; i < sample_count; i++ )
384 *pair16u32s_ptr[i].u = 65535;
388 // compact the class indices and build the map
389 prev_label = ~*int_ptr[0];
390 for( i = 0; i < num_valid; i++ )
392 int cur_label = *int_ptr[i];
393 if( cur_label != prev_label )
394 c_map[++c_count] = prev_label = cur_label;
395 *int_ptr[i] = c_count;
397 // replace labels for missing values with -1
398 for( ; i < sample_count; i++ )
402 else if( ci < 0 ) // process ordered variable
404 for( i = 0; i < sample_count; i++ )
407 int si = sidx ? sidx[i] : i;
408 if( !mask || !mask[si*m_step] )
411 val = (float)idata[si*step];
413 val = fdata[si*step];
415 if( fabs(val) >= ord_nan )
417 sprintf( err, "%d-th value of %d-th (ordered) "
418 "variable (=%g) is too large", i, vi, val );
419 CV_ERROR( CV_StsBadArg, err );
426 data_root->set_num_valid(vi, num_valid);
431 udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
433 idst = buf->data.i + get_work_var_count()*sample_count;
435 for (i = 0; i < sample_count; i++)
438 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
440 idst[i] = sidx ? sidx[i] : i;
445 unsigned short* udst = 0;
451 udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
452 for( i = vi = 0; i < sample_count; i++ )
454 udst[i] = (unsigned short)vi++;
455 vi &= vi < cv_n ? -1 : 0;
458 for( i = 0; i < sample_count; i++ )
460 int a = cvRandInt(r) % sample_count;
461 int b = cvRandInt(r) % sample_count;
462 unsigned short unsh = (unsigned short)vi;
463 CV_SWAP( udst[a], udst[b], unsh );
468 idst = buf->data.i + (get_work_var_count()-1)*sample_count;
469 for( i = vi = 0; i < sample_count; i++ )
472 vi &= vi < cv_n ? -1 : 0;
475 for( i = 0; i < sample_count; i++ )
477 int a = cvRandInt(r) % sample_count;
478 int b = cvRandInt(r) % sample_count;
479 CV_SWAP( idst[a], idst[b], vi );
485 cat_map->cols = MAX( total_c_count, 1 );
487 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
488 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
489 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
491 have_priors = is_classifier && params.priors;
494 int m = get_num_classes();
496 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
497 for( i = 0; i < m; i++ )
499 double val = have_priors ? params.priors[i] : 1.;
501 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
502 priors->data.db[i] = val;
508 cvScale( priors, priors, 1./sum );
510 CV_CALL( priors_mult = cvCloneMat( priors ));
511 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
514 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
515 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
527 cvReleaseMat( &var_type0 );
528 cvReleaseMat( &sample_indices );
529 cvReleaseMat( &tmp_map );
532 void CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
533 const float** ord_values, const int** missing, int* sample_indices_buf )
535 int vidx = var_idx ? var_idx->data.i[vi] : vi;
536 int node_sample_count = n->sample_count;
537 // may use missing_buf as buffer for sample indices!
538 const int* sample_indices = get_sample_indices(n, sample_indices_buf ? sample_indices_buf : missing_buf);
540 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
541 int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
542 if( tflag == CV_ROW_SAMPLE )
544 for( int i = 0; i < node_sample_count; i++ )
546 int idx = sample_indices[i];
547 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
548 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
552 for( int i = 0; i < node_sample_count; i++ )
554 int idx = sample_indices[i];
555 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
556 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
558 *ord_values = ord_values_buf;
559 *missing = missing_buf;
563 const int* CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
565 return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf );
569 const int* CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
572 return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf );
577 const int* CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf )
579 int ci = get_var_type( vi);
580 const int* cat_values = 0;
582 cat_values = buf->data.i + n->buf_idx*buf->cols + ci*sample_count + n->offset;
584 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
585 ci*sample_count + n->offset);
586 for( int i = 0; i < n->sample_count; i++ )
587 cat_values_buf[i] = short_values[i];
588 cat_values = cat_values_buf;
593 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
594 float* values, uchar* missing,
595 float* responses, bool get_class_idx )
597 CvMat* subsample_idx = 0;
598 CvMat* subsample_co = 0;
600 cv::AutoBuffer<uchar> inn_buf(sample_count*(sizeof(float) + sizeof(int)));
602 CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
606 int i, vi, total = sample_count, count = total, cur_ofs = 0;
612 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
613 sidx = subsample_idx->data.i;
614 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
615 co = subsample_co->data.i;
616 cvZero( subsample_co );
617 count = subsample_idx->cols + subsample_idx->rows - 1;
618 for( i = 0; i < count; i++ )
620 for( i = 0; i < total; i++ )
622 int count_i = co[i*2];
625 co[i*2+1] = cur_ofs*var_count;
632 memset( missing, 1, count*var_count );
634 for( vi = 0; vi < var_count; vi++ )
636 int ci = get_var_type(vi);
637 if( ci >= 0 ) // categorical
639 float* dst = values + vi;
640 uchar* m = missing ? missing + vi : 0;
641 int* lbls_buf = (int*)(uchar*)inn_buf;
642 const int* src = get_cat_var_data(data_root, vi, lbls_buf);
644 for( i = 0; i < count; i++, dst += var_count )
646 int idx = sidx ? sidx[i] : i;
651 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
658 int* mis_buf = (int*)(uchar*)inn_buf;
659 const float *dst = 0;
661 get_ord_var_data(data_root, vi, values + vi, mis_buf, &dst, &mis, 0);
662 for (int si = 0; si < total; si++)
663 *(missing + vi + si) = mis[si] == 0 ? 0 : 1;
672 int* lbls_buf = (int*)(uchar*)inn_buf;
673 const int* src = get_class_labels(data_root, lbls_buf);
674 for( i = 0; i < count; i++ )
676 int idx = sidx ? sidx[i] : i;
677 int val = get_class_idx ? src[idx] :
678 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
679 responses[i] = (float)val;
684 float* _values_buf = (float*)(uchar*)inn_buf;
685 int* sample_idx_buf = (int*)(_values_buf + sample_count);
686 const float* _values = get_ord_responses(data_root, _values_buf, sample_idx_buf);
687 for( i = 0; i < count; i++ )
689 int idx = sidx ? sidx[i] : i;
690 responses[i] = _values[idx];
697 cvReleaseMat( &subsample_idx );
698 cvReleaseMat( &subsample_co );
701 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
703 CvDTreeNode* root = 0;
705 CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
710 CV_ERROR( CV_StsError, "No training data has been set" );
712 if( !_subsample_idx )
714 // make a copy of the root node
717 root = new_node( 0, 1, 0, 0 );
720 root->num_valid = temp.num_valid;
721 if( root->num_valid )
723 for( i = 0; i < var_count; i++ )
724 root->num_valid[i] = data_root->num_valid[i];
726 root->cv_Tn = temp.cv_Tn;
727 root->cv_node_risk = temp.cv_node_risk;
728 root->cv_node_error = temp.cv_node_error;
731 CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
737 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
739 char* dir = (char*)data->direction->data.ptr;
740 int i, n = node->sample_count, vi = node->split->var_idx;
743 assert( !node->split->inversed );
745 if( data->get_var_type(vi) >= 0 ) // split on categorical var
747 cv::AutoBuffer<uchar> inn_buf(n*sizeof(int)*(!data->have_priors ? 1 : 2));
748 int* labels_buf = (int*)(uchar*)inn_buf;
749 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
750 const int* subset = node->split->subset;
751 if( !data->have_priors )
753 int sum = 0, sum_abs = 0;
755 for( i = 0; i < n; i++ )
758 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
759 CV_DTREE_CAT_DIR(idx,subset) : 0;
760 sum += d; sum_abs += d & 1;
764 R = (sum_abs + sum) >> 1;
765 L = (sum_abs - sum) >> 1;
769 const double* priors = data->priors_mult->data.db;
770 double sum = 0, sum_abs = 0;
771 int *responses_buf = labels_buf + n;
772 const int* responses = data->get_class_labels(node, responses_buf);
774 for( i = 0; i < n; i++ )
777 double w = priors[responses[i]];
778 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
779 sum += d*w; sum_abs += (d & 1)*w;
783 R = (sum_abs + sum) * 0.5;
784 L = (sum_abs - sum) * 0.5;
787 else // split on ordered var
789 float split_val = node->split->ord.c;
790 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(!data->have_priors ? 1 : 2) + sizeof(float)));
791 float* val_buf = (float*)(uchar*)inn_buf;
792 int* missing_buf = (int*)(val_buf + n);
793 const float* val = 0;
794 const int* missing = 0;
795 data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing, 0 );
797 if( !data->have_priors )
800 for( i = 0; i < n; i++ )
806 if ( val[i] < split_val)
821 const double* priors = data->priors_mult->data.db;
822 int* responses_buf = missing_buf + n;
823 const int* responses = data->get_class_labels(node, responses_buf);
825 for( i = 0; i < n; i++ )
831 double w = priors[responses[i]];
832 if ( val[i] < split_val)
847 node->maxlr = MAX( L, R );
848 return node->split->quality/(L + R);
851 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
854 const float epsilon = FLT_EPSILON*2;
855 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
857 int n = node->sample_count, i;
858 int m = data->get_num_classes();
860 cv::AutoBuffer<uchar> inn_buf;
862 inn_buf.allocate(n*(2*sizeof(int) + sizeof(float)));
863 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
864 float* values_buf = (float*)ext_buf;
865 int* missing_buf = (int*)(values_buf + n);
866 const float* values = 0;
867 const int* missing = 0;
868 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
869 int* responses_buf = missing_buf + n;
870 const int* responses = data->get_class_labels( node, responses_buf );
872 double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
873 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
874 bool is_find_split = false;
877 while ( missing[smpi] && (smpi < n) )
883 for (; smpi < n; smpi++)
885 float ptemp = values[smpi];
886 int m = missing[smpi];
893 float fdiff = pmax-pmin;
896 is_find_split = true;
897 CvRNG* rng = &data->rng;
898 split_val = pmin + cvRandReal(rng) * fdiff ;
899 if (split_val - pmin <= FLT_EPSILON)
900 split_val = pmin + split_delta;
901 if (pmax - split_val <= FLT_EPSILON)
902 split_val = pmax - split_delta;
904 // calculate Gini index
907 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
908 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
911 // init arrays of class instance counters on both sides of the split
912 for( i = 0; i < m; i++ )
917 for( int si = 0; si < n; si++ )
919 int r = responses[si];
920 float val = values[si];
923 if ( val < split_val )
934 for (int i = 0; i < m; i++)
936 lbest_val += lc[i]*lc[i];
937 rbest_val += rc[i]*rc[i];
939 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
943 double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
944 double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
947 // init arrays of class instance counters on both sides of the split
948 for( i = 0; i < m; i++ )
953 for( int si = 0; si < n; si++ )
955 int r = responses[si];
956 float val = values[si];
958 double p = priors[si];
960 if ( val < split_val )
971 for (int i = 0; i < m; i++)
973 lbest_val += lc[i]*lc[i];
974 rbest_val += rc[i]*rc[i];
976 best_val = (lbest_val*R + rbest_val*L) / (L*R);
981 CvDTreeSplit* split = 0;
984 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
986 split->ord.c = (float)split_val;
987 split->ord.split_point = -1;
989 split->quality = (float)best_val;
994 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
997 int ci = data->get_var_type(vi);
998 int n = node->sample_count;
999 int cm = data->get_num_classes();
1000 int vm = data->cat_count->data.i[ci];
1001 double best_val = init_quality;
1002 CvDTreeSplit *split = 0;
1006 cv::AutoBuffer<int> inn_buf;
1008 inn_buf.allocate(2*n);
1009 int* ext_buf = _ext_buf ? (int*)_ext_buf : (int*)inn_buf;
1011 const int* labels = data->get_cat_var_data( node, vi, ext_buf );
1012 const int* responses = data->get_class_labels( node, ext_buf + n );
1014 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1016 // create random class mask
1017 int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1018 for (int i = 0; i < vm; i++)
1022 for (int si = 0; si < n; si++)
1025 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1030 int valid_ccount = 0;
1031 for (int i = 0; i < vm; i++)
1032 if (valid_cidx[i] >= 0)
1034 valid_cidx[i] = valid_ccount;
1037 if (valid_ccount > 1)
1039 CvRNG* rng = forest->get_rng();
1040 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1042 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1044 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1045 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1046 cvSet( &submask, cvScalar(1) );
1047 for (int i = 0; i < valid_ccount; i++)
1050 int i1 = cvRandInt( rng ) % valid_ccount;
1051 int i2 = cvRandInt( rng ) % valid_ccount;
1052 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1055 split = _split ? _split : data->new_split_cat( 0, -1.0f );
1056 split->var_idx = vi;
1057 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1059 // calculate Gini index
1060 double lbest_val = 0, rbest_val = 0;
1063 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
1064 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
1066 // init arrays of class instance counters on both sides of the split
1067 for(int i = 0; i < cm; i++ )
1072 for( int si = 0; si < n; si++ )
1074 int r = responses[si];
1075 int var_class_idx = labels[si];
1076 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1078 int mask_class_idx = valid_cidx[var_class_idx];
1079 if (var_class_mask->data.ptr[mask_class_idx])
1083 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1091 for (int i = 0; i < cm; i++)
1093 lbest_val += lc[i]*lc[i];
1094 rbest_val += rc[i]*rc[i];
1096 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1100 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
1101 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
1102 double L = 0, R = 0;
1103 // init arrays of class instance counters on both sides of the split
1104 for(int i = 0; i < cm; i++ )
1109 for( int si = 0; si < n; si++ )
1111 int r = responses[si];
1112 int var_class_idx = labels[si];
1113 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1115 double p = priors[si];
1116 int mask_class_idx = valid_cidx[var_class_idx];
1118 if (var_class_mask->data.ptr[mask_class_idx])
1122 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1130 for (int i = 0; i < cm; i++)
1132 lbest_val += lc[i]*lc[i];
1133 rbest_val += rc[i]*rc[i];
1135 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1137 split->quality = (float)best_val;
1139 cvReleaseMat(&var_class_mask);
1146 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1149 const float epsilon = FLT_EPSILON*2;
1150 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1151 int n = node->sample_count;
1152 cv::AutoBuffer<uchar> inn_buf;
1154 inn_buf.allocate(n*(2*sizeof(int) + 2*sizeof(float)));
1155 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
1156 float* values_buf = (float*)ext_buf;
1157 int* missing_buf = (int*)(values_buf + n);
1158 const float* values = 0;
1159 const int* missing = 0;
1160 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1161 float* responses_buf = (float*)(missing_buf + n);
1162 int* sample_indices_buf = (int*)(responses_buf + n);
1163 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1165 double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1168 bool is_find_split = false;
1171 while ( missing[smpi] && (smpi < n) )
1176 pmin = values[smpi];
1178 for (; smpi < n; smpi++)
1180 float ptemp = values[smpi];
1181 int m = missing[smpi];
1188 float fdiff = pmax-pmin;
1189 if (fdiff > epsilon)
1191 is_find_split = true;
1192 CvRNG* rng = &data->rng;
1193 split_val = pmin + cvRandReal(rng) * fdiff ;
1194 if (split_val - pmin <= FLT_EPSILON)
1195 split_val = pmin + split_delta;
1196 if (pmax - split_val <= FLT_EPSILON)
1197 split_val = pmax - split_delta;
1199 for (int si = 0; si < n; si++)
1201 float r = responses[si];
1202 float val = values[si];
1203 int m = missing[si];
1205 if (val < split_val)
1216 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1219 CvDTreeSplit* split = 0;
1222 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1223 split->var_idx = vi;
1224 split->ord.c = (float)split_val;
1225 split->ord.split_point = -1;
1226 split->inversed = 0;
1227 split->quality = (float)best_val;
1232 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1235 int ci = data->get_var_type(vi);
1236 int n = node->sample_count;
1237 int vm = data->cat_count->data.i[ci];
1238 double best_val = init_quality;
1239 CvDTreeSplit *split = 0;
1240 float lsum = 0, rsum = 0;
1244 int base_size = vm*sizeof(int);
1245 cv::AutoBuffer<uchar> inn_buf(base_size);
1247 inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
1248 uchar* base_buf = (uchar*)inn_buf;
1249 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1250 int* labels_buf = (int*)ext_buf;
1251 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1252 float* responses_buf = (float*)(labels_buf + n);
1253 int* sample_indices_buf = (int*)(responses_buf + n);
1254 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1256 // create random class mask
1257 int *valid_cidx = (int*)base_buf;
1258 for (int i = 0; i < vm; i++)
1262 for (int si = 0; si < n; si++)
1265 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1270 int valid_ccount = 0;
1271 for (int i = 0; i < vm; i++)
1272 if (valid_cidx[i] >= 0)
1274 valid_cidx[i] = valid_ccount;
1277 if (valid_ccount > 1)
1279 CvRNG* rng = forest->get_rng();
1280 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1282 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1284 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1285 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1286 cvSet( &submask, cvScalar(1) );
1287 for (int i = 0; i < valid_ccount; i++)
1290 int i1 = cvRandInt( rng ) % valid_ccount;
1291 int i2 = cvRandInt( rng ) % valid_ccount;
1292 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1295 split = _split ? _split : data->new_split_cat( 0, -1.0f);
1296 split->var_idx = vi;
1297 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1300 for( int si = 0; si < n; si++ )
1302 float r = responses[si];
1303 int var_class_idx = labels[si];
1304 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1306 int mask_class_idx = valid_cidx[var_class_idx];
1307 if (var_class_mask->data.ptr[mask_class_idx])
1311 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1319 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1321 split->quality = (float)best_val;
1323 cvReleaseMat(&var_class_mask);
1330 void CvForestERTree::split_node_data( CvDTreeNode* node )
1332 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1333 char* dir = (char*)data->direction->data.ptr;
1334 CvDTreeNode *left = 0, *right = 0;
1335 int new_buf_idx = data->get_child_buf_idx( node );
1336 CvMat* buf = data->buf;
1337 int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
1339 complete_node_dir(node);
1341 for( i = nl = nr = 0; i < n; i++ )
1348 bool split_input_data;
1349 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1350 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1352 split_input_data = node->depth + 1 < data->params.max_depth &&
1353 (node->left->sample_count > data->params.min_sample_count ||
1354 node->right->sample_count > data->params.min_sample_count);
1356 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)+sizeof(float)));
1357 // split ordered vars
1358 for( vi = 0; vi < data->var_count; vi++ )
1360 int ci = data->get_var_type(vi);
1361 if (ci >= 0) continue;
1363 int n1 = node->get_num_valid(vi), nr1 = 0;
1364 float* values_buf = (float*)(uchar*)inn_buf;
1365 int* missing_buf = (int*)(values_buf + n);
1366 const float* values = 0;
1367 const int* missing = 0;
1368 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1370 for( i = 0; i < n; i++ )
1371 nr1 += (!missing[i] & dir[i]);
1372 left->set_num_valid(vi, n1 - nr1);
1373 right->set_num_valid(vi, nr1);
1375 // split categorical vars, responses and cv_labels using new_idx relocation table
1376 for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1378 int ci = data->get_var_type(vi);
1379 if (ci < 0) continue;
1381 int n1 = node->get_num_valid(vi), nr1 = 0;
1382 const int* src_lbls = data->get_cat_var_data(node, vi, (int*)(uchar*)inn_buf);
1384 for(i = 0; i < n; i++)
1385 temp_buf[i] = src_lbls[i];
1387 if (data->is_buf_16u)
1389 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
1390 ci*scount + left->offset);
1391 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
1392 ci*scount + right->offset);
1394 for( i = 0; i < n; i++ )
1397 int idx = temp_buf[i];
1400 *rdst = (unsigned short)idx;
1402 nr1 += (idx != 65535);
1406 *ldst = (unsigned short)idx;
1411 if( vi < data->var_count )
1413 left->set_num_valid(vi, n1 - nr1);
1414 right->set_num_valid(vi, nr1);
1419 int *ldst = buf->data.i + left->buf_idx*buf->cols +
1420 ci*scount + left->offset;
1421 int *rdst = buf->data.i + right->buf_idx*buf->cols +
1422 ci*scount + right->offset;
1424 for( i = 0; i < n; i++ )
1427 int idx = temp_buf[i];
1442 if( vi < data->var_count )
1444 left->set_num_valid(vi, n1 - nr1);
1445 right->set_num_valid(vi, nr1);
1450 // split sample indices
1451 int *sample_idx_src_buf = (int*)(uchar*)inn_buf;
1452 const int* sample_idx_src = 0;
1453 if (split_input_data)
1455 sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
1457 for(i = 0; i < n; i++)
1458 temp_buf[i] = sample_idx_src[i];
1460 int pos = data->get_work_var_count();
1462 if (data->is_buf_16u)
1464 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
1465 pos*scount + left->offset);
1466 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
1467 pos*scount + right->offset);
1469 for (i = 0; i < n; i++)
1472 unsigned short idx = (unsigned short)temp_buf[i];
1487 int* ldst = buf->data.i + left->buf_idx*buf->cols +
1488 pos*scount + left->offset;
1489 int* rdst = buf->data.i + right->buf_idx*buf->cols +
1490 pos*scount + right->offset;
1491 for (i = 0; i < n; i++)
1494 int idx = temp_buf[i];
1509 // deallocate the parent node data that is not needed anymore
1510 data->free_node_data(node);
1513 CvERTrees::CvERTrees()
1517 CvERTrees::~CvERTrees()
1521 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1522 const CvMat* _responses, const CvMat* _var_idx,
1523 const CvMat* _sample_idx, const CvMat* _var_type,
1524 const CvMat* _missing_mask, CvRTParams params )
1526 bool result = false;
1528 CV_FUNCNAME("CvERTrees::train");
1534 CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1535 params.regression_accuracy, params.use_surrogates, params.max_categories,
1536 params.cv_folds, params.use_1se_rule, false, params.priors );
1538 data = new CvERTreeTrainData();
1539 CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1540 _sample_idx, _var_type, _missing_mask, tree_params, true));
1542 var_count = data->var_count;
1543 if( params.nactive_vars > var_count )
1544 params.nactive_vars = var_count;
1545 else if( params.nactive_vars == 0 )
1546 params.nactive_vars = (int)sqrt((double)var_count);
1547 else if( params.nactive_vars < 0 )
1548 CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1550 // Create mask of active variables at the tree nodes
1551 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1552 if( params.calc_var_importance )
1554 CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 ));
1555 cvZero(var_importance);
1557 { // initialize active variables mask
1558 CvMat submask1, submask2;
1559 CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1560 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1561 cvSet( &submask1, cvScalar(1) );
1562 if( params.nactive_vars < active_var_mask->cols )
1564 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1565 cvZero( &submask2 );
1569 CV_CALL(result = grow_forest( params.term_crit ));
1578 bool CvERTrees::train( CvMLData* data, CvRTParams params)
1580 bool result = false;
1582 CV_FUNCNAME( "CvERTrees::train" );
1586 CV_CALL( result = CvRTrees::train( data, params) );
1593 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1595 bool result = false;
1597 CvMat* sample_idx_for_tree = 0;
1599 CV_FUNCNAME("CvERTrees::grow_forest");
1602 const int max_ntrees = term_crit.max_iter;
1603 const double max_oob_err = term_crit.epsilon;
1605 const int dims = data->var_count;
1606 float maximal_response = 0;
1608 CvMat* oob_sample_votes = 0;
1609 CvMat* oob_responses = 0;
1611 float* oob_samples_perm_ptr= 0;
1613 float* samples_ptr = 0;
1614 uchar* missing_ptr = 0;
1615 float* true_resp_ptr = 0;
1616 bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1618 // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1619 // oob_num_of_predictions[i] = number of summands
1620 // (number of predictions for the i-th sample)
1621 // initialize these variable to avoid warning C4701
1622 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1623 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1625 nsamples = data->sample_count;
1626 nclasses = data->get_num_classes();
1628 if ( is_oob_or_vimportance )
1630 if( data->is_classifier )
1632 CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1633 cvZero(oob_sample_votes);
1637 // oob_responses[0,i] = oob_predictions_sum[i]
1638 // = sum of predicted values for the i-th sample
1639 // oob_responses[1,i] = oob_num_of_predictions[i]
1640 // = number of summands (number of predictions for the i-th sample)
1641 CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1642 cvZero(oob_responses);
1643 cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1644 cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1647 CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1648 CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1649 CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1650 CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples ));
1652 CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1654 double minval, maxval;
1655 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1656 cvMinMaxLoc( &responses, &minval, &maxval );
1657 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1661 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1662 memset( trees, 0, sizeof(trees[0])*max_ntrees );
1664 CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1666 for (int i = 0; i < nsamples; i++)
1667 sample_idx_for_tree->data.i[i] = i;
1669 while( ntrees < max_ntrees )
1671 int i, oob_samples_count = 0;
1672 double ncorrect_responses = 0; // used for estimation of variable importance
1673 CvForestTree* tree = 0;
1675 trees[ntrees] = new CvForestERTree();
1676 tree = (CvForestERTree*)trees[ntrees];
1677 CV_CALL(tree->train( data, 0, this ));
1679 if ( is_oob_or_vimportance )
1681 CvMat sample, missing;
1682 // form array of OOB samples indices and get these samples
1683 sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
1684 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1687 for( i = 0; i < nsamples; i++,
1688 sample.data.fl += dims, missing.data.ptr += dims )
1690 CvDTreeNode* predicted_node = 0;
1692 // predict oob samples
1693 if( !predicted_node )
1694 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1696 if( !data->is_classifier ) //regression
1698 double avg_resp, resp = predicted_node->value;
1699 oob_predictions_sum.data.fl[i] += (float)resp;
1700 oob_num_of_predictions.data.fl[i] += 1;
1702 // compute oob error
1703 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1704 avg_resp -= true_resp_ptr[i];
1705 oob_error += avg_resp*avg_resp;
1706 resp = (resp - true_resp_ptr[i])/maximal_response;
1707 ncorrect_responses += exp( -resp*resp );
1709 else //classification
1715 cvGetRow(oob_sample_votes, &votes, i);
1716 votes.data.i[predicted_node->class_idx]++;
1718 // compute oob error
1719 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1721 prdct_resp = data->cat_map->data.i[max_loc.x];
1722 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1724 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1726 oob_samples_count++;
1728 if( oob_samples_count > 0 )
1729 oob_error /= (double)oob_samples_count;
1731 // estimate variable importance
1732 if( var_importance && oob_samples_count > 0 )
1736 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1737 for( m = 0; m < dims; m++ )
1739 double ncorrect_responses_permuted = 0;
1740 // randomly permute values of the m-th variable in the oob samples
1741 float* mth_var_ptr = oob_samples_perm_ptr + m;
1743 for( i = 0; i < nsamples; i++ )
1748 i1 = cvRandInt( &rng ) % nsamples;
1749 i2 = cvRandInt( &rng ) % nsamples;
1750 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1752 // turn values of (m-1)-th variable, that were permuted
1753 // at the previous iteration, untouched
1755 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1758 // predict "permuted" cases and calculate the number of votes for the
1759 // correct class in the variable-m-permuted oob data
1760 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1761 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1762 for( i = 0; i < nsamples; i++,
1763 sample.data.fl += dims, missing.data.ptr += dims )
1765 double predct_resp, true_resp;
1767 predct_resp = tree->predict(&sample, &missing, true)->value;
1768 true_resp = true_resp_ptr[i];
1769 if( data->is_classifier )
1770 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1773 true_resp = (true_resp - predct_resp)/maximal_response;
1774 ncorrect_responses_permuted += exp( -true_resp*true_resp );
1777 var_importance->data.fl[m] += (float)(ncorrect_responses
1778 - ncorrect_responses_permuted);
1783 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1786 if( var_importance )
1788 for ( int vi = 0; vi < var_importance->cols; vi++ )
1789 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1790 var_importance->data.fl[vi] : 0;
1791 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1796 cvFree( &oob_samples_perm_ptr );
1797 cvFree( &samples_ptr );
1798 cvFree( &missing_ptr );
1799 cvFree( &true_resp_ptr );
1801 cvReleaseMat( &sample_idx_for_tree );
1803 cvReleaseMat( &oob_sample_votes );
1804 cvReleaseMat( &oob_responses );
1813 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1814 const Mat& _responses, const Mat& _var_idx,
1815 const Mat& _sample_idx, const Mat& _var_type,
1816 const Mat& _missing_mask, CvRTParams params )
1818 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1819 sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1820 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1821 sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1822 mmask.data.ptr ? &mmask : 0, params);