1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 By downloading, copying, installing or using the software you agree to this license.
6 If you do not agree to this license, do not download, install,
7 copy or use the software.
10 Intel License Agreement
12 Copyright (C) 2000, Intel Corporation, all rights reserved.
13 Third party copyrights are property of their respective owners.
15 Redistribution and use in source and binary forms, with or without modification,
16 are permitted provided that the following conditions are met:
18 * Redistribution's of source code must retain the above copyright notice,
19 this list of conditions and the following disclaimer.
21 * Redistribution's in binary form must reproduce the above copyright notice,
22 this list of conditions and the following disclaimer in the documentation
23 and/or other materials provided with the distribution.
25 * The name of Intel Corporation may not be used to endorse or promote products
26 derived from this software without specific prior written permission.
28 This software is provided by the copyright holders and contributors "as is" and
29 any express or implied warranties, including, but not limited to, the implied
30 warranties of merchantability and fitness for a particular purpose are disclaimed.
31 In no event shall the Intel Corporation or contributors be liable for any direct,
32 indirect, incidental, special, exemplary, or consequential damages
33 (including, but not limited to, procurement of substitute goods or services;
34 loss of use, data, or profits; or business interruption) however caused
35 and on any theory of liability, whether in contract, strict liability,
36 or tort (including negligence or otherwise) arising in any way out of
37 the use of this software, even if advised of the possibility of such damage.
47 static const float ord_nan = FLT_MAX*0.5f;
48 static const int min_block_size = 1 << 16;
49 static const int block_size_delta = 1 << 10;
51 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
52 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
54 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
55 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
59 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
60 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
61 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
62 bool _shared, bool _add_labels, bool _update_data )
64 CvMat* sample_indices = 0;
68 CvPair16u32s* pair16u32s_ptr = 0;
69 CvDTreeTrainData* data = 0;
72 unsigned short* udst = 0;
75 CV_FUNCNAME( "CvERTreeTrainData::set_data" );
79 int sample_all = 0, r_type = 0, cv_n;
80 int total_c_count = 0;
81 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
82 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
85 const int *sidx = 0, *vidx = 0;
87 if ( _params.use_surrogates )
88 CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
90 if( _update_data && data_root )
92 CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
100 CV_CALL( set_params( _params ));
102 // check parameter types and sizes
103 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
105 train_data = _train_data;
106 responses = _responses;
107 missing_mask = _missing_mask;
109 if( _tflag == CV_ROW_SAMPLE )
111 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
114 ms_step = _missing_mask->step, mv_step = 1;
118 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
121 mv_step = _missing_mask->step, ms_step = 1;
125 sample_count = sample_all;
130 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
131 sidx = sample_indices->data.i;
132 sample_count = sample_indices->rows + sample_indices->cols - 1;
137 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
138 vidx = var_idx->data.i;
139 var_count = var_idx->rows + var_idx->cols - 1;
142 if( !CV_IS_MAT(_responses) ||
143 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
144 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
145 (_responses->rows != 1 && _responses->cols != 1) ||
146 _responses->rows + _responses->cols - 1 != sample_all )
147 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
148 "floating-point vector containing as many elements as "
149 "the total number of samples in the training data matrix" );
152 if ( sample_count < 65536 )
156 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
158 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
164 is_classifier = r_type == CV_VAR_CATEGORICAL;
166 // step 0. calc the number of categorical vars
167 for( vi = 0; vi < var_count; vi++ )
169 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
170 cat_var_count++ : ord_var_count--;
173 ord_var_count = ~ord_var_count;
174 cv_n = params.cv_folds;
175 // set the two last elements of var_type array to be able
176 // to locate responses and cross-validation labels using
177 // the corresponding get_* functions.
178 var_type->data.i[var_count] = cat_var_count;
179 var_type->data.i[var_count+1] = cat_var_count+1;
181 // in case of single ordered predictor we need dummy cv_labels
182 // for safe split_node_data() operation
183 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
185 work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
186 buf_size = (work_var_count + 1)*sample_count;
188 buf_count = shared ? 2 : 1;
192 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
193 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
197 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
198 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
201 size = is_classifier ? cat_var_count+1 : cat_var_count;
202 size = !size ? 1 : size;
203 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
204 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
206 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
207 size = !size ? 1 : size;
208 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
210 // now calculate the maximum size of split,
211 // create memory storage that will keep nodes and splits of the decision tree
212 // allocate root node and the buffer for the whole training data
213 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
214 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
215 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
216 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
217 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
218 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
220 nv_size = var_count*sizeof(int);
221 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
223 temp_block_size = nv_size;
227 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
228 CV_ERROR( CV_StsOutOfRange,
229 "The many folds in cross-validation for such a small dataset" );
231 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
232 temp_block_size = MAX(temp_block_size, cv_size);
235 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
236 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
237 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
239 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
241 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
248 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
249 if (is_buf_16u && (cat_var_count || is_classifier))
250 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
252 // transform the training data to convenient representation
253 for( vi = 0; vi <= var_count; vi++ )
256 const uchar* mask = 0;
257 int m_step = 0, step;
258 const int* idata = 0;
259 const float* fdata = 0;
262 if( vi < var_count ) // analyze i-th input variable
264 int vi0 = vidx ? vidx[vi] : vi;
265 ci = get_var_type(vi);
266 step = ds_step; m_step = ms_step;
267 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
268 idata = _train_data->data.i + vi0*dv_step;
270 fdata = _train_data->data.fl + vi0*dv_step;
272 mask = _missing_mask->data.ptr + vi0*mv_step;
274 else // analyze _responses
277 step = CV_IS_MAT_CONT(_responses->type) ?
278 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
279 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
280 idata = _responses->data.i;
282 fdata = _responses->data.fl;
285 if( (vi < var_count && ci>=0) ||
286 (vi == var_count && is_classifier) ) // process categorical variable or response
288 int c_count, prev_label;
292 udst = (unsigned short*)(buf->data.s + ci*sample_count);
294 idst = buf->data.i + ci*sample_count;
297 for( i = 0; i < sample_count; i++ )
299 int val = INT_MAX, si = sidx ? sidx[i] : i;
300 if( !mask || !mask[si*m_step] )
303 val = idata[si*step];
306 float t = fdata[si*step];
310 sprintf( err, "%d-th value of %d-th (categorical) "
311 "variable is not an integer", i, vi );
312 CV_ERROR( CV_StsBadArg, err );
318 sprintf( err, "%d-th value of %d-th (categorical) "
319 "variable is too large", i, vi );
320 CV_ERROR( CV_StsBadArg, err );
327 pair16u32s_ptr[i].u = udst + i;
328 pair16u32s_ptr[i].i = _idst + i;
333 int_ptr[i] = idst + i;
337 c_count = num_valid > 0;
341 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
342 // count the categories
343 for( i = 1; i < num_valid; i++ )
344 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
349 icvSortIntPtr( int_ptr, sample_count, 0 );
350 // count the categories
351 for( i = 1; i < num_valid; i++ )
352 c_count += *int_ptr[i] != *int_ptr[i-1];
356 max_c_count = MAX( max_c_count, c_count );
357 cat_count->data.i[ci] = c_count;
358 cat_ofs->data.i[ci] = total_c_count;
360 // resize cat_map, if need
361 if( cat_map->cols < total_c_count + c_count )
364 CV_CALL( cat_map = cvCreateMat( 1,
365 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
366 for( i = 0; i < total_c_count; i++ )
367 cat_map->data.i[i] = tmp_map->data.i[i];
368 cvReleaseMat( &tmp_map );
371 c_map = cat_map->data.i + total_c_count;
372 total_c_count += c_count;
377 // compact the class indices and build the map
378 prev_label = ~*pair16u32s_ptr[0].i;
379 for( i = 0; i < num_valid; i++ )
381 int cur_label = *pair16u32s_ptr[i].i;
382 if( cur_label != prev_label )
383 c_map[++c_count] = prev_label = cur_label;
384 *pair16u32s_ptr[i].u = (unsigned short)c_count;
386 // replace labels for missing values with 65535
387 for( ; i < sample_count; i++ )
388 *pair16u32s_ptr[i].u = 65535;
392 // compact the class indices and build the map
393 prev_label = ~*int_ptr[0];
394 for( i = 0; i < num_valid; i++ )
396 int cur_label = *int_ptr[i];
397 if( cur_label != prev_label )
398 c_map[++c_count] = prev_label = cur_label;
399 *int_ptr[i] = c_count;
401 // replace labels for missing values with -1
402 for( ; i < sample_count; i++ )
406 else if( ci < 0 ) // process ordered variable
408 for( i = 0; i < sample_count; i++ )
411 int si = sidx ? sidx[i] : i;
412 if( !mask || !mask[si*m_step] )
415 val = (float)idata[si*step];
417 val = fdata[si*step];
419 if( fabs(val) >= ord_nan )
421 sprintf( err, "%d-th value of %d-th (ordered) "
422 "variable (=%g) is too large", i, vi, val );
423 CV_ERROR( CV_StsBadArg, err );
430 data_root->set_num_valid(vi, num_valid);
435 udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
437 idst = buf->data.i + get_work_var_count()*sample_count;
439 for (i = 0; i < sample_count; i++)
442 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
444 idst[i] = sidx ? sidx[i] : i;
449 unsigned short* udst = 0;
455 udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
456 for( i = vi = 0; i < sample_count; i++ )
458 udst[i] = (unsigned short)vi++;
459 vi &= vi < cv_n ? -1 : 0;
462 for( i = 0; i < sample_count; i++ )
464 int a = cvRandInt(r) % sample_count;
465 int b = cvRandInt(r) % sample_count;
466 unsigned short unsh = (unsigned short)vi;
467 CV_SWAP( udst[a], udst[b], unsh );
472 idst = buf->data.i + (get_work_var_count()-1)*sample_count;
473 for( i = vi = 0; i < sample_count; i++ )
476 vi &= vi < cv_n ? -1 : 0;
479 for( i = 0; i < sample_count; i++ )
481 int a = cvRandInt(r) % sample_count;
482 int b = cvRandInt(r) % sample_count;
483 CV_SWAP( idst[a], idst[b], vi );
489 cat_map->cols = MAX( total_c_count, 1 );
491 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
492 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
493 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
495 have_priors = is_classifier && params.priors;
498 int m = get_num_classes();
500 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
501 for( i = 0; i < m; i++ )
503 double val = have_priors ? params.priors[i] : 1.;
505 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
506 priors->data.db[i] = val;
512 cvScale( priors, priors, 1./sum );
514 CV_CALL( priors_mult = cvCloneMat( priors ));
515 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
518 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
519 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
522 int maxNumThreads = 1;
524 maxNumThreads = omp_get_num_procs();
526 pred_float_buf.resize(maxNumThreads);
527 pred_int_buf.resize(maxNumThreads);
528 resp_float_buf.resize(maxNumThreads);
529 resp_int_buf.resize(maxNumThreads);
530 cv_lables_buf.resize(maxNumThreads);
531 sample_idx_buf.resize(maxNumThreads);
532 for( int ti = 0; ti < maxNumThreads; ti++ )
534 pred_float_buf[ti].resize(sample_count);
535 pred_int_buf[ti].resize(sample_count);
536 resp_float_buf[ti].resize(sample_count);
537 resp_int_buf[ti].resize(sample_count);
538 cv_lables_buf[ti].resize(sample_count);
539 sample_idx_buf[ti].resize(sample_count);
553 cvReleaseMat( &var_type0 );
554 cvReleaseMat( &sample_indices );
555 cvReleaseMat( &tmp_map );
558 int CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, const float** ord_values, const int** missing )
560 int vidx = var_idx ? var_idx->data.i[vi] : vi;
561 int node_sample_count = n->sample_count;
562 int* sample_indices_buf = get_sample_idx_buf();
563 const int* sample_indices = 0;
565 get_sample_indices(n, sample_indices_buf, &sample_indices);
567 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
568 int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
569 if( tflag == CV_ROW_SAMPLE )
571 for( int i = 0; i < node_sample_count; i++ )
573 int idx = sample_indices[i];
574 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
575 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
579 for( int i = 0; i < node_sample_count; i++ )
581 int idx = sample_indices[i];
582 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
583 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
585 *ord_values = ord_values_buf;
586 *missing = missing_buf;
587 return 0; //TODO: return the number of non-missing values
591 void CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
593 get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf, indices );
597 void CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
600 get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf, labels );
604 int CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
606 int ci = get_var_type( vi);
608 *cat_values = buf->data.i + n->buf_idx*buf->cols +
609 ci*sample_count + n->offset;
611 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
612 ci*sample_count + n->offset);
613 for( int i = 0; i < n->sample_count; i++ )
614 cat_values_buf[i] = short_values[i];
615 *cat_values = cat_values_buf;
618 return 0; //TODO: return the number of non-missing values
621 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
622 float* values, uchar* missing,
623 float* responses, bool get_class_idx )
625 CvMat* subsample_idx = 0;
626 CvMat* subsample_co = 0;
628 CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
632 int i, vi, total = sample_count, count = total, cur_ofs = 0;
638 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
639 sidx = subsample_idx->data.i;
640 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
641 co = subsample_co->data.i;
642 cvZero( subsample_co );
643 count = subsample_idx->cols + subsample_idx->rows - 1;
644 for( i = 0; i < count; i++ )
646 for( i = 0; i < total; i++ )
648 int count_i = co[i*2];
651 co[i*2+1] = cur_ofs*var_count;
658 memset( missing, 1, count*var_count );
660 for( vi = 0; vi < var_count; vi++ )
662 int ci = get_var_type(vi);
663 if( ci >= 0 ) // categorical
665 float* dst = values + vi;
666 uchar* m = missing ? missing + vi : 0;
667 int* src_buf = get_pred_int_buf();
669 get_cat_var_data(data_root, vi, src_buf, &src);
671 for( i = 0; i < count; i++, dst += var_count )
673 int idx = sidx ? sidx[i] : i;
678 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
685 float* dst_buf = values + vi;
686 int* m_buf = get_pred_int_buf();
687 const float *dst = 0;
689 get_ord_var_data(data_root, vi, dst_buf, m_buf, &dst, &m);
690 for (int si = 0; si < total; si++)
691 *(missing + vi + si) = m[si] == 0 ? 0 : 1;
700 int* src_buf = get_resp_int_buf();
702 get_class_labels(data_root, src_buf, &src);
703 for( i = 0; i < count; i++ )
705 int idx = sidx ? sidx[i] : i;
706 int val = get_class_idx ? src[idx] :
707 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
708 responses[i] = (float)val;
713 float *_values_buf = get_resp_float_buf();
714 const float* _values = 0;
715 get_ord_responses(data_root, _values_buf, &_values);
716 for( i = 0; i < count; i++ )
718 int idx = sidx ? sidx[i] : i;
719 responses[i] = _values[idx];
726 cvReleaseMat( &subsample_idx );
727 cvReleaseMat( &subsample_co );
730 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
732 CvDTreeNode* root = 0;
734 CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
739 CV_ERROR( CV_StsError, "No training data has been set" );
741 if( !_subsample_idx )
743 // make a copy of the root node
746 root = new_node( 0, 1, 0, 0 );
749 root->num_valid = temp.num_valid;
750 if( root->num_valid )
752 for( i = 0; i < var_count; i++ )
753 root->num_valid[i] = data_root->num_valid[i];
755 root->cv_Tn = temp.cv_Tn;
756 root->cv_node_risk = temp.cv_node_risk;
757 root->cv_node_error = temp.cv_node_error;
760 CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
766 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
768 char* dir = (char*)data->direction->data.ptr;
769 int i, n = node->sample_count, vi = node->split->var_idx;
772 assert( !node->split->inversed );
774 if( data->get_var_type(vi) >= 0 ) // split on categorical var
776 int* labels_buf = data->get_pred_int_buf();
777 const int* labels = 0;
778 const int* subset = node->split->subset;
779 data->get_cat_var_data( node, vi, labels_buf, &labels );
780 if( !data->have_priors )
782 int sum = 0, sum_abs = 0;
784 for( i = 0; i < n; i++ )
787 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
788 CV_DTREE_CAT_DIR(idx,subset) : 0;
789 sum += d; sum_abs += d & 1;
793 R = (sum_abs + sum) >> 1;
794 L = (sum_abs - sum) >> 1;
798 const double* priors = data->priors_mult->data.db;
799 double sum = 0, sum_abs = 0;
800 int *responses_buf = data->get_resp_int_buf();
801 const int* responses;
802 data->get_class_labels(node, responses_buf, &responses);
804 for( i = 0; i < n; i++ )
807 double w = priors[responses[i]];
808 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
809 sum += d*w; sum_abs += (d & 1)*w;
813 R = (sum_abs + sum) * 0.5;
814 L = (sum_abs - sum) * 0.5;
817 else // split on ordered var
819 float split_val = node->split->ord.c;
820 float* val_buf = data->get_pred_float_buf();
821 const float* val = 0;
822 int* missing_buf = data->get_pred_int_buf();
823 const int* missing = 0;
824 data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing );
826 if( !data->have_priors )
829 for( i = 0; i < n; i++ )
835 if ( val[i] < split_val)
850 const double* priors = data->priors_mult->data.db;
851 int* responses_buf = data->get_resp_int_buf();
852 const int* responses = 0;
853 data->get_class_labels(node, responses_buf, &responses);
855 for( i = 0; i < n; i++ )
861 double w = priors[responses[i]];
862 if ( val[i] < split_val)
877 node->maxlr = MAX( L, R );
878 return node->split->quality/(L + R);
881 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
883 const float epsilon = FLT_EPSILON*2;
884 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
886 int n = node->sample_count;
887 int m = data->get_num_classes();
889 float* values_buf = data->get_pred_float_buf();
890 const float* values = 0;
891 int* missing_buf = data->get_pred_int_buf();
892 const int* missing = 0;
893 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
894 int* responses_buf = data->get_resp_int_buf();
895 const int* responses = 0;
896 data->get_class_labels( node, responses_buf, &responses );
898 double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
902 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
904 bool is_find_split = false;
908 while ( missing[smpi] && (smpi < n) )
914 for (; smpi < n; smpi++)
916 float ptemp = values[smpi];
917 int m = missing[smpi];
924 float fdiff = pmax-pmin;
927 is_find_split = true;
928 CvRNG* rng = &data->rng;
929 split_val = pmin + cvRandReal(rng) * fdiff ;
930 if (split_val - pmin <= FLT_EPSILON)
931 split_val = pmin + split_delta;
932 if (pmax - split_val <= FLT_EPSILON)
933 split_val = pmax - split_delta;
935 // calculate Gini index
938 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
939 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
942 // init arrays of class instance counters on both sides of the split
943 for( i = 0; i < m; i++ )
948 for( int si = 0; si < n; si++ )
950 int r = responses[si];
951 float val = values[si];
954 if ( val < split_val )
965 for (int i = 0; i < m; i++)
967 lbest_val += lc[i]*lc[i];
968 rbest_val += rc[i]*rc[i];
970 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
974 double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
975 double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
978 // init arrays of class instance counters on both sides of the split
979 for( i = 0; i < m; i++ )
984 for( int si = 0; si < n; si++ )
986 int r = responses[si];
987 float val = values[si];
989 double p = priors[si];
991 if ( val < split_val )
1002 for (int i = 0; i < m; i++)
1004 lbest_val += lc[i]*lc[i];
1005 rbest_val += rc[i]*rc[i];
1007 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1012 CvDTreeSplit* split = 0;
1015 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1016 split->var_idx = vi;
1017 split->ord.c = (float)split_val;
1018 split->ord.split_point = -1;
1019 split->inversed = 0;
1020 split->quality = (float)best_val;
1025 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1027 int ci = data->get_var_type(vi);
1028 int n = node->sample_count;
1029 int cm = data->get_num_classes();
1030 int vm = data->cat_count->data.i[ci];
1031 double best_val = init_quality;
1032 CvDTreeSplit *split = 0;
1036 int* labels_buf = data->get_pred_int_buf();
1037 const int* labels = 0;
1038 data->get_cat_var_data( node, vi, labels_buf, &labels );
1040 int* responses_buf = data->get_resp_int_buf();
1041 const int* responses = 0;
1042 data->get_class_labels( node, responses_buf, &responses );
1044 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1046 // create random class mask
1047 int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1048 for (int i = 0; i < vm; i++)
1052 for (int si = 0; si < n; si++)
1055 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1060 int valid_ccount = 0;
1061 for (int i = 0; i < vm; i++)
1062 if (valid_cidx[i] >= 0)
1064 valid_cidx[i] = valid_ccount;
1067 if (valid_ccount > 1)
1069 CvRNG* rng = forest->get_rng();
1070 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1072 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1074 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1075 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1076 cvSet( &submask, cvScalar(1) );
1077 for (int i = 0; i < valid_ccount; i++)
1080 int i1 = cvRandInt( rng ) % valid_ccount;
1081 int i2 = cvRandInt( rng ) % valid_ccount;
1082 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1085 split = _split ? _split : data->new_split_cat( 0, -1.0f );
1086 split->var_idx = vi;
1087 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1089 // calculate Gini index
1090 double lbest_val = 0, rbest_val = 0;
1093 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
1094 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
1096 // init arrays of class instance counters on both sides of the split
1097 for(int i = 0; i < cm; i++ )
1102 for( int si = 0; si < n; si++ )
1104 int r = responses[si];
1105 int var_class_idx = labels[si];
1106 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1108 int mask_class_idx = valid_cidx[var_class_idx];
1109 if (var_class_mask->data.ptr[mask_class_idx])
1113 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1121 for (int i = 0; i < cm; i++)
1123 lbest_val += lc[i]*lc[i];
1124 rbest_val += rc[i]*rc[i];
1126 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1130 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
1131 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
1132 double L = 0, R = 0;
1133 // init arrays of class instance counters on both sides of the split
1134 for(int i = 0; i < cm; i++ )
1139 for( int si = 0; si < n; si++ )
1141 int r = responses[si];
1142 int var_class_idx = labels[si];
1143 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1145 double p = priors[si];
1146 int mask_class_idx = valid_cidx[var_class_idx];
1148 if (var_class_mask->data.ptr[mask_class_idx])
1152 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1160 for (int i = 0; i < cm; i++)
1162 lbest_val += lc[i]*lc[i];
1163 rbest_val += rc[i]*rc[i];
1165 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1167 split->quality = (float)best_val;
1169 cvReleaseMat(&var_class_mask);
1176 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1178 const float epsilon = FLT_EPSILON*2;
1179 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1180 int n = node->sample_count;
1181 float* values_buf = data->get_pred_float_buf();
1182 const float* values = 0;
1183 int* missing_buf = data->get_pred_int_buf();
1184 const int* missing = 0;
1185 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1186 float* responses_buf = data->get_resp_float_buf();
1187 const float* responses = 0;
1188 data->get_ord_responses( node, responses_buf, &responses );
1190 double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1193 bool is_find_split = false;
1196 while ( missing[smpi] && (smpi < n) )
1201 pmin = values[smpi];
1203 for (; smpi < n; smpi++)
1205 float ptemp = values[smpi];
1206 int m = missing[smpi];
1213 float fdiff = pmax-pmin;
1214 if (fdiff > epsilon)
1216 is_find_split = true;
1217 CvRNG* rng = &data->rng;
1218 split_val = pmin + cvRandReal(rng) * fdiff ;
1219 if (split_val - pmin <= FLT_EPSILON)
1220 split_val = pmin + split_delta;
1221 if (pmax - split_val <= FLT_EPSILON)
1222 split_val = pmax - split_delta;
1224 for (int si = 0; si < n; si++)
1226 float r = responses[si];
1227 float val = values[si];
1228 int m = missing[si];
1230 if (val < split_val)
1241 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1244 CvDTreeSplit* split = 0;
1247 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1248 split->var_idx = vi;
1249 split->ord.c = (float)split_val;
1250 split->ord.split_point = -1;
1251 split->inversed = 0;
1252 split->quality = (float)best_val;
1257 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1259 int ci = data->get_var_type(vi);
1260 int n = node->sample_count;
1261 int vm = data->cat_count->data.i[ci];
1262 double best_val = init_quality;
1263 CvDTreeSplit *split = 0;
1264 float lsum = 0, rsum = 0;
1268 int* labels_buf = data->get_pred_int_buf();
1269 const int* labels = 0;
1270 data->get_cat_var_data( node, vi, labels_buf, &labels );
1272 float* responses_buf = data->get_resp_float_buf();
1273 const float* responses = 0;
1274 data->get_ord_responses( node, responses_buf, &responses );
1276 // create random class mask
1277 int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1278 for (int i = 0; i < vm; i++)
1282 for (int si = 0; si < n; si++)
1285 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1290 int valid_ccount = 0;
1291 for (int i = 0; i < vm; i++)
1292 if (valid_cidx[i] >= 0)
1294 valid_cidx[i] = valid_ccount;
1297 if (valid_ccount > 1)
1299 CvRNG* rng = forest->get_rng();
1300 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1302 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1304 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1305 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1306 cvSet( &submask, cvScalar(1) );
1307 for (int i = 0; i < valid_ccount; i++)
1310 int i1 = cvRandInt( rng ) % valid_ccount;
1311 int i2 = cvRandInt( rng ) % valid_ccount;
1312 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1315 split = _split ? _split : data->new_split_cat( 0, -1.0f);
1316 split->var_idx = vi;
1317 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1320 for( int si = 0; si < n; si++ )
1322 float r = responses[si];
1323 int var_class_idx = labels[si];
1324 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1326 int mask_class_idx = valid_cidx[var_class_idx];
1327 if (var_class_mask->data.ptr[mask_class_idx])
1331 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1339 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1341 split->quality = (float)best_val;
1343 cvReleaseMat(&var_class_mask);
1350 //void CvForestERTree::complete_node_dir( CvDTreeNode* node )
1352 // int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
1353 // int nz = n - node->get_num_valid(node->split->var_idx);
1354 // char* dir = (char*)data->direction->data.ptr;
1356 // // try to complete direction using surrogate splits
1357 // if( nz && data->params.use_surrogates )
1359 // CvDTreeSplit* split = node->split->next;
1360 // for( ; split != 0 && nz; split = split->next )
1362 // int inversed_mask = split->inversed ? -1 : 0;
1363 // vi = split->var_idx;
1365 // if( data->get_var_type(vi) >= 0 ) // split on categorical var
1367 // int* labels_buf = data->pred_int_buf;
1368 // const int* labels = 0;
1369 // data->get_cat_var_data(node, vi, labels_buf, &labels);
1370 // const int* subset = split->subset;
1372 // for( i = 0; i < n; i++ )
1374 // int idx = labels[i];
1375 // if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
1378 // int d = CV_DTREE_CAT_DIR(idx,subset);
1379 // dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
1385 // else // split on ordered var
1387 // float* values_buf = data->pred_float_buf;
1388 // const float* values = 0;
1389 // uchar* missing_buf = (uchar*)data->pred_int_buf;
1390 // const uchar* missing = 0;
1391 // data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1392 // float split_val = node->split->ord.c;
1394 // for( i = 0; i < n; i++ )
1396 // if( !dir[i] && !missing[i])
1398 // int d = values[i] <= split_val ? -1 : 1;
1399 // dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
1408 // // find the default direction for the rest
1411 // for( i = nr = 0; i < n; i++ )
1412 // nr += dir[i] > 0;
1413 // nl = n - nr - nz;
1414 // d0 = nl > nr ? -1 : nr > nl;
1417 // // make sure that every sample is directed either to the left or to the right
1418 // for( i = 0; i < n; i++ )
1425 // d = d1, d1 = -d1;
1428 // dir[i] = (char)d; // remap (-1,1) to (0,1)
1432 void CvForestERTree::split_node_data( CvDTreeNode* node )
1434 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1435 char* dir = (char*)data->direction->data.ptr;
1436 CvDTreeNode *left = 0, *right = 0;
1437 int new_buf_idx = data->get_child_buf_idx( node );
1438 CvMat* buf = data->buf;
1439 int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
1441 complete_node_dir(node);
1443 for( i = nl = nr = 0; i < n; i++ )
1450 bool split_input_data;
1451 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1452 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1454 split_input_data = node->depth + 1 < data->params.max_depth &&
1455 (node->left->sample_count > data->params.min_sample_count ||
1456 node->right->sample_count > data->params.min_sample_count);
1458 // split ordered vars
1459 for( vi = 0; vi < data->var_count; vi++ )
1461 int ci = data->get_var_type(vi);
1462 if (ci >= 0) continue;
1464 int n1 = node->get_num_valid(vi), nr1 = 0;
1466 float* values_buf = data->get_pred_float_buf();
1467 const float* values = 0;
1468 int* missing_buf = data->get_pred_int_buf();
1469 const int* missing = 0;
1470 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1472 for( i = 0; i < n; i++ )
1473 nr1 += (!missing[i] & dir[i]);
1474 left->set_num_valid(vi, n1 - nr1);
1475 right->set_num_valid(vi, nr1);
1477 // split categorical vars, responses and cv_labels using new_idx relocation table
1478 for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1480 int ci = data->get_var_type(vi);
1481 if (ci < 0) continue;
1483 int n1 = node->get_num_valid(vi), nr1 = 0;
1485 int *src_lbls_buf = data->get_pred_int_buf();
1486 const int* src_lbls = 0;
1487 data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
1489 for(i = 0; i < n; i++)
1490 temp_buf[i] = src_lbls[i];
1492 if (data->is_buf_16u)
1494 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
1495 ci*scount + left->offset);
1496 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
1497 ci*scount + right->offset);
1499 for( i = 0; i < n; i++ )
1502 int idx = temp_buf[i];
1505 *rdst = (unsigned short)idx;
1507 nr1 += (idx != 65535);
1511 *ldst = (unsigned short)idx;
1516 if( vi < data->var_count )
1518 left->set_num_valid(vi, n1 - nr1);
1519 right->set_num_valid(vi, nr1);
1524 int *ldst = buf->data.i + left->buf_idx*buf->cols +
1525 ci*scount + left->offset;
1526 int *rdst = buf->data.i + right->buf_idx*buf->cols +
1527 ci*scount + right->offset;
1529 for( i = 0; i < n; i++ )
1532 int idx = temp_buf[i];
1547 if( vi < data->var_count )
1549 left->set_num_valid(vi, n1 - nr1);
1550 right->set_num_valid(vi, nr1);
1556 // split sample indices
1557 int *sample_idx_src_buf = data->get_sample_idx_buf();
1558 const int* sample_idx_src = 0;
1559 if (split_input_data)
1561 data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
1563 for(i = 0; i < n; i++)
1564 temp_buf[i] = sample_idx_src[i];
1566 int pos = data->get_work_var_count();
1568 if (data->is_buf_16u)
1570 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
1571 pos*scount + left->offset);
1572 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
1573 pos*scount + right->offset);
1575 for (i = 0; i < n; i++)
1578 unsigned short idx = (unsigned short)temp_buf[i];
1593 int* ldst = buf->data.i + left->buf_idx*buf->cols +
1594 pos*scount + left->offset;
1595 int* rdst = buf->data.i + right->buf_idx*buf->cols +
1596 pos*scount + right->offset;
1597 for (i = 0; i < n; i++)
1600 int idx = temp_buf[i];
1615 // deallocate the parent node data that is not needed anymore
1616 data->free_node_data(node);
1619 CvERTrees::CvERTrees()
1623 CvERTrees::~CvERTrees()
1627 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1628 const CvMat* _responses, const CvMat* _var_idx,
1629 const CvMat* _sample_idx, const CvMat* _var_type,
1630 const CvMat* _missing_mask, CvRTParams params )
1632 bool result = false;
1634 CV_FUNCNAME("CvERTrees::train");
1640 CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1641 params.regression_accuracy, params.use_surrogates, params.max_categories,
1642 params.cv_folds, params.use_1se_rule, false, params.priors );
1644 data = new CvERTreeTrainData();
1645 CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1646 _sample_idx, _var_type, _missing_mask, tree_params, true));
1648 var_count = data->var_count;
1649 if( params.nactive_vars > var_count )
1650 params.nactive_vars = var_count;
1651 else if( params.nactive_vars == 0 )
1652 params.nactive_vars = (int)sqrt((double)var_count);
1653 else if( params.nactive_vars < 0 )
1654 CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1656 // Create mask of active variables at the tree nodes
1657 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1658 if( params.calc_var_importance )
1660 CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 ));
1661 cvZero(var_importance);
1663 { // initialize active variables mask
1664 CvMat submask1, submask2;
1665 CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1666 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1667 cvSet( &submask1, cvScalar(1) );
1668 if( params.nactive_vars < active_var_mask->cols )
1670 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1671 cvZero( &submask2 );
1675 CV_CALL(result = grow_forest( params.term_crit ));
1684 bool CvERTrees::train( CvMLData* data, CvRTParams params)
1686 bool result = false;
1688 CV_FUNCNAME( "CvERTrees::train" );
1692 CV_CALL( result = CvRTrees::train( data, params) );
1699 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1701 bool result = false;
1703 CvMat* sample_idx_for_tree = 0;
1705 CV_FUNCNAME("CvERTrees::grow_forest");
1708 const int max_ntrees = term_crit.max_iter;
1709 const double max_oob_err = term_crit.epsilon;
1711 const int dims = data->var_count;
1712 float maximal_response = 0;
1714 CvMat* oob_sample_votes = 0;
1715 CvMat* oob_responses = 0;
1717 float* oob_samples_perm_ptr= 0;
1719 float* samples_ptr = 0;
1720 uchar* missing_ptr = 0;
1721 float* true_resp_ptr = 0;
1722 bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1724 // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1725 // oob_num_of_predictions[i] = number of summands
1726 // (number of predictions for the i-th sample)
1727 // initialize these variable to avoid warning C4701
1728 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1729 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1731 nsamples = data->sample_count;
1732 nclasses = data->get_num_classes();
1734 if ( is_oob_or_vimportance )
1736 if( data->is_classifier )
1738 CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1739 cvZero(oob_sample_votes);
1743 // oob_responses[0,i] = oob_predictions_sum[i]
1744 // = sum of predicted values for the i-th sample
1745 // oob_responses[1,i] = oob_num_of_predictions[i]
1746 // = number of summands (number of predictions for the i-th sample)
1747 CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1748 cvZero(oob_responses);
1749 cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1750 cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1753 CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1754 CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1755 CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1756 CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples ));
1758 CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1760 double minval, maxval;
1761 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1762 cvMinMaxLoc( &responses, &minval, &maxval );
1763 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1767 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1768 memset( trees, 0, sizeof(trees[0])*max_ntrees );
1770 CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1772 for (int i = 0; i < nsamples; i++)
1773 sample_idx_for_tree->data.i[i] = i;
1775 while( ntrees < max_ntrees )
1777 int i, oob_samples_count = 0;
1778 double ncorrect_responses = 0; // used for estimation of variable importance
1779 CvForestTree* tree = 0;
1781 trees[ntrees] = new CvForestERTree();
1782 tree = (CvForestERTree*)trees[ntrees];
1783 CV_CALL(tree->train( data, 0, this ));
1785 if ( is_oob_or_vimportance )
1787 CvMat sample, missing;
1788 // form array of OOB samples indices and get these samples
1789 sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
1790 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1793 for( i = 0; i < nsamples; i++,
1794 sample.data.fl += dims, missing.data.ptr += dims )
1796 CvDTreeNode* predicted_node = 0;
1798 // predict oob samples
1799 if( !predicted_node )
1800 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1802 if( !data->is_classifier ) //regression
1804 double avg_resp, resp = predicted_node->value;
1805 oob_predictions_sum.data.fl[i] += (float)resp;
1806 oob_num_of_predictions.data.fl[i] += 1;
1808 // compute oob error
1809 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1810 avg_resp -= true_resp_ptr[i];
1811 oob_error += avg_resp*avg_resp;
1812 resp = (resp - true_resp_ptr[i])/maximal_response;
1813 ncorrect_responses += exp( -resp*resp );
1815 else //classification
1821 cvGetRow(oob_sample_votes, &votes, i);
1822 votes.data.i[predicted_node->class_idx]++;
1824 // compute oob error
1825 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1827 prdct_resp = data->cat_map->data.i[max_loc.x];
1828 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1830 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1832 oob_samples_count++;
1834 if( oob_samples_count > 0 )
1835 oob_error /= (double)oob_samples_count;
1837 // estimate variable importance
1838 if( var_importance && oob_samples_count > 0 )
1842 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1843 for( m = 0; m < dims; m++ )
1845 double ncorrect_responses_permuted = 0;
1846 // randomly permute values of the m-th variable in the oob samples
1847 float* mth_var_ptr = oob_samples_perm_ptr + m;
1849 for( i = 0; i < nsamples; i++ )
1854 i1 = cvRandInt( &rng ) % nsamples;
1855 i2 = cvRandInt( &rng ) % nsamples;
1856 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1858 // turn values of (m-1)-th variable, that were permuted
1859 // at the previous iteration, untouched
1861 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1864 // predict "permuted" cases and calculate the number of votes for the
1865 // correct class in the variable-m-permuted oob data
1866 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1867 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1868 for( i = 0; i < nsamples; i++,
1869 sample.data.fl += dims, missing.data.ptr += dims )
1871 double predct_resp, true_resp;
1873 predct_resp = tree->predict(&sample, &missing, true)->value;
1874 true_resp = true_resp_ptr[i];
1875 if( data->is_classifier )
1876 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1879 true_resp = (true_resp - predct_resp)/maximal_response;
1880 ncorrect_responses_permuted += exp( -true_resp*true_resp );
1883 var_importance->data.fl[m] += (float)(ncorrect_responses
1884 - ncorrect_responses_permuted);
1889 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1892 if( var_importance )
1894 for ( int vi = 0; vi < var_importance->cols; vi++ )
1895 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1896 var_importance->data.fl[vi] : 0;
1897 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1902 cvFree( &oob_samples_perm_ptr );
1903 cvFree( &samples_ptr );
1904 cvFree( &missing_ptr );
1905 cvFree( &true_resp_ptr );
1907 cvReleaseMat( &sample_idx_for_tree );
1909 cvReleaseMat( &oob_sample_votes );
1910 cvReleaseMat( &oob_responses );
1919 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1920 const Mat& _responses, const Mat& _var_idx,
1921 const Mat& _sample_idx, const Mat& _var_type,
1922 const Mat& _missing_mask, CvRTParams params )
1924 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1925 sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1926 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1927 sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1928 mmask.data.ptr ? &mmask : 0, params);