1 /*M///////////////////////////////////////////////////////////////////////////////////////
\r
3 IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
\r
5 By downloading, copying, installing or using the software you agree to this license.
\r
6 If you do not agree to this license, do not download, install,
\r
7 copy or use the software.
\r
10 Intel License Agreement
\r
12 Copyright (C) 2000, Intel Corporation, all rights reserved.
\r
13 Third party copyrights are property of their respective owners.
\r
15 Redistribution and use in source and binary forms, with or without modification,
\r
16 are permitted provided that the following conditions are met:
\r
18 * Redistribution's of source code must retain the above copyright notice,
\r
19 this list of conditions and the following disclaimer.
\r
21 * Redistribution's in binary form must reproduce the above copyright notice,
\r
22 this list of conditions and the following disclaimer in the documentation
\r
23 and/or other materials provided with the distribution.
\r
25 * The name of Intel Corporation may not be used to endorse or promote products
\r
26 derived from this software without specific prior written permission.
\r
28 This software is provided by the copyright holders and contributors "as is" and
\r
29 any express or implied warranties, including, but not limited to, the implied
\r
30 warranties of merchantability and fitness for a particular purpose are disclaimed.
\r
31 In no event shall the Intel Corporation or contributors be liable for any direct,
\r
32 indirect, incidental, special, exemplary, or consequential damages
\r
33 (including, but not limited to, procurement of substitute goods or services;
\r
34 loss of use, data, or profits; or business interruption) however caused
\r
35 and on any theory of liability, whether in contract, strict liability,
\r
36 or tort (including negligence or otherwise) arising in any way out of
\r
37 the use of this software, even if advised of the possibility of such damage.
\r
43 static const float ord_nan = FLT_MAX*0.5f;
\r
44 static const int min_block_size = 1 << 16;
\r
45 static const int block_size_delta = 1 << 10;
\r
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
\r
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
\r
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
\r
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
\r
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
\r
56 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
\r
57 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
\r
58 bool _shared, bool _add_labels, bool _update_data )
\r
60 CvMat* sample_indices = 0;
\r
61 CvMat* var_type0 = 0;
\r
64 CvPair16u32s* pair16u32s_ptr = 0;
\r
65 CvDTreeTrainData* data = 0;
\r
68 unsigned short* udst = 0;
\r
71 CV_FUNCNAME( "CvERTreeTrainData::set_data" );
\r
75 int sample_all = 0, r_type = 0, cv_n;
\r
76 int total_c_count = 0;
\r
77 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
\r
78 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
\r
81 const int *sidx = 0, *vidx = 0;
\r
83 if ( _params.use_surrogates )
\r
84 CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
\r
86 if( _update_data && data_root )
\r
88 CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
\r
96 CV_CALL( set_params( _params ));
\r
98 // check parameter types and sizes
\r
99 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
\r
101 train_data = _train_data;
\r
102 responses = _responses;
\r
103 missing_mask = _missing_mask;
\r
105 if( _tflag == CV_ROW_SAMPLE )
\r
107 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
\r
109 if( _missing_mask )
\r
110 ms_step = _missing_mask->step, mv_step = 1;
\r
114 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
\r
116 if( _missing_mask )
\r
117 mv_step = _missing_mask->step, ms_step = 1;
\r
121 sample_count = sample_all;
\r
122 var_count = var_all;
\r
126 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
\r
127 sidx = sample_indices->data.i;
\r
128 sample_count = sample_indices->rows + sample_indices->cols - 1;
\r
133 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
\r
134 vidx = var_idx->data.i;
\r
135 var_count = var_idx->rows + var_idx->cols - 1;
\r
138 if( !CV_IS_MAT(_responses) ||
\r
139 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
\r
140 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
\r
141 (_responses->rows != 1 && _responses->cols != 1) ||
\r
142 _responses->rows + _responses->cols - 1 != sample_all )
\r
143 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
\r
144 "floating-point vector containing as many elements as "
\r
145 "the total number of samples in the training data matrix" );
\r
147 is_buf_16u = false;
\r
148 if ( sample_count < 65536 )
\r
149 is_buf_16u = true;
\r
152 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
\r
154 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
\r
158 ord_var_count = -1;
\r
160 is_classifier = r_type == CV_VAR_CATEGORICAL;
\r
162 // step 0. calc the number of categorical vars
\r
163 for( vi = 0; vi < var_count; vi++ )
\r
165 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
\r
166 cat_var_count++ : ord_var_count--;
\r
169 ord_var_count = ~ord_var_count;
\r
170 cv_n = params.cv_folds;
\r
171 // set the two last elements of var_type array to be able
\r
172 // to locate responses and cross-validation labels using
\r
173 // the corresponding get_* functions.
\r
174 var_type->data.i[var_count] = cat_var_count;
\r
175 var_type->data.i[var_count+1] = cat_var_count+1;
\r
177 // in case of single ordered predictor we need dummy cv_labels
\r
178 // for safe split_node_data() operation
\r
179 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
\r
181 work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
\r
182 buf_size = (work_var_count + 1)*sample_count;
\r
184 buf_count = shared ? 2 : 1;
\r
188 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
\r
189 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
\r
193 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
\r
194 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
\r
197 size = is_classifier ? cat_var_count+1 : cat_var_count;
\r
198 size = !size ? 1 : size;
\r
199 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
\r
200 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
\r
202 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
\r
203 size = !size ? 1 : size;
\r
204 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
\r
206 // now calculate the maximum size of split,
\r
207 // create memory storage that will keep nodes and splits of the decision tree
\r
208 // allocate root node and the buffer for the whole training data
\r
209 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
\r
210 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
\r
211 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
\r
212 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
\r
213 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
\r
214 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
\r
216 nv_size = var_count*sizeof(int);
\r
217 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
\r
219 temp_block_size = nv_size;
\r
223 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
\r
224 CV_ERROR( CV_StsOutOfRange,
\r
225 "The many folds in cross-validation for such a small dataset" );
\r
227 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
\r
228 temp_block_size = MAX(temp_block_size, cv_size);
\r
231 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
\r
232 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
\r
233 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
\r
235 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
\r
237 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
\r
244 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
\r
245 if (is_buf_16u && (cat_var_count || is_classifier))
\r
246 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
\r
248 // transform the training data to convenient representation
\r
249 for( vi = 0; vi <= var_count; vi++ )
\r
252 const uchar* mask = 0;
\r
253 int m_step = 0, step;
\r
254 const int* idata = 0;
\r
255 const float* fdata = 0;
\r
258 if( vi < var_count ) // analyze i-th input variable
\r
260 int vi0 = vidx ? vidx[vi] : vi;
\r
261 ci = get_var_type(vi);
\r
262 step = ds_step; m_step = ms_step;
\r
263 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
\r
264 idata = _train_data->data.i + vi0*dv_step;
\r
266 fdata = _train_data->data.fl + vi0*dv_step;
\r
267 if( _missing_mask )
\r
268 mask = _missing_mask->data.ptr + vi0*mv_step;
\r
270 else // analyze _responses
\r
272 ci = cat_var_count;
\r
273 step = CV_IS_MAT_CONT(_responses->type) ?
\r
274 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
\r
275 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
\r
276 idata = _responses->data.i;
\r
278 fdata = _responses->data.fl;
\r
281 if( (vi < var_count && ci>=0) ||
\r
282 (vi == var_count && is_classifier) ) // process categorical variable or response
\r
284 int c_count, prev_label;
\r
288 udst = (unsigned short*)(buf->data.s + ci*sample_count);
\r
290 idst = buf->data.i + ci*sample_count;
\r
293 for( i = 0; i < sample_count; i++ )
\r
295 int val = INT_MAX, si = sidx ? sidx[i] : i;
\r
296 if( !mask || !mask[si*m_step] )
\r
299 val = idata[si*step];
\r
302 float t = fdata[si*step];
\r
306 sprintf( err, "%d-th value of %d-th (categorical) "
\r
307 "variable is not an integer", i, vi );
\r
308 CV_ERROR( CV_StsBadArg, err );
\r
312 if( val == INT_MAX )
\r
314 sprintf( err, "%d-th value of %d-th (categorical) "
\r
315 "variable is too large", i, vi );
\r
316 CV_ERROR( CV_StsBadArg, err );
\r
323 pair16u32s_ptr[i].u = udst + i;
\r
324 pair16u32s_ptr[i].i = _idst + i;
\r
329 int_ptr[i] = idst + i;
\r
333 c_count = num_valid > 0;
\r
337 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
\r
338 // count the categories
\r
339 for( i = 1; i < num_valid; i++ )
\r
340 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
\r
345 icvSortIntPtr( int_ptr, sample_count, 0 );
\r
346 // count the categories
\r
347 for( i = 1; i < num_valid; i++ )
\r
348 c_count += *int_ptr[i] != *int_ptr[i-1];
\r
352 max_c_count = MAX( max_c_count, c_count );
\r
353 cat_count->data.i[ci] = c_count;
\r
354 cat_ofs->data.i[ci] = total_c_count;
\r
356 // resize cat_map, if need
\r
357 if( cat_map->cols < total_c_count + c_count )
\r
360 CV_CALL( cat_map = cvCreateMat( 1,
\r
361 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
\r
362 for( i = 0; i < total_c_count; i++ )
\r
363 cat_map->data.i[i] = tmp_map->data.i[i];
\r
364 cvReleaseMat( &tmp_map );
\r
367 c_map = cat_map->data.i + total_c_count;
\r
368 total_c_count += c_count;
\r
373 // compact the class indices and build the map
\r
374 prev_label = ~*pair16u32s_ptr[0].i;
\r
375 for( i = 0; i < num_valid; i++ )
\r
377 int cur_label = *pair16u32s_ptr[i].i;
\r
378 if( cur_label != prev_label )
\r
379 c_map[++c_count] = prev_label = cur_label;
\r
380 *pair16u32s_ptr[i].u = (unsigned short)c_count;
\r
382 // replace labels for missing values with 65535
\r
383 for( ; i < sample_count; i++ )
\r
384 *pair16u32s_ptr[i].u = 65535;
\r
388 // compact the class indices and build the map
\r
389 prev_label = ~*int_ptr[0];
\r
390 for( i = 0; i < num_valid; i++ )
\r
392 int cur_label = *int_ptr[i];
\r
393 if( cur_label != prev_label )
\r
394 c_map[++c_count] = prev_label = cur_label;
\r
395 *int_ptr[i] = c_count;
\r
397 // replace labels for missing values with -1
\r
398 for( ; i < sample_count; i++ )
\r
402 else if( ci < 0 ) // process ordered variable
\r
404 for( i = 0; i < sample_count; i++ )
\r
406 float val = ord_nan;
\r
407 int si = sidx ? sidx[i] : i;
\r
408 if( !mask || !mask[si*m_step] )
\r
411 val = (float)idata[si*step];
\r
413 val = fdata[si*step];
\r
415 if( fabs(val) >= ord_nan )
\r
417 sprintf( err, "%d-th value of %d-th (ordered) "
\r
418 "variable (=%g) is too large", i, vi, val );
\r
419 CV_ERROR( CV_StsBadArg, err );
\r
425 if( vi < var_count )
\r
426 data_root->set_num_valid(vi, num_valid);
\r
429 // set sample labels
\r
431 udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
\r
433 idst = buf->data.i + get_work_var_count()*sample_count;
\r
435 for (i = 0; i < sample_count; i++)
\r
438 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
\r
440 idst[i] = sidx ? sidx[i] : i;
\r
445 unsigned short* udst = 0;
\r
451 udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
\r
452 for( i = vi = 0; i < sample_count; i++ )
\r
454 udst[i] = (unsigned short)vi++;
\r
455 vi &= vi < cv_n ? -1 : 0;
\r
458 for( i = 0; i < sample_count; i++ )
\r
460 int a = cvRandInt(r) % sample_count;
\r
461 int b = cvRandInt(r) % sample_count;
\r
462 unsigned short unsh = (unsigned short)vi;
\r
463 CV_SWAP( udst[a], udst[b], unsh );
\r
468 idst = buf->data.i + (get_work_var_count()-1)*sample_count;
\r
469 for( i = vi = 0; i < sample_count; i++ )
\r
472 vi &= vi < cv_n ? -1 : 0;
\r
475 for( i = 0; i < sample_count; i++ )
\r
477 int a = cvRandInt(r) % sample_count;
\r
478 int b = cvRandInt(r) % sample_count;
\r
479 CV_SWAP( idst[a], idst[b], vi );
\r
485 cat_map->cols = MAX( total_c_count, 1 );
\r
487 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
\r
488 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
\r
489 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
\r
491 have_priors = is_classifier && params.priors;
\r
492 if( is_classifier )
\r
494 int m = get_num_classes();
\r
496 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
\r
497 for( i = 0; i < m; i++ )
\r
499 double val = have_priors ? params.priors[i] : 1.;
\r
501 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
\r
502 priors->data.db[i] = val;
\r
506 // normalize weights
\r
508 cvScale( priors, priors, 1./sum );
\r
510 CV_CALL( priors_mult = cvCloneMat( priors ));
\r
511 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
\r
514 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
\r
515 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
\r
518 int maxNumThreads = 1;
\r
520 maxNumThreads = cv::getNumThreads();
\r
522 pred_float_buf.resize(maxNumThreads);
\r
523 pred_int_buf.resize(maxNumThreads);
\r
524 resp_float_buf.resize(maxNumThreads);
\r
525 resp_int_buf.resize(maxNumThreads);
\r
526 cv_lables_buf.resize(maxNumThreads);
\r
527 sample_idx_buf.resize(maxNumThreads);
\r
528 for( int ti = 0; ti < maxNumThreads; ti++ )
\r
530 pred_float_buf[ti].resize(sample_count);
\r
531 pred_int_buf[ti].resize(sample_count);
\r
532 resp_float_buf[ti].resize(sample_count);
\r
533 resp_int_buf[ti].resize(sample_count);
\r
534 cv_lables_buf[ti].resize(sample_count);
\r
535 sample_idx_buf[ti].resize(sample_count);
\r
548 cvFree( &int_ptr );
\r
549 cvReleaseMat( &var_type0 );
\r
550 cvReleaseMat( &sample_indices );
\r
551 cvReleaseMat( &tmp_map );
\r
554 int CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, const float** ord_values, const int** missing )
\r
556 int vidx = var_idx ? var_idx->data.i[vi] : vi;
\r
557 int node_sample_count = n->sample_count;
\r
558 int* sample_indices_buf = get_sample_idx_buf();
\r
559 const int* sample_indices = 0;
\r
561 get_sample_indices(n, sample_indices_buf, &sample_indices);
\r
563 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
\r
564 int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
\r
565 if( tflag == CV_ROW_SAMPLE )
\r
567 for( int i = 0; i < node_sample_count; i++ )
\r
569 int idx = sample_indices[i];
\r
570 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
\r
571 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
\r
575 for( int i = 0; i < node_sample_count; i++ )
\r
577 int idx = sample_indices[i];
\r
578 missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
\r
579 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
\r
581 *ord_values = ord_values_buf;
\r
582 *missing = missing_buf;
\r
583 return 0; //TODO: return the number of non-missing values
\r
587 void CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
\r
589 get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf, indices );
\r
593 void CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
\r
596 get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf, labels );
\r
600 int CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
\r
602 int ci = get_var_type( vi);
\r
604 *cat_values = buf->data.i + n->buf_idx*buf->cols +
\r
605 ci*sample_count + n->offset;
\r
607 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
\r
608 ci*sample_count + n->offset);
\r
609 for( int i = 0; i < n->sample_count; i++ )
\r
610 cat_values_buf[i] = short_values[i];
\r
611 *cat_values = cat_values_buf;
\r
614 return 0; //TODO: return the number of non-missing values
\r
617 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
\r
618 float* values, uchar* missing,
\r
619 float* responses, bool get_class_idx )
\r
621 CvMat* subsample_idx = 0;
\r
622 CvMat* subsample_co = 0;
\r
624 CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
\r
628 int i, vi, total = sample_count, count = total, cur_ofs = 0;
\r
632 if( _subsample_idx )
\r
634 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
\r
635 sidx = subsample_idx->data.i;
\r
636 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
\r
637 co = subsample_co->data.i;
\r
638 cvZero( subsample_co );
\r
639 count = subsample_idx->cols + subsample_idx->rows - 1;
\r
640 for( i = 0; i < count; i++ )
\r
642 for( i = 0; i < total; i++ )
\r
644 int count_i = co[i*2];
\r
647 co[i*2+1] = cur_ofs*var_count;
\r
648 cur_ofs += count_i;
\r
654 memset( missing, 1, count*var_count );
\r
656 for( vi = 0; vi < var_count; vi++ )
\r
658 int ci = get_var_type(vi);
\r
659 if( ci >= 0 ) // categorical
\r
661 float* dst = values + vi;
\r
662 uchar* m = missing ? missing + vi : 0;
\r
663 int* src_buf = get_pred_int_buf();
\r
664 const int* src = 0;
\r
665 get_cat_var_data(data_root, vi, src_buf, &src);
\r
667 for( i = 0; i < count; i++, dst += var_count )
\r
669 int idx = sidx ? sidx[i] : i;
\r
670 int val = src[idx];
\r
674 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
\r
681 float* dst_buf = values + vi;
\r
682 int* m_buf = get_pred_int_buf();
\r
683 const float *dst = 0;
\r
685 get_ord_var_data(data_root, vi, dst_buf, m_buf, &dst, &m);
\r
686 for (int si = 0; si < total; si++)
\r
687 *(missing + vi + si) = m[si] == 0 ? 0 : 1;
\r
694 if( is_classifier )
\r
696 int* src_buf = get_resp_int_buf();
\r
697 const int* src = 0;
\r
698 get_class_labels(data_root, src_buf, &src);
\r
699 for( i = 0; i < count; i++ )
\r
701 int idx = sidx ? sidx[i] : i;
\r
702 int val = get_class_idx ? src[idx] :
\r
703 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
\r
704 responses[i] = (float)val;
\r
709 float *_values_buf = get_resp_float_buf();
\r
710 const float* _values = 0;
\r
711 get_ord_responses(data_root, _values_buf, &_values);
\r
712 for( i = 0; i < count; i++ )
\r
714 int idx = sidx ? sidx[i] : i;
\r
715 responses[i] = _values[idx];
\r
722 cvReleaseMat( &subsample_idx );
\r
723 cvReleaseMat( &subsample_co );
\r
726 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
\r
728 CvDTreeNode* root = 0;
\r
730 CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
\r
735 CV_ERROR( CV_StsError, "No training data has been set" );
\r
737 if( !_subsample_idx )
\r
739 // make a copy of the root node
\r
742 root = new_node( 0, 1, 0, 0 );
\r
744 *root = *data_root;
\r
745 root->num_valid = temp.num_valid;
\r
746 if( root->num_valid )
\r
748 for( i = 0; i < var_count; i++ )
\r
749 root->num_valid[i] = data_root->num_valid[i];
\r
751 root->cv_Tn = temp.cv_Tn;
\r
752 root->cv_node_risk = temp.cv_node_risk;
\r
753 root->cv_node_error = temp.cv_node_error;
\r
756 CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
\r
762 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
\r
764 char* dir = (char*)data->direction->data.ptr;
\r
765 int i, n = node->sample_count, vi = node->split->var_idx;
\r
768 assert( !node->split->inversed );
\r
770 if( data->get_var_type(vi) >= 0 ) // split on categorical var
\r
772 int* labels_buf = data->get_pred_int_buf();
\r
773 const int* labels = 0;
\r
774 const int* subset = node->split->subset;
\r
775 data->get_cat_var_data( node, vi, labels_buf, &labels );
\r
776 if( !data->have_priors )
\r
778 int sum = 0, sum_abs = 0;
\r
780 for( i = 0; i < n; i++ )
\r
782 int idx = labels[i];
\r
783 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
\r
784 CV_DTREE_CAT_DIR(idx,subset) : 0;
\r
785 sum += d; sum_abs += d & 1;
\r
789 R = (sum_abs + sum) >> 1;
\r
790 L = (sum_abs - sum) >> 1;
\r
794 const double* priors = data->priors_mult->data.db;
\r
795 double sum = 0, sum_abs = 0;
\r
796 int *responses_buf = data->get_resp_int_buf();
\r
797 const int* responses;
\r
798 data->get_class_labels(node, responses_buf, &responses);
\r
800 for( i = 0; i < n; i++ )
\r
802 int idx = labels[i];
\r
803 double w = priors[responses[i]];
\r
804 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
\r
805 sum += d*w; sum_abs += (d & 1)*w;
\r
809 R = (sum_abs + sum) * 0.5;
\r
810 L = (sum_abs - sum) * 0.5;
\r
813 else // split on ordered var
\r
815 float split_val = node->split->ord.c;
\r
816 float* val_buf = data->get_pred_float_buf();
\r
817 const float* val = 0;
\r
818 int* missing_buf = data->get_pred_int_buf();
\r
819 const int* missing = 0;
\r
820 data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing );
\r
822 if( !data->have_priors )
\r
825 for( i = 0; i < n; i++ )
\r
831 if ( val[i] < split_val)
\r
846 const double* priors = data->priors_mult->data.db;
\r
847 int* responses_buf = data->get_resp_int_buf();
\r
848 const int* responses = 0;
\r
849 data->get_class_labels(node, responses_buf, &responses);
\r
851 for( i = 0; i < n; i++ )
\r
857 double w = priors[responses[i]];
\r
858 if ( val[i] < split_val)
\r
873 node->maxlr = MAX( L, R );
\r
874 return node->split->quality/(L + R);
\r
877 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
879 const float epsilon = FLT_EPSILON*2;
\r
880 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
\r
882 int n = node->sample_count;
\r
883 int m = data->get_num_classes();
\r
885 float* values_buf = data->get_pred_float_buf();
\r
886 const float* values = 0;
\r
887 int* missing_buf = data->get_pred_int_buf();
\r
888 const int* missing = 0;
\r
889 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
\r
890 int* responses_buf = data->get_resp_int_buf();
\r
891 const int* responses = 0;
\r
892 data->get_class_labels( node, responses_buf, &responses );
\r
894 double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
\r
898 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
\r
900 bool is_find_split = false;
\r
904 while ( missing[smpi] && (smpi < n) )
\r
908 pmin = values[smpi];
\r
910 for (; smpi < n; smpi++)
\r
912 float ptemp = values[smpi];
\r
913 int m = missing[smpi];
\r
920 float fdiff = pmax-pmin;
\r
921 if (fdiff > epsilon)
\r
923 is_find_split = true;
\r
924 CvRNG* rng = &data->rng;
\r
925 split_val = pmin + cvRandReal(rng) * fdiff ;
\r
926 if (split_val - pmin <= FLT_EPSILON)
\r
927 split_val = pmin + split_delta;
\r
928 if (pmax - split_val <= FLT_EPSILON)
\r
929 split_val = pmax - split_delta;
\r
931 // calculate Gini index
\r
934 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
\r
935 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
\r
938 // init arrays of class instance counters on both sides of the split
\r
939 for( i = 0; i < m; i++ )
\r
944 for( int si = 0; si < n; si++ )
\r
946 int r = responses[si];
\r
947 float val = values[si];
\r
948 int m = missing[si];
\r
950 if ( val < split_val )
\r
961 for (int i = 0; i < m; i++)
\r
963 lbest_val += lc[i]*lc[i];
\r
964 rbest_val += rc[i]*rc[i];
\r
966 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
\r
970 double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
\r
971 double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
\r
972 double L = 0, R = 0;
\r
974 // init arrays of class instance counters on both sides of the split
\r
975 for( i = 0; i < m; i++ )
\r
980 for( int si = 0; si < n; si++ )
\r
982 int r = responses[si];
\r
983 float val = values[si];
\r
984 int m = missing[si];
\r
985 double p = priors[si];
\r
987 if ( val < split_val )
\r
998 for (int i = 0; i < m; i++)
\r
1000 lbest_val += lc[i]*lc[i];
\r
1001 rbest_val += rc[i]*rc[i];
\r
1003 best_val = (lbest_val*R + rbest_val*L) / (L*R);
\r
1008 CvDTreeSplit* split = 0;
\r
1009 if( is_find_split )
\r
1011 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
\r
1012 split->var_idx = vi;
1013 split->ord.c = (float)split_val;
1014 split->ord.split_point = -1;
1015 split->inversed = 0;
1016 split->quality = (float)best_val;
1021 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
1023 int ci = data->get_var_type(vi);
\r
1024 int n = node->sample_count;
\r
1025 int cm = data->get_num_classes();
\r
1026 int vm = data->cat_count->data.i[ci];
\r
1027 double best_val = init_quality;
\r
1028 CvDTreeSplit *split = 0;
\r
1032 int* labels_buf = data->get_pred_int_buf();
\r
1033 const int* labels = 0;
\r
1034 data->get_cat_var_data( node, vi, labels_buf, &labels );
\r
1036 int* responses_buf = data->get_resp_int_buf();
\r
1037 const int* responses = 0;
\r
1038 data->get_class_labels( node, responses_buf, &responses );
\r
1040 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
\r
1042 // create random class mask
\r
1043 int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
\r
1044 for (int i = 0; i < vm; i++)
\r
1046 valid_cidx[i] = -1;
\r
1048 for (int si = 0; si < n; si++)
\r
1050 int c = labels[si];
\r
1051 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
\r
1056 int valid_ccount = 0;
\r
1057 for (int i = 0; i < vm; i++)
\r
1058 if (valid_cidx[i] >= 0)
\r
1060 valid_cidx[i] = valid_ccount;
\r
1063 if (valid_ccount > 1)
\r
1065 CvRNG* rng = forest->get_rng();
\r
1066 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
\r
1068 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
\r
1070 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
\r
1071 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
\r
1072 cvSet( &submask, cvScalar(1) );
\r
1073 for (int i = 0; i < valid_ccount; i++)
\r
1076 int i1 = cvRandInt( rng ) % valid_ccount;
\r
1077 int i2 = cvRandInt( rng ) % valid_ccount;
\r
1078 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
\r
1081 split = _split ? _split : data->new_split_cat( 0, -1.0f );
\r
1082 split->var_idx = vi;
\r
1083 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
\r
1085 // calculate Gini index
\r
1086 double lbest_val = 0, rbest_val = 0;
\r
1089 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
\r
1090 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
\r
1092 // init arrays of class instance counters on both sides of the split
\r
1093 for(int i = 0; i < cm; i++ )
\r
1098 for( int si = 0; si < n; si++ )
\r
1100 int r = responses[si];
\r
1101 int var_class_idx = labels[si];
\r
1102 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
\r
1104 int mask_class_idx = valid_cidx[var_class_idx];
\r
1105 if (var_class_mask->data.ptr[mask_class_idx])
\r
1109 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
\r
1117 for (int i = 0; i < cm; i++)
\r
1119 lbest_val += lc[i]*lc[i];
\r
1120 rbest_val += rc[i]*rc[i];
\r
1122 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
\r
1126 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
\r
1127 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
\r
1128 double L = 0, R = 0;
\r
1129 // init arrays of class instance counters on both sides of the split
\r
1130 for(int i = 0; i < cm; i++ )
\r
1135 for( int si = 0; si < n; si++ )
\r
1137 int r = responses[si];
\r
1138 int var_class_idx = labels[si];
\r
1139 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
\r
1141 double p = priors[si];
\r
1142 int mask_class_idx = valid_cidx[var_class_idx];
\r
1144 if (var_class_mask->data.ptr[mask_class_idx])
\r
1148 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
\r
1156 for (int i = 0; i < cm; i++)
\r
1158 lbest_val += lc[i]*lc[i];
\r
1159 rbest_val += rc[i]*rc[i];
\r
1161 best_val = (lbest_val*R + rbest_val*L) / (L*R);
\r
1163 split->quality = (float)best_val;
\r
1165 cvReleaseMat(&var_class_mask);
\r
1172 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
1174 const float epsilon = FLT_EPSILON*2;
\r
1175 const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
\r
1176 int n = node->sample_count;
\r
1177 float* values_buf = data->get_pred_float_buf();
\r
1178 const float* values = 0;
\r
1179 int* missing_buf = data->get_pred_int_buf();
\r
1180 const int* missing = 0;
\r
1181 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
\r
1182 float* responses_buf = data->get_resp_float_buf();
\r
1183 const float* responses = 0;
\r
1184 data->get_ord_responses( node, responses_buf, &responses );
\r
1186 double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
\r
1189 bool is_find_split = false;
\r
1192 while ( missing[smpi] && (smpi < n) )
\r
1197 pmin = values[smpi];
\r
1199 for (; smpi < n; smpi++)
\r
1201 float ptemp = values[smpi];
\r
1202 int m = missing[smpi];
\r
1204 if ( ptemp < pmin)
\r
1206 if ( ptemp > pmax)
\r
1209 float fdiff = pmax-pmin;
\r
1210 if (fdiff > epsilon)
\r
1212 is_find_split = true;
\r
1213 CvRNG* rng = &data->rng;
\r
1214 split_val = pmin + cvRandReal(rng) * fdiff ;
\r
1215 if (split_val - pmin <= FLT_EPSILON)
\r
1216 split_val = pmin + split_delta;
\r
1217 if (pmax - split_val <= FLT_EPSILON)
\r
1218 split_val = pmax - split_delta;
\r
1220 for (int si = 0; si < n; si++)
\r
1222 float r = responses[si];
\r
1223 float val = values[si];
\r
1224 int m = missing[si];
\r
1226 if (val < split_val)
\r
1237 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
\r
1240 CvDTreeSplit* split = 0;
\r
1241 if( is_find_split )
\r
1243 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
\r
1244 split->var_idx = vi;
1245 split->ord.c = (float)split_val;
1246 split->ord.split_point = -1;
1247 split->inversed = 0;
1248 split->quality = (float)best_val;
1253 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
\r
1255 int ci = data->get_var_type(vi);
\r
1256 int n = node->sample_count;
\r
1257 int vm = data->cat_count->data.i[ci];
\r
1258 double best_val = init_quality;
\r
1259 CvDTreeSplit *split = 0;
\r
1260 float lsum = 0, rsum = 0;
\r
1264 int* labels_buf = data->get_pred_int_buf();
\r
1265 const int* labels = 0;
\r
1266 data->get_cat_var_data( node, vi, labels_buf, &labels );
\r
1268 float* responses_buf = data->get_resp_float_buf();
\r
1269 const float* responses = 0;
\r
1270 data->get_ord_responses( node, responses_buf, &responses );
\r
1272 // create random class mask
\r
1273 int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
\r
1274 for (int i = 0; i < vm; i++)
\r
1276 valid_cidx[i] = -1;
\r
1278 for (int si = 0; si < n; si++)
\r
1280 int c = labels[si];
\r
1281 if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
\r
1286 int valid_ccount = 0;
\r
1287 for (int i = 0; i < vm; i++)
\r
1288 if (valid_cidx[i] >= 0)
\r
1290 valid_cidx[i] = valid_ccount;
\r
1293 if (valid_ccount > 1)
\r
1295 CvRNG* rng = forest->get_rng();
\r
1296 int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
\r
1298 CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
\r
1300 memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
\r
1301 cvGetCols( var_class_mask, &submask, 0, l_cval_count );
\r
1302 cvSet( &submask, cvScalar(1) );
\r
1303 for (int i = 0; i < valid_ccount; i++)
\r
1306 int i1 = cvRandInt( rng ) % valid_ccount;
\r
1307 int i2 = cvRandInt( rng ) % valid_ccount;
\r
1308 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
\r
1311 split = _split ? _split : data->new_split_cat( 0, -1.0f);
\r
1312 split->var_idx = vi;
\r
1313 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
\r
1316 for( int si = 0; si < n; si++ )
\r
1318 float r = responses[si];
\r
1319 int var_class_idx = labels[si];
\r
1320 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
\r
1322 int mask_class_idx = valid_cidx[var_class_idx];
\r
1323 if (var_class_mask->data.ptr[mask_class_idx])
\r
1327 split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
\r
1335 best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
\r
1337 split->quality = (float)best_val;
\r
1339 cvReleaseMat(&var_class_mask);
\r
1346 //void CvForestERTree::complete_node_dir( CvDTreeNode* node )
\r
1348 // int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
\r
1349 // int nz = n - node->get_num_valid(node->split->var_idx);
\r
1350 // char* dir = (char*)data->direction->data.ptr;
\r
1352 // // try to complete direction using surrogate splits
\r
1353 // if( nz && data->params.use_surrogates )
\r
1355 // CvDTreeSplit* split = node->split->next;
\r
1356 // for( ; split != 0 && nz; split = split->next )
\r
1358 // int inversed_mask = split->inversed ? -1 : 0;
\r
1359 // vi = split->var_idx;
\r
1361 // if( data->get_var_type(vi) >= 0 ) // split on categorical var
\r
1363 // int* labels_buf = data->pred_int_buf;
\r
1364 // const int* labels = 0;
\r
1365 // data->get_cat_var_data(node, vi, labels_buf, &labels);
\r
1366 // const int* subset = split->subset;
\r
1368 // for( i = 0; i < n; i++ )
\r
1370 // int idx = labels[i];
\r
1371 // if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
\r
1374 // int d = CV_DTREE_CAT_DIR(idx,subset);
\r
1375 // dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
\r
1381 // else // split on ordered var
\r
1383 // float* values_buf = data->pred_float_buf;
\r
1384 // const float* values = 0;
\r
1385 // uchar* missing_buf = (uchar*)data->pred_int_buf;
\r
1386 // const uchar* missing = 0;
\r
1387 // data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
\r
1388 // float split_val = node->split->ord.c;
\r
1390 // for( i = 0; i < n; i++ )
\r
1392 // if( !dir[i] && !missing[i])
\r
1394 // int d = values[i] <= split_val ? -1 : 1;
\r
1395 // dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
\r
1404 // // find the default direction for the rest
\r
1407 // for( i = nr = 0; i < n; i++ )
\r
1408 // nr += dir[i] > 0;
\r
1409 // nl = n - nr - nz;
\r
1410 // d0 = nl > nr ? -1 : nr > nl;
\r
1413 // // make sure that every sample is directed either to the left or to the right
\r
1414 // for( i = 0; i < n; i++ )
\r
1416 // int d = dir[i];
\r
1421 // d = d1, d1 = -d1;
\r
1424 // dir[i] = (char)d; // remap (-1,1) to (0,1)
\r
1428 void CvForestERTree::split_node_data( CvDTreeNode* node )
\r
1430 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
\r
1431 char* dir = (char*)data->direction->data.ptr;
\r
1432 CvDTreeNode *left = 0, *right = 0;
\r
1433 int new_buf_idx = data->get_child_buf_idx( node );
\r
1434 CvMat* buf = data->buf;
\r
1435 int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
\r
1437 complete_node_dir(node);
\r
1439 for( i = nl = nr = 0; i < n; i++ )
\r
1446 bool split_input_data;
\r
1447 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
\r
1448 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
\r
1450 split_input_data = node->depth + 1 < data->params.max_depth &&
\r
1451 (node->left->sample_count > data->params.min_sample_count ||
\r
1452 node->right->sample_count > data->params.min_sample_count);
\r
1454 // split ordered vars
\r
1455 for( vi = 0; vi < data->var_count; vi++ )
\r
1457 int ci = data->get_var_type(vi);
\r
1458 if (ci >= 0) continue;
\r
1460 int n1 = node->get_num_valid(vi), nr1 = 0;
\r
1462 float* values_buf = data->get_pred_float_buf();
\r
1463 const float* values = 0;
\r
1464 int* missing_buf = data->get_pred_int_buf();
\r
1465 const int* missing = 0;
\r
1466 data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
\r
1468 for( i = 0; i < n; i++ )
\r
1469 nr1 += (!missing[i] & dir[i]);
\r
1470 left->set_num_valid(vi, n1 - nr1);
\r
1471 right->set_num_valid(vi, nr1);
\r
1473 // split categorical vars, responses and cv_labels using new_idx relocation table
\r
1474 for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
\r
1476 int ci = data->get_var_type(vi);
\r
1477 if (ci < 0) continue;
\r
1479 int n1 = node->get_num_valid(vi), nr1 = 0;
\r
1481 int *src_lbls_buf = data->get_pred_int_buf();
\r
1482 const int* src_lbls = 0;
\r
1483 data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
\r
1485 for(i = 0; i < n; i++)
\r
1486 temp_buf[i] = src_lbls[i];
\r
1488 if (data->is_buf_16u)
\r
1490 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
\r
1491 ci*scount + left->offset);
\r
1492 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
\r
1493 ci*scount + right->offset);
\r
1495 for( i = 0; i < n; i++ )
\r
1498 int idx = temp_buf[i];
\r
1501 *rdst = (unsigned short)idx;
\r
1503 nr1 += (idx != 65535);
\r
1507 *ldst = (unsigned short)idx;
\r
1512 if( vi < data->var_count )
\r
1514 left->set_num_valid(vi, n1 - nr1);
\r
1515 right->set_num_valid(vi, nr1);
\r
1520 int *ldst = buf->data.i + left->buf_idx*buf->cols +
\r
1521 ci*scount + left->offset;
\r
1522 int *rdst = buf->data.i + right->buf_idx*buf->cols +
\r
1523 ci*scount + right->offset;
\r
1525 for( i = 0; i < n; i++ )
\r
1528 int idx = temp_buf[i];
\r
1533 nr1 += (idx >= 0);
\r
1543 if( vi < data->var_count )
\r
1545 left->set_num_valid(vi, n1 - nr1);
\r
1546 right->set_num_valid(vi, nr1);
\r
1552 // split sample indices
\r
1553 int *sample_idx_src_buf = data->get_sample_idx_buf();
\r
1554 const int* sample_idx_src = 0;
\r
1555 if (split_input_data)
\r
1557 data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
\r
1559 for(i = 0; i < n; i++)
\r
1560 temp_buf[i] = sample_idx_src[i];
\r
1562 int pos = data->get_work_var_count();
\r
1564 if (data->is_buf_16u)
\r
1566 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
\r
1567 pos*scount + left->offset);
\r
1568 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
\r
1569 pos*scount + right->offset);
\r
1571 for (i = 0; i < n; i++)
\r
1574 unsigned short idx = (unsigned short)temp_buf[i];
\r
1589 int* ldst = buf->data.i + left->buf_idx*buf->cols +
\r
1590 pos*scount + left->offset;
\r
1591 int* rdst = buf->data.i + right->buf_idx*buf->cols +
\r
1592 pos*scount + right->offset;
\r
1593 for (i = 0; i < n; i++)
\r
1596 int idx = temp_buf[i];
\r
1611 // deallocate the parent node data that is not needed anymore
\r
1612 data->free_node_data(node);
\r
1615 CvERTrees::CvERTrees()
\r
1619 CvERTrees::~CvERTrees()
\r
1623 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
\r
1624 const CvMat* _responses, const CvMat* _var_idx,
\r
1625 const CvMat* _sample_idx, const CvMat* _var_type,
\r
1626 const CvMat* _missing_mask, CvRTParams params )
\r
1628 bool result = false;
\r
1630 CV_FUNCNAME("CvERTrees::train");
\r
1632 int var_count = 0;
\r
1636 CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
\r
1637 params.regression_accuracy, params.use_surrogates, params.max_categories,
\r
1638 params.cv_folds, params.use_1se_rule, false, params.priors );
\r
1640 data = new CvERTreeTrainData();
\r
1641 CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
\r
1642 _sample_idx, _var_type, _missing_mask, tree_params, true));
\r
1644 var_count = data->var_count;
\r
1645 if( params.nactive_vars > var_count )
\r
1646 params.nactive_vars = var_count;
\r
1647 else if( params.nactive_vars == 0 )
\r
1648 params.nactive_vars = (int)sqrt((double)var_count);
\r
1649 else if( params.nactive_vars < 0 )
\r
1650 CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
\r
1652 // Create mask of active variables at the tree nodes
\r
1653 CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
\r
1654 if( params.calc_var_importance )
\r
1656 CV_CALL(var_importance = cvCreateMat( 1, var_count, CV_32FC1 ));
\r
1657 cvZero(var_importance);
\r
1659 { // initialize active variables mask
\r
1660 CvMat submask1, submask2;
\r
1661 cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
\r
1662 cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
\r
1663 cvSet( &submask1, cvScalar(1) );
\r
1664 cvZero( &submask2 );
\r
1667 CV_CALL(result = grow_forest( params.term_crit ));
\r
1676 bool CvERTrees::train( CvMLData* data, CvRTParams params)
\r
1678 bool result = false;
\r
1680 CV_FUNCNAME( "CvERTrees::train" );
\r
1684 CV_CALL( result = CvRTrees::train( data, params) );
\r
1691 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
\r
1693 bool result = false;
\r
1695 CvMat* sample_idx_for_tree = 0;
\r
1697 CV_FUNCNAME("CvERTrees::grow_forest");
\r
1700 const int max_ntrees = term_crit.max_iter;
\r
1701 const double max_oob_err = term_crit.epsilon;
\r
1703 const int dims = data->var_count;
\r
1704 float maximal_response = 0;
\r
1706 CvMat* oob_sample_votes = 0;
\r
1707 CvMat* oob_responses = 0;
\r
1709 float* oob_samples_perm_ptr= 0;
\r
1711 float* samples_ptr = 0;
\r
1712 uchar* missing_ptr = 0;
\r
1713 float* true_resp_ptr = 0;
\r
1714 bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
\r
1716 // oob_predictions_sum[i] = sum of predicted values for the i-th sample
\r
1717 // oob_num_of_predictions[i] = number of summands
\r
1718 // (number of predictions for the i-th sample)
\r
1719 // initialize these variable to avoid warning C4701
\r
1720 CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
\r
1721 CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
\r
1723 nsamples = data->sample_count;
\r
1724 nclasses = data->get_num_classes();
\r
1726 if ( is_oob_or_vimportance )
\r
1728 if( data->is_classifier )
\r
1730 CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
\r
1731 cvZero(oob_sample_votes);
\r
1735 // oob_responses[0,i] = oob_predictions_sum[i]
\r
1736 // = sum of predicted values for the i-th sample
\r
1737 // oob_responses[1,i] = oob_num_of_predictions[i]
\r
1738 // = number of summands (number of predictions for the i-th sample)
\r
1739 CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
\r
1740 cvZero(oob_responses);
\r
1741 cvGetRow( oob_responses, &oob_predictions_sum, 0 );
\r
1742 cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
\r
1745 CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
\r
1746 CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
\r
1747 CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
\r
1748 CV_CALL(true_resp_ptr = (float*)cvAlloc( sizeof(float)*nsamples ));
\r
1750 CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
\r
1752 double minval, maxval;
\r
1753 CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
\r
1754 cvMinMaxLoc( &responses, &minval, &maxval );
\r
1755 maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
\r
1759 trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
\r
1760 memset( trees, 0, sizeof(trees[0])*max_ntrees );
\r
1762 CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
\r
1764 for (int i = 0; i < nsamples; i++)
\r
1765 sample_idx_for_tree->data.i[i] = i;
\r
1767 while( ntrees < max_ntrees )
\r
1769 int i, oob_samples_count = 0;
\r
1770 double ncorrect_responses = 0; // used for estimation of variable importance
\r
1771 CvForestTree* tree = 0;
\r
1773 trees[ntrees] = new CvForestERTree();
\r
1774 tree = (CvForestERTree*)trees[ntrees];
\r
1775 CV_CALL(tree->train( data, 0, this ));
\r
1777 if ( is_oob_or_vimportance )
\r
1779 CvMat sample, missing;
\r
1780 // form array of OOB samples indices and get these samples
\r
1781 sample = cvMat( 1, dims, CV_32FC1, samples_ptr );
\r
1782 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
\r
1785 for( i = 0; i < nsamples; i++,
\r
1786 sample.data.fl += dims, missing.data.ptr += dims )
\r
1788 CvDTreeNode* predicted_node = 0;
\r
1790 // predict oob samples
\r
1791 if( !predicted_node )
\r
1792 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
\r
1794 if( !data->is_classifier ) //regression
\r
1796 double avg_resp, resp = predicted_node->value;
\r
1797 oob_predictions_sum.data.fl[i] += (float)resp;
\r
1798 oob_num_of_predictions.data.fl[i] += 1;
\r
1800 // compute oob error
\r
1801 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
\r
1802 avg_resp -= true_resp_ptr[i];
\r
1803 oob_error += avg_resp*avg_resp;
\r
1804 resp = (resp - true_resp_ptr[i])/maximal_response;
\r
1805 ncorrect_responses += exp( -resp*resp );
\r
1807 else //classification
\r
1809 double prdct_resp;
\r
1813 cvGetRow(oob_sample_votes, &votes, i);
\r
1814 votes.data.i[predicted_node->class_idx]++;
\r
1816 // compute oob error
\r
1817 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
\r
1819 prdct_resp = data->cat_map->data.i[max_loc.x];
\r
1820 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
\r
1822 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
\r
1824 oob_samples_count++;
\r
1826 if( oob_samples_count > 0 )
\r
1827 oob_error /= (double)oob_samples_count;
\r
1829 // estimate variable importance
\r
1830 if( var_importance && oob_samples_count > 0 )
\r
1834 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
\r
1835 for( m = 0; m < dims; m++ )
\r
1837 double ncorrect_responses_permuted = 0;
\r
1838 // randomly permute values of the m-th variable in the oob samples
\r
1839 float* mth_var_ptr = oob_samples_perm_ptr + m;
\r
1841 for( i = 0; i < nsamples; i++ )
\r
1846 i1 = cvRandInt( &rng ) % nsamples;
\r
1847 i2 = cvRandInt( &rng ) % nsamples;
\r
1848 CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
\r
1850 // turn values of (m-1)-th variable, that were permuted
\r
1851 // at the previous iteration, untouched
\r
1853 oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
\r
1856 // predict "permuted" cases and calculate the number of votes for the
\r
1857 // correct class in the variable-m-permuted oob data
\r
1858 sample = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
\r
1859 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
\r
1860 for( i = 0; i < nsamples; i++,
\r
1861 sample.data.fl += dims, missing.data.ptr += dims )
\r
1863 double predct_resp, true_resp;
\r
1865 predct_resp = tree->predict(&sample, &missing, true)->value;
\r
1866 true_resp = true_resp_ptr[i];
\r
1867 if( data->is_classifier )
\r
1868 ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
\r
1871 true_resp = (true_resp - predct_resp)/maximal_response;
\r
1872 ncorrect_responses_permuted += exp( -true_resp*true_resp );
\r
1875 var_importance->data.fl[m] += (float)(ncorrect_responses
\r
1876 - ncorrect_responses_permuted);
\r
1881 if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
\r
1884 if( var_importance )
\r
1886 for ( int vi = 0; vi < var_importance->cols; vi++ )
\r
1887 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
\r
1888 var_importance->data.fl[vi] : 0;
\r
1889 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
\r
1894 cvFree( &oob_samples_perm_ptr );
\r
1895 cvFree( &samples_ptr );
\r
1896 cvFree( &missing_ptr );
\r
1897 cvFree( &true_resp_ptr );
\r
1899 cvReleaseMat( &sample_idx_for_tree );
\r
1901 cvReleaseMat( &oob_sample_votes );
\r
1902 cvReleaseMat( &oob_responses );
\r