]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mlertrees.cpp
1c2eb48a37cf20c827f22ace30176735026b7dea
[opencv.git] / opencv / src / ml / mlertrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 \r
3   IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 \r
5   By downloading, copying, installing or using the software you agree to this license.\r
6   If you do not agree to this license, do not download, install,\r
7   copy or use the software.\r
8 \r
9 \r
10                         Intel License Agreement\r
11 \r
12  Copyright (C) 2000, Intel Corporation, all rights reserved.\r
13  Third party copyrights are property of their respective owners.\r
14 \r
15  Redistribution and use in source and binary forms, with or without modification,\r
16  are permitted provided that the following conditions are met:\r
17 \r
18    * Redistribution's of source code must retain the above copyright notice,\r
19      this list of conditions and the following disclaimer.\r
20 \r
21    * Redistribution's in binary form must reproduce the above copyright notice,\r
22      this list of conditions and the following disclaimer in the documentation\r
23      and/or other materials provided with the distribution.\r
24 \r
25    * The name of Intel Corporation may not be used to endorse or promote products\r
26      derived from this software without specific prior written permission.\r
27 \r
28  This software is provided by the copyright holders and contributors "as is" and\r
29  any express or implied warranties, including, but not limited to, the implied\r
30  warranties of merchantability and fitness for a particular purpose are disclaimed.\r
31  In no event shall the Intel Corporation or contributors be liable for any direct,\r
32  indirect, incidental, special, exemplary, or consequential damages\r
33  (including, but not limited to, procurement of substitute goods or services;\r
34  loss of use, data, or profits; or business interruption) however caused\r
35  and on any theory of liability, whether in contract, strict liability,\r
36  or tort (including negligence or otherwise) arising in any way out of\r
37  the use of this software, even if advised of the possibility of such damage.\r
38 \r
39 M*/\r
40 \r
41 #include "_ml.h"\r
42 \r
43 static const float ord_nan = FLT_MAX*0.5f;\r
44 static const int min_block_size = 1 << 16;\r
45 static const int block_size_delta = 1 << 10;\r
46 \r
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))\r
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )\r
49 \r
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))\r
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )\r
52 \r
53 ///\r
54 \r
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
56     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
57     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
58     bool _shared, bool _add_labels, bool _update_data )\r
59 {\r
60     CvMat* sample_indices = 0;\r
61     CvMat* var_type0 = 0;\r
62     CvMat* tmp_map = 0;\r
63     int** int_ptr = 0;\r
64     CvPair16u32s* pair16u32s_ptr = 0;\r
65     CvDTreeTrainData* data = 0;\r
66     float *_fdst = 0;\r
67     int *_idst = 0;\r
68     unsigned short* udst = 0;\r
69     int* idst = 0;\r
70 \r
71     CV_FUNCNAME( "CvERTreeTrainData::set_data" );\r
72 \r
73     __BEGIN__;\r
74     \r
75     int sample_all = 0, r_type = 0, cv_n;\r
76     int total_c_count = 0;\r
77     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
78     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step\r
79     int vi, i, size;\r
80     char err[100];\r
81     const int *sidx = 0, *vidx = 0;\r
82     \r
83     if ( _params.use_surrogates )\r
84         CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");\r
85         \r
86     if( _update_data && data_root )\r
87     {\r
88         CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");\r
89     }\r
90 \r
91     clear();\r
92 \r
93     var_all = 0;\r
94     rng = cvRNG(-1);\r
95 \r
96     CV_CALL( set_params( _params ));\r
97 \r
98     // check parameter types and sizes\r
99     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));\r
100 \r
101     train_data = _train_data;\r
102     responses = _responses;\r
103     missing_mask = _missing_mask;\r
104 \r
105     if( _tflag == CV_ROW_SAMPLE )\r
106     {\r
107         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
108         dv_step = 1;\r
109         if( _missing_mask )\r
110             ms_step = _missing_mask->step, mv_step = 1;\r
111     }\r
112     else\r
113     {\r
114         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
115         ds_step = 1;\r
116         if( _missing_mask )\r
117             mv_step = _missing_mask->step, ms_step = 1;\r
118     }\r
119     tflag = _tflag;\r
120 \r
121     sample_count = sample_all;\r
122     var_count = var_all;\r
123 \r
124     if( _sample_idx )\r
125     {\r
126         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));\r
127         sidx = sample_indices->data.i;\r
128         sample_count = sample_indices->rows + sample_indices->cols - 1;\r
129     }\r
130 \r
131     if( _var_idx )\r
132     {\r
133         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));\r
134         vidx = var_idx->data.i;\r
135         var_count = var_idx->rows + var_idx->cols - 1;\r
136     }\r
137 \r
138     if( !CV_IS_MAT(_responses) ||\r
139         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
140          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
141         (_responses->rows != 1 && _responses->cols != 1) ||\r
142         _responses->rows + _responses->cols - 1 != sample_all )\r
143         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
144                   "floating-point vector containing as many elements as "\r
145                   "the total number of samples in the training data matrix" );\r
146    \r
147     is_buf_16u = false;\r
148     if ( sample_count < 65536 )  \r
149         is_buf_16u = true;                                \r
150     \r
151   \r
152     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));\r
153 \r
154     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));\r
155    \r
156     \r
157     cat_var_count = 0;\r
158     ord_var_count = -1;\r
159 \r
160     is_classifier = r_type == CV_VAR_CATEGORICAL;\r
161 \r
162     // step 0. calc the number of categorical vars\r
163     for( vi = 0; vi < var_count; vi++ )\r
164     {\r
165         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?\r
166             cat_var_count++ : ord_var_count--;\r
167     }\r
168 \r
169     ord_var_count = ~ord_var_count;\r
170     cv_n = params.cv_folds;\r
171     // set the two last elements of var_type array to be able\r
172     // to locate responses and cross-validation labels using\r
173     // the corresponding get_* functions.\r
174     var_type->data.i[var_count] = cat_var_count;\r
175     var_type->data.i[var_count+1] = cat_var_count+1;\r
176 \r
177     // in case of single ordered predictor we need dummy cv_labels\r
178     // for safe split_node_data() operation\r
179     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;\r
180 \r
181     work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);\r
182     buf_size = (work_var_count + 1)*sample_count;\r
183     shared = _shared;\r
184     buf_count = shared ? 2 : 1;\r
185     \r
186     if ( is_buf_16u )\r
187     {\r
188         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));\r
189         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));\r
190     }\r
191     else\r
192     {\r
193         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));\r
194         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));\r
195     }    \r
196 \r
197     size = is_classifier ? cat_var_count+1 : cat_var_count;\r
198     size = !size ? 1 : size;\r
199     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));\r
200     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));\r
201     \r
202     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;\r
203     size = !size ? 1 : size;\r
204     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));\r
205 \r
206     // now calculate the maximum size of split,\r
207     // create memory storage that will keep nodes and splits of the decision tree\r
208     // allocate root node and the buffer for the whole training data\r
209     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
210         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));\r
211     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
212     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
213     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
214     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));\r
215 \r
216     nv_size = var_count*sizeof(int);\r
217     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));\r
218 \r
219     temp_block_size = nv_size;\r
220 \r
221     if( cv_n )\r
222     {\r
223         if( sample_count < cv_n*MAX(params.min_sample_count,10) )\r
224             CV_ERROR( CV_StsOutOfRange,\r
225                 "The many folds in cross-validation for such a small dataset" );\r
226 \r
227         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );\r
228         temp_block_size = MAX(temp_block_size, cv_size);\r
229     }\r
230 \r
231     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );\r
232     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));\r
233     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));\r
234     if( cv_size )\r
235         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));\r
236 \r
237     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));\r
238 \r
239     max_c_count = 1;\r
240 \r
241     _fdst = 0;\r
242     _idst = 0;\r
243     if (ord_var_count)\r
244         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));\r
245     if (is_buf_16u && (cat_var_count || is_classifier))\r
246         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));\r
247 \r
248     // transform the training data to convenient representation\r
249     for( vi = 0; vi <= var_count; vi++ )\r
250     {\r
251         int ci;\r
252         const uchar* mask = 0;\r
253         int m_step = 0, step;\r
254         const int* idata = 0;\r
255         const float* fdata = 0;\r
256         int num_valid = 0;\r
257 \r
258         if( vi < var_count ) // analyze i-th input variable\r
259         {\r
260             int vi0 = vidx ? vidx[vi] : vi;\r
261             ci = get_var_type(vi);\r
262             step = ds_step; m_step = ms_step;\r
263             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
264                 idata = _train_data->data.i + vi0*dv_step;\r
265             else\r
266                 fdata = _train_data->data.fl + vi0*dv_step;\r
267             if( _missing_mask )\r
268                 mask = _missing_mask->data.ptr + vi0*mv_step;\r
269         }\r
270         else // analyze _responses\r
271         {\r
272             ci = cat_var_count;\r
273             step = CV_IS_MAT_CONT(_responses->type) ?\r
274                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);\r
275             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )\r
276                 idata = _responses->data.i;\r
277             else\r
278                 fdata = _responses->data.fl;\r
279         }\r
280 \r
281         if( (vi < var_count && ci>=0) ||\r
282             (vi == var_count && is_classifier) ) // process categorical variable or response\r
283         {\r
284             int c_count, prev_label;\r
285             int* c_map;\r
286             \r
287             if (is_buf_16u)\r
288                 udst = (unsigned short*)(buf->data.s + ci*sample_count);\r
289             else\r
290                 idst = buf->data.i + ci*sample_count;\r
291             \r
292             // copy data\r
293             for( i = 0; i < sample_count; i++ )\r
294             {\r
295                 int val = INT_MAX, si = sidx ? sidx[i] : i;\r
296                 if( !mask || !mask[si*m_step] )\r
297                 {\r
298                     if( idata )\r
299                         val = idata[si*step];\r
300                     else\r
301                     {\r
302                         float t = fdata[si*step];\r
303                         val = cvRound(t);\r
304                         if( val != t )\r
305                         {\r
306                             sprintf( err, "%d-th value of %d-th (categorical) "\r
307                                 "variable is not an integer", i, vi );\r
308                             CV_ERROR( CV_StsBadArg, err );\r
309                         }\r
310                     }\r
311 \r
312                     if( val == INT_MAX )\r
313                     {\r
314                         sprintf( err, "%d-th value of %d-th (categorical) "\r
315                             "variable is too large", i, vi );\r
316                         CV_ERROR( CV_StsBadArg, err );\r
317                     }\r
318                     num_valid++;\r
319                 }\r
320                 if (is_buf_16u)\r
321                 {\r
322                     _idst[i] = val;\r
323                     pair16u32s_ptr[i].u = udst + i;\r
324                     pair16u32s_ptr[i].i = _idst + i;\r
325                 }   \r
326                 else\r
327                 {\r
328                     idst[i] = val;\r
329                     int_ptr[i] = idst + i;\r
330                 }\r
331             }\r
332 \r
333             c_count = num_valid > 0;\r
334 \r
335             if (is_buf_16u)\r
336             {\r
337                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );\r
338                 // count the categories\r
339                 for( i = 1; i < num_valid; i++ )\r
340                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)\r
341                         c_count ++ ;\r
342             }\r
343             else\r
344             {\r
345                 icvSortIntPtr( int_ptr, sample_count, 0 );\r
346                 // count the categories\r
347                 for( i = 1; i < num_valid; i++ )\r
348                     c_count += *int_ptr[i] != *int_ptr[i-1];\r
349             }\r
350 \r
351             if( vi > 0 )\r
352                 max_c_count = MAX( max_c_count, c_count );\r
353             cat_count->data.i[ci] = c_count;\r
354             cat_ofs->data.i[ci] = total_c_count;\r
355 \r
356             // resize cat_map, if need\r
357             if( cat_map->cols < total_c_count + c_count )\r
358             {\r
359                 tmp_map = cat_map;\r
360                 CV_CALL( cat_map = cvCreateMat( 1,\r
361                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));\r
362                 for( i = 0; i < total_c_count; i++ )\r
363                     cat_map->data.i[i] = tmp_map->data.i[i];\r
364                 cvReleaseMat( &tmp_map );\r
365             }\r
366 \r
367             c_map = cat_map->data.i + total_c_count;\r
368             total_c_count += c_count;\r
369 \r
370             c_count = -1;\r
371             if (is_buf_16u)\r
372             {\r
373                 // compact the class indices and build the map\r
374                 prev_label = ~*pair16u32s_ptr[0].i;\r
375                 for( i = 0; i < num_valid; i++ )\r
376                 {\r
377                     int cur_label = *pair16u32s_ptr[i].i;\r
378                     if( cur_label != prev_label )\r
379                         c_map[++c_count] = prev_label = cur_label;\r
380                     *pair16u32s_ptr[i].u = (unsigned short)c_count;\r
381                 }\r
382                 // replace labels for missing values with 65535\r
383                 for( ; i < sample_count; i++ )\r
384                     *pair16u32s_ptr[i].u = 65535;\r
385             }\r
386             else\r
387             {\r
388                 // compact the class indices and build the map\r
389                 prev_label = ~*int_ptr[0];\r
390                 for( i = 0; i < num_valid; i++ )\r
391                 {\r
392                     int cur_label = *int_ptr[i];\r
393                     if( cur_label != prev_label )\r
394                         c_map[++c_count] = prev_label = cur_label;\r
395                     *int_ptr[i] = c_count;\r
396                 }\r
397                 // replace labels for missing values with -1\r
398                 for( ; i < sample_count; i++ )\r
399                     *int_ptr[i] = -1;\r
400             }           \r
401         }\r
402         else if( ci < 0 ) // process ordered variable\r
403         {\r
404             for( i = 0; i < sample_count; i++ )\r
405             {\r
406                 float val = ord_nan;\r
407                 int si = sidx ? sidx[i] : i;\r
408                 if( !mask || !mask[si*m_step] )\r
409                 {\r
410                     if( idata )\r
411                         val = (float)idata[si*step];\r
412                     else\r
413                         val = fdata[si*step];\r
414 \r
415                     if( fabs(val) >= ord_nan )\r
416                     {\r
417                         sprintf( err, "%d-th value of %d-th (ordered) "\r
418                             "variable (=%g) is too large", i, vi, val );\r
419                         CV_ERROR( CV_StsBadArg, err );\r
420                     }\r
421                 }\r
422                 num_valid++;\r
423             }\r
424         }\r
425         if( vi < var_count )\r
426             data_root->set_num_valid(vi, num_valid);\r
427     }\r
428 \r
429     // set sample labels\r
430     if (is_buf_16u)\r
431         udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);\r
432     else\r
433         idst = buf->data.i + get_work_var_count()*sample_count;\r
434 \r
435     for (i = 0; i < sample_count; i++)\r
436     {\r
437         if (udst)\r
438             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;\r
439         else\r
440             idst[i] = sidx ? sidx[i] : i;\r
441     }\r
442 \r
443     if( cv_n )\r
444     {\r
445         unsigned short* udst = 0;\r
446         int* idst = 0;\r
447         CvRNG* r = &rng;\r
448 \r
449         if (is_buf_16u)\r
450         {\r
451             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);\r
452             for( i = vi = 0; i < sample_count; i++ )\r
453             {\r
454                 udst[i] = (unsigned short)vi++;\r
455                 vi &= vi < cv_n ? -1 : 0;\r
456             }\r
457 \r
458             for( i = 0; i < sample_count; i++ )\r
459             {\r
460                 int a = cvRandInt(r) % sample_count;\r
461                 int b = cvRandInt(r) % sample_count;\r
462                 unsigned short unsh = (unsigned short)vi;\r
463                 CV_SWAP( udst[a], udst[b], unsh );\r
464             }\r
465         }\r
466         else\r
467         {\r
468             idst = buf->data.i + (get_work_var_count()-1)*sample_count;\r
469             for( i = vi = 0; i < sample_count; i++ )\r
470             {\r
471                 idst[i] = vi++;\r
472                 vi &= vi < cv_n ? -1 : 0;\r
473             }\r
474 \r
475             for( i = 0; i < sample_count; i++ )\r
476             {\r
477                 int a = cvRandInt(r) % sample_count;\r
478                 int b = cvRandInt(r) % sample_count;\r
479                 CV_SWAP( idst[a], idst[b], vi );\r
480             }\r
481         }\r
482     }\r
483 \r
484     if ( cat_map ) \r
485         cat_map->cols = MAX( total_c_count, 1 );\r
486 \r
487     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
488         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
489     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));\r
490 \r
491     have_priors = is_classifier && params.priors;\r
492     if( is_classifier )\r
493     {\r
494         int m = get_num_classes();\r
495         double sum = 0;\r
496         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));\r
497         for( i = 0; i < m; i++ )\r
498         {\r
499             double val = have_priors ? params.priors[i] : 1.;\r
500             if( val <= 0 )\r
501                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );\r
502             priors->data.db[i] = val;\r
503             sum += val;\r
504         }\r
505 \r
506         // normalize weights\r
507         if( have_priors )\r
508             cvScale( priors, priors, 1./sum );\r
509 \r
510         CV_CALL( priors_mult = cvCloneMat( priors ));\r
511         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));\r
512     }\r
513 \r
514     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));\r
515     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));\r
516 \r
517     {\r
518         int maxNumThreads = 1;\r
519 #ifdef _OPENMP\r
520         maxNumThreads = cv::getNumThreads();\r
521 #endif\r
522         pred_float_buf.resize(maxNumThreads);\r
523         pred_int_buf.resize(maxNumThreads);\r
524         resp_float_buf.resize(maxNumThreads);\r
525         resp_int_buf.resize(maxNumThreads);\r
526         cv_lables_buf.resize(maxNumThreads);\r
527         sample_idx_buf.resize(maxNumThreads);\r
528         for( int ti = 0; ti < maxNumThreads; ti++ )\r
529         {\r
530             pred_float_buf[ti].resize(sample_count);\r
531             pred_int_buf[ti].resize(sample_count);\r
532             resp_float_buf[ti].resize(sample_count);\r
533             resp_int_buf[ti].resize(sample_count);\r
534             cv_lables_buf[ti].resize(sample_count);\r
535             sample_idx_buf[ti].resize(sample_count);\r
536         }\r
537     }\r
538     \r
539     __END__;\r
540 \r
541     if( data )\r
542         delete data;\r
543 \r
544     if (_fdst)\r
545         cvFree( &_fdst );\r
546     if (_idst)\r
547         cvFree( &_idst );\r
548     cvFree( &int_ptr );\r
549     cvReleaseMat( &var_type0 );\r
550     cvReleaseMat( &sample_indices );\r
551     cvReleaseMat( &tmp_map );\r
552 }\r
553 \r
554 int CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, const float** ord_values, const int** missing )\r
555 {\r
556     int vidx = var_idx ? var_idx->data.i[vi] : vi;\r
557     int node_sample_count = n->sample_count; \r
558     int* sample_indices_buf = get_sample_idx_buf();\r
559     const int* sample_indices = 0;\r
560 \r
561     get_sample_indices(n, sample_indices_buf, &sample_indices);\r
562 \r
563     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);\r
564     int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;\r
565     if( tflag == CV_ROW_SAMPLE )\r
566     {\r
567         for( int i = 0; i < node_sample_count; i++ )\r
568         {\r
569             int idx = sample_indices[i];\r
570             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;\r
571             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);\r
572         }\r
573     }\r
574     else\r
575         for( int i = 0; i < node_sample_count; i++ )\r
576         {\r
577             int idx = sample_indices[i];\r
578             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;\r
579             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);\r
580         }\r
581     *ord_values = ord_values_buf;\r
582     *missing = missing_buf;\r
583     return 0; //TODO: return the number of non-missing values\r
584 }\r
585 \r
586 \r
587 void CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )\r
588 {\r
589     get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf, indices );\r
590 }\r
591 \r
592 \r
593 void CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
594 {\r
595     if (have_labels)\r
596         get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf, labels );\r
597 }\r
598 \r
599 \r
600 int CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )\r
601 {\r
602     int ci = get_var_type( vi);\r
603     if( !is_buf_16u )\r
604         *cat_values = buf->data.i + n->buf_idx*buf->cols + \r
605         ci*sample_count + n->offset;\r
606     else {\r
607         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
608             ci*sample_count + n->offset);\r
609         for( int i = 0; i < n->sample_count; i++ )\r
610             cat_values_buf[i] = short_values[i];\r
611         *cat_values = cat_values_buf;\r
612     }\r
613 \r
614     return 0; //TODO: return the number of non-missing values\r
615 }\r
616 \r
617 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,\r
618                                     float* values, uchar* missing,\r
619                                     float* responses, bool get_class_idx )\r
620 {\r
621     CvMat* subsample_idx = 0;\r
622     CvMat* subsample_co = 0;\r
623 \r
624     CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );\r
625 \r
626     __BEGIN__;\r
627 \r
628     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
629     int* sidx = 0;\r
630     int* co = 0;\r
631 \r
632     if( _subsample_idx )\r
633     {\r
634         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
635         sidx = subsample_idx->data.i;\r
636         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
637         co = subsample_co->data.i;\r
638         cvZero( subsample_co );\r
639         count = subsample_idx->cols + subsample_idx->rows - 1;\r
640         for( i = 0; i < count; i++ )\r
641             co[sidx[i]*2]++;\r
642         for( i = 0; i < total; i++ )\r
643         {\r
644             int count_i = co[i*2];\r
645             if( count_i )\r
646             {\r
647                 co[i*2+1] = cur_ofs*var_count;\r
648                 cur_ofs += count_i;\r
649             }\r
650         }\r
651     }\r
652 \r
653     if( missing )\r
654         memset( missing, 1, count*var_count );\r
655 \r
656     for( vi = 0; vi < var_count; vi++ )\r
657     {\r
658         int ci = get_var_type(vi);\r
659         if( ci >= 0 ) // categorical\r
660         {\r
661             float* dst = values + vi;\r
662             uchar* m = missing ? missing + vi : 0;\r
663             int* src_buf = get_pred_int_buf();\r
664             const int* src = 0; \r
665             get_cat_var_data(data_root, vi, src_buf, &src);\r
666 \r
667             for( i = 0; i < count; i++, dst += var_count )\r
668             {\r
669                 int idx = sidx ? sidx[i] : i;\r
670                 int val = src[idx];\r
671                 *dst = (float)val;\r
672                 if( m )\r
673                 {\r
674                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));\r
675                     m += var_count;\r
676                 }\r
677             }\r
678         }\r
679         else // ordered\r
680         {\r
681             float* dst_buf = values + vi;\r
682             int* m_buf = get_pred_int_buf();\r
683             const float *dst = 0;\r
684             const int* m = 0;\r
685             get_ord_var_data(data_root, vi, dst_buf, m_buf, &dst, &m);\r
686             for (int si = 0; si < total; si++)\r
687                 *(missing + vi + si) = m[si] == 0 ? 0 : 1;\r
688         }\r
689     }\r
690 \r
691     // copy responses\r
692     if( responses )\r
693     {\r
694         if( is_classifier )\r
695         {\r
696             int* src_buf = get_resp_int_buf();\r
697             const int* src = 0;\r
698             get_class_labels(data_root, src_buf, &src);\r
699             for( i = 0; i < count; i++ )\r
700             {\r
701                 int idx = sidx ? sidx[i] : i;\r
702                 int val = get_class_idx ? src[idx] :\r
703                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];\r
704                 responses[i] = (float)val;\r
705             }\r
706         }\r
707         else           \r
708         {\r
709             float *_values_buf = get_resp_float_buf();\r
710             const float* _values = 0;\r
711             get_ord_responses(data_root, _values_buf, &_values);\r
712             for( i = 0; i < count; i++ )\r
713             {\r
714                 int idx = sidx ? sidx[i] : i;\r
715                 responses[i] = _values[idx];\r
716             }\r
717         }\r
718     }\r
719 \r
720     __END__;\r
721 \r
722     cvReleaseMat( &subsample_idx );\r
723     cvReleaseMat( &subsample_co );\r
724 }\r
725 \r
726 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )\r
727 {\r
728     CvDTreeNode* root = 0;\r
729     \r
730     CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );\r
731 \r
732     __BEGIN__;\r
733 \r
734     if( !data_root )\r
735         CV_ERROR( CV_StsError, "No training data has been set" );\r
736 \r
737     if( !_subsample_idx )\r
738     {\r
739         // make a copy of the root node\r
740         CvDTreeNode temp;\r
741         int i;\r
742         root = new_node( 0, 1, 0, 0 );\r
743         temp = *root;\r
744         *root = *data_root;\r
745         root->num_valid = temp.num_valid;\r
746         if( root->num_valid )\r
747         {\r
748             for( i = 0; i < var_count; i++ )\r
749                 root->num_valid[i] = data_root->num_valid[i];\r
750         }\r
751         root->cv_Tn = temp.cv_Tn;\r
752         root->cv_node_risk = temp.cv_node_risk;\r
753         root->cv_node_error = temp.cv_node_error;\r
754     }\r
755     else\r
756         CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );\r
757     __END__;\r
758 \r
759     return root;\r
760 }\r
761 \r
762 double CvForestERTree::calc_node_dir( CvDTreeNode* node )\r
763 {\r
764     char* dir = (char*)data->direction->data.ptr;\r
765     int i, n = node->sample_count, vi = node->split->var_idx;\r
766     double L, R;\r
767 \r
768     assert( !node->split->inversed );\r
769 \r
770     if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
771     {\r
772         int* labels_buf = data->get_pred_int_buf();\r
773         const int* labels = 0;\r
774         const int* subset = node->split->subset;\r
775         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
776         if( !data->have_priors )\r
777         {\r
778             int sum = 0, sum_abs = 0;\r
779 \r
780             for( i = 0; i < n; i++ )\r
781             {\r
782                 int idx = labels[i];\r
783                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?\r
784                     CV_DTREE_CAT_DIR(idx,subset) : 0;\r
785                 sum += d; sum_abs += d & 1;\r
786                 dir[i] = (char)d;\r
787             }\r
788 \r
789             R = (sum_abs + sum) >> 1;\r
790             L = (sum_abs - sum) >> 1;\r
791         }\r
792         else\r
793         {\r
794             const double* priors = data->priors_mult->data.db;\r
795             double sum = 0, sum_abs = 0;\r
796             int *responses_buf = data->get_resp_int_buf();\r
797             const int* responses;\r
798             data->get_class_labels(node, responses_buf, &responses);\r
799 \r
800             for( i = 0; i < n; i++ )\r
801             {\r
802                 int idx = labels[i];\r
803                 double w = priors[responses[i]];\r
804                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;\r
805                 sum += d*w; sum_abs += (d & 1)*w;\r
806                 dir[i] = (char)d;\r
807             }\r
808 \r
809             R = (sum_abs + sum) * 0.5;\r
810             L = (sum_abs - sum) * 0.5;\r
811         }\r
812     }\r
813     else // split on ordered var\r
814     {\r
815         float split_val = node->split->ord.c;\r
816         float* val_buf = data->get_pred_float_buf();\r
817         const float* val = 0;\r
818         int* missing_buf = data->get_pred_int_buf();\r
819         const int* missing = 0;\r
820         data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing );\r
821 \r
822         if( !data->have_priors )\r
823         {\r
824             L = R = 0;\r
825             for( i = 0; i < n; i++ )\r
826             {\r
827                 if ( missing[i] )\r
828                     dir[i] = (char)0;\r
829                 else\r
830                 {\r
831                     if ( val[i] < split_val)\r
832                     {\r
833                         dir[i] = (char)-1;\r
834                         L++;\r
835                     }\r
836                     else\r
837                     {\r
838                         dir[i] = (char)1;\r
839                         R++;\r
840                     }\r
841                 }\r
842             }\r
843         }\r
844         else\r
845         {\r
846             const double* priors = data->priors_mult->data.db;\r
847             int* responses_buf = data->get_resp_int_buf();\r
848             const int* responses = 0;\r
849             data->get_class_labels(node, responses_buf, &responses);\r
850             L = R = 0;\r
851             for( i = 0; i < n; i++ )\r
852             {\r
853                 if ( missing[i] )\r
854                     dir[i] = (char)0;\r
855                 else\r
856                 {\r
857                     double w = priors[responses[i]];\r
858                     if ( val[i] < split_val)\r
859                     {\r
860                         dir[i] = (char)-1;\r
861                          L += w;\r
862                     }\r
863                     else\r
864                     {\r
865                         dir[i] = (char)1;\r
866                         R += w;\r
867                     }\r
868                 }\r
869             }\r
870         }\r
871     }\r
872 \r
873     node->maxlr = MAX( L, R );\r
874     return node->split->quality/(L + R);\r
875 }\r
876 \r
877 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
878 {\r
879     const float epsilon = FLT_EPSILON*2;\r
880     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;\r
881 \r
882     int n = node->sample_count;\r
883     int m = data->get_num_classes();\r
884 \r
885     float* values_buf = data->get_pred_float_buf();\r
886     const float* values = 0;\r
887     int* missing_buf = data->get_pred_int_buf();\r
888     const int* missing = 0;\r
889     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );\r
890     int* responses_buf =  data->get_resp_int_buf();\r
891     const int* responses = 0;\r
892     data->get_class_labels( node, responses_buf, &responses );\r
893 \r
894     double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;\r
895     \r
896     int i;\r
897 \r
898     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
899 \r
900     bool is_find_split = false;\r
901     \r
902     float pmin, pmax;\r
903     int smpi = 0;\r
904     while ( missing[smpi] && (smpi < n) )\r
905         smpi++;\r
906     assert(smpi < n);\r
907 \r
908     pmin = values[smpi];\r
909     pmax = pmin;\r
910     for (; smpi < n; smpi++)\r
911     {\r
912         float ptemp = values[smpi];\r
913         int m = missing[smpi];\r
914         if (m) continue;\r
915         if ( ptemp < pmin)\r
916             pmin = ptemp;\r
917         if ( ptemp > pmax)\r
918             pmax = ptemp;\r
919     }\r
920     float fdiff = pmax-pmin;\r
921     if (fdiff > epsilon)\r
922     {\r
923         is_find_split = true;\r
924         CvRNG* rng = &data->rng;\r
925         split_val = pmin + cvRandReal(rng) * fdiff ;\r
926         if (split_val - pmin <= FLT_EPSILON)\r
927             split_val = pmin + split_delta;\r
928         if (pmax - split_val <= FLT_EPSILON)\r
929             split_val = pmax - split_delta;       \r
930 \r
931         // calculate Gini index\r
932         if ( !priors )\r
933         {\r
934             int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
935             int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
936             int L = 0, R = 0;\r
937     \r
938             // init arrays of class instance counters on both sides of the split\r
939             for( i = 0; i < m; i++ )\r
940             {\r
941                 lc[i] = 0;\r
942                 rc[i] = 0;\r
943             }\r
944             for( int si = 0; si < n; si++ )\r
945             {\r
946                 int r = responses[si];\r
947                 float val = values[si];\r
948                 int m = missing[si];\r
949                 if (m) continue;\r
950                 if ( val < split_val )\r
951                 {\r
952                     lc[r]++;\r
953                     L++;\r
954                 }\r
955                 else\r
956                 {\r
957                     rc[r]++;\r
958                     R++;\r
959                 }\r
960             }\r
961             for (int i = 0; i < m; i++)\r
962             {\r
963                 lbest_val += lc[i]*lc[i];\r
964                 rbest_val += rc[i]*rc[i];\r
965             }\r
966             best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));\r
967         }\r
968         else\r
969         {\r
970             double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));\r
971             double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));\r
972             double L = 0, R = 0;\r
973     \r
974             // init arrays of class instance counters on both sides of the split\r
975             for( i = 0; i < m; i++ )\r
976             {\r
977                 lc[i] = 0;\r
978                 rc[i] = 0;\r
979             }\r
980             for( int si = 0; si < n; si++ )\r
981             {\r
982                 int r = responses[si];\r
983                 float val = values[si];\r
984                 int m = missing[si];\r
985                 double p = priors[si];\r
986                 if (m) continue;\r
987                 if ( val < split_val )\r
988                 {\r
989                     lc[r] += p;\r
990                     L += p;\r
991                 }\r
992                 else\r
993                 {\r
994                     rc[r] += p;\r
995                     R += p;\r
996                 }\r
997             }\r
998             for (int i = 0; i < m; i++)\r
999             {\r
1000                 lbest_val += lc[i]*lc[i];\r
1001                 rbest_val += rc[i]*rc[i];\r
1002             }\r
1003             best_val = (lbest_val*R + rbest_val*L) / (L*R);\r
1004         }\r
1005         \r
1006     }\r
1007 \r
1008     CvDTreeSplit* split = 0;\r
1009     if( is_find_split )\r
1010     {\r
1011         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );\r
1012         split->var_idx = vi;
1013         split->ord.c = (float)split_val;
1014         split->ord.split_point = -1;
1015         split->inversed = 0;
1016         split->quality = (float)best_val;
1017     }
1018     return split;\r
1019 }\r
1020 \r
1021 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
1022 {\r
1023     int ci = data->get_var_type(vi);\r
1024     int n = node->sample_count;\r
1025     int cm = data->get_num_classes(); \r
1026     int vm = data->cat_count->data.i[ci];\r
1027     double best_val = init_quality;\r
1028     CvDTreeSplit *split = 0;\r
1029 \r
1030     if ( vm > 1 )\r
1031     {\r
1032         int* labels_buf = data->get_pred_int_buf();\r
1033         const int* labels = 0;\r
1034         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1035 \r
1036         int* responses_buf =  data->get_resp_int_buf();\r
1037         const int* responses = 0;\r
1038         data->get_class_labels( node, responses_buf, &responses );\r
1039     \r
1040         const double* priors = data->have_priors ? data->priors_mult->data.db : 0;       \r
1041 \r
1042         // create random class mask\r
1043         int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));\r
1044         for (int i = 0; i < vm; i++)\r
1045         {\r
1046             valid_cidx[i] = -1;\r
1047         }\r
1048         for (int si = 0; si < n; si++)\r
1049         {\r
1050             int c = labels[si];\r
1051             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )\r
1052                 continue;\r
1053             valid_cidx[c]++;\r
1054         }\r
1055 \r
1056         int valid_ccount = 0;\r
1057         for (int i = 0; i < vm; i++)\r
1058             if (valid_cidx[i] >= 0)\r
1059             {\r
1060                 valid_cidx[i] = valid_ccount;\r
1061                 valid_ccount++;\r
1062             }\r
1063         if (valid_ccount > 1)\r
1064         {\r
1065             CvRNG* rng = forest->get_rng();\r
1066             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);\r
1067 \r
1068             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );\r
1069             CvMat submask;\r
1070             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));\r
1071             cvGetCols( var_class_mask, &submask, 0, l_cval_count );\r
1072             cvSet( &submask, cvScalar(1) );\r
1073             for (int i = 0; i < valid_ccount; i++)\r
1074             {\r
1075                 uchar temp;\r
1076                 int i1 = cvRandInt( rng ) % valid_ccount;\r
1077                 int i2 = cvRandInt( rng ) % valid_ccount;\r
1078                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );\r
1079             }\r
1080 \r
1081             split = _split ? _split : data->new_split_cat( 0, -1.0f );\r
1082             split->var_idx = vi;\r
1083             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));\r
1084 \r
1085             // calculate Gini index\r
1086             double lbest_val = 0, rbest_val = 0;\r
1087             if( !priors )\r
1088             {\r
1089                 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));\r
1090                 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));\r
1091                 int L = 0, R = 0;\r
1092                 // init arrays of class instance counters on both sides of the split\r
1093                 for(int i = 0; i < cm; i++ )\r
1094                 {\r
1095                     lc[i] = 0;\r
1096                     rc[i] = 0;\r
1097                 }\r
1098                 for( int si = 0; si < n; si++ )\r
1099                 {\r
1100                     int r = responses[si];\r
1101                     int var_class_idx = labels[si];\r
1102                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )\r
1103                         continue;\r
1104                     int mask_class_idx = valid_cidx[var_class_idx];\r
1105                     if (var_class_mask->data.ptr[mask_class_idx])\r
1106                     {\r
1107                         lc[r]++;\r
1108                         L++;                 \r
1109                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);\r
1110                     }\r
1111                     else\r
1112                     {\r
1113                         rc[r]++;\r
1114                         R++;\r
1115                     }\r
1116                 }\r
1117                 for (int i = 0; i < cm; i++)\r
1118                 {\r
1119                     lbest_val += lc[i]*lc[i];\r
1120                     rbest_val += rc[i]*rc[i];\r
1121                 }                \r
1122                 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));\r
1123             }\r
1124             else\r
1125             {\r
1126                 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));\r
1127                 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));\r
1128                 double L = 0, R = 0;\r
1129                 // init arrays of class instance counters on both sides of the split\r
1130                 for(int i = 0; i < cm; i++ )\r
1131                 {\r
1132                     lc[i] = 0;\r
1133                     rc[i] = 0;\r
1134                 }\r
1135                 for( int si = 0; si < n; si++ )\r
1136                 {\r
1137                     int r = responses[si];\r
1138                     int var_class_idx = labels[si];\r
1139                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )\r
1140                         continue;\r
1141                     double p = priors[si];\r
1142                     int mask_class_idx = valid_cidx[var_class_idx];\r
1143                     \r
1144                     if (var_class_mask->data.ptr[mask_class_idx])\r
1145                     {\r
1146                         lc[r]+=p;\r
1147                         L+=p;                 \r
1148                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);\r
1149                     }\r
1150                     else\r
1151                     {\r
1152                         rc[r]+=p;\r
1153                         R+=p;\r
1154                     }\r
1155                 }\r
1156                 for (int i = 0; i < cm; i++)\r
1157                 {\r
1158                     lbest_val += lc[i]*lc[i];\r
1159                     rbest_val += rc[i]*rc[i];\r
1160                 }\r
1161                 best_val = (lbest_val*R + rbest_val*L) / (L*R);\r
1162             }\r
1163             split->quality = (float)best_val;\r
1164 \r
1165             cvReleaseMat(&var_class_mask);\r
1166         }   \r
1167     }  \r
1168 \r
1169     return split;\r
1170 }\r
1171 \r
1172 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
1173 {\r
1174     const float epsilon = FLT_EPSILON*2;\r
1175     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;\r
1176     int n = node->sample_count;\r
1177     float* values_buf = data->get_pred_float_buf();\r
1178     const float* values = 0;\r
1179     int* missing_buf = data->get_pred_int_buf();\r
1180     const int* missing = 0;\r
1181     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );\r
1182     float* responses_buf =  data->get_resp_float_buf();\r
1183     const float* responses = 0;\r
1184     data->get_ord_responses( node, responses_buf, &responses );\r
1185 \r
1186     double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;\r
1187     int L = 0, R = 0;\r
1188 \r
1189     bool is_find_split = false;\r
1190     float pmin, pmax;\r
1191     int smpi = 0;\r
1192     while ( missing[smpi] && (smpi < n) )\r
1193         smpi++;\r
1194 \r
1195     assert(smpi < n);\r
1196 \r
1197     pmin = values[smpi];\r
1198     pmax = pmin;\r
1199     for (; smpi < n; smpi++)\r
1200     {\r
1201         float ptemp = values[smpi];\r
1202         int m = missing[smpi];\r
1203         if (m) continue;\r
1204         if ( ptemp < pmin)\r
1205             pmin = ptemp;\r
1206         if ( ptemp > pmax)\r
1207             pmax = ptemp;\r
1208     }\r
1209     float fdiff = pmax-pmin;\r
1210     if (fdiff > epsilon)\r
1211     {\r
1212         is_find_split = true;\r
1213         CvRNG* rng = &data->rng;\r
1214         split_val = pmin + cvRandReal(rng) * fdiff ;\r
1215         if (split_val - pmin <= FLT_EPSILON)\r
1216             split_val = pmin + split_delta;\r
1217         if (pmax - split_val <= FLT_EPSILON)\r
1218             split_val = pmax - split_delta;    \r
1219 \r
1220         for (int si = 0; si < n; si++)\r
1221         {\r
1222             float r = responses[si];\r
1223             float val = values[si];\r
1224             int m = missing[si];\r
1225             if (m) continue;\r
1226             if (val < split_val)\r
1227             {\r
1228                 lsum += r;\r
1229                 L++;\r
1230             }\r
1231             else\r
1232             {\r
1233                 rsum += r;\r
1234                 R++;            \r
1235             }\r
1236         }\r
1237         best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
1238     }\r
1239 \r
1240     CvDTreeSplit* split = 0;\r
1241     if( is_find_split )\r
1242     {\r
1243         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );\r
1244         split->var_idx = vi;
1245         split->ord.c = (float)split_val;
1246         split->ord.split_point = -1;
1247         split->inversed = 0;
1248         split->quality = (float)best_val;
1249     }
1250     return split;\r
1251 }\r
1252 \r
1253 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
1254 {\r
1255     int ci = data->get_var_type(vi);\r
1256     int n = node->sample_count;\r
1257     int vm = data->cat_count->data.i[ci];\r
1258     double best_val = init_quality;\r
1259     CvDTreeSplit *split = 0;\r
1260     float lsum = 0, rsum = 0;\r
1261 \r
1262     if ( vm > 1 )\r
1263     {\r
1264         int* labels_buf = data->get_pred_int_buf();\r
1265         const int* labels = 0;\r
1266         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1267 \r
1268         float* responses_buf =  data->get_resp_float_buf();\r
1269         const float* responses = 0;\r
1270         data->get_ord_responses( node, responses_buf, &responses );\r
1271 \r
1272         // create random class mask\r
1273         int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));\r
1274         for (int i = 0; i < vm; i++)\r
1275         {\r
1276             valid_cidx[i] = -1;\r
1277         }\r
1278         for (int si = 0; si < n; si++)\r
1279         {\r
1280             int c = labels[si];\r
1281             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )\r
1282                         continue;\r
1283             valid_cidx[c]++;\r
1284         }\r
1285 \r
1286         int valid_ccount = 0;\r
1287         for (int i = 0; i < vm; i++)\r
1288             if (valid_cidx[i] >= 0)\r
1289             {\r
1290                 valid_cidx[i] = valid_ccount;\r
1291                 valid_ccount++;\r
1292             }\r
1293         if (valid_ccount > 1)\r
1294         {\r
1295             CvRNG* rng = forest->get_rng();\r
1296             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);\r
1297 \r
1298             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );\r
1299             CvMat submask;\r
1300             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));\r
1301             cvGetCols( var_class_mask, &submask, 0, l_cval_count );\r
1302             cvSet( &submask, cvScalar(1) );\r
1303             for (int i = 0; i < valid_ccount; i++)\r
1304             {\r
1305                 uchar temp;\r
1306                 int i1 = cvRandInt( rng ) % valid_ccount;\r
1307                 int i2 = cvRandInt( rng ) % valid_ccount;\r
1308                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );\r
1309             }\r
1310 \r
1311             split = _split ? _split : data->new_split_cat( 0, -1.0f);\r
1312             split->var_idx = vi;\r
1313             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));\r
1314 \r
1315             int L = 0, R = 0;\r
1316             for( int si = 0; si < n; si++ )\r
1317             {\r
1318                 float r = responses[si];\r
1319                 int var_class_idx = labels[si];\r
1320                 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )\r
1321                         continue;\r
1322                 int mask_class_idx = valid_cidx[var_class_idx];\r
1323                 if (var_class_mask->data.ptr[mask_class_idx])\r
1324                 {\r
1325                     lsum += r;\r
1326                     L++;                 \r
1327                     split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);\r
1328                 }\r
1329                 else\r
1330                 {\r
1331                     rsum += r;\r
1332                     R++;\r
1333                 }\r
1334             }\r
1335             best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
1336 \r
1337             split->quality = (float)best_val;\r
1338 \r
1339             cvReleaseMat(&var_class_mask);\r
1340         }   \r
1341     }  \r
1342 \r
1343     return split;\r
1344 }\r
1345 \r
1346 //void CvForestERTree::complete_node_dir( CvDTreeNode* node )\r
1347 //{\r
1348 //    int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;\r
1349 //    int nz = n - node->get_num_valid(node->split->var_idx);\r
1350 //    char* dir = (char*)data->direction->data.ptr;\r
1351 //\r
1352 //    // try to complete direction using surrogate splits\r
1353 //    if( nz && data->params.use_surrogates )\r
1354 //    {\r
1355 //        CvDTreeSplit* split = node->split->next;\r
1356 //        for( ; split != 0 && nz; split = split->next )\r
1357 //        {\r
1358 //            int inversed_mask = split->inversed ? -1 : 0;\r
1359 //            vi = split->var_idx;\r
1360 //\r
1361 //            if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
1362 //            {\r
1363 //                int* labels_buf = data->pred_int_buf;\r
1364 //                const int* labels = 0;\r
1365 //                data->get_cat_var_data(node, vi, labels_buf, &labels);\r
1366 //                const int* subset = split->subset;\r
1367 //\r
1368 //                for( i = 0; i < n; i++ )\r
1369 //                {\r
1370 //                    int idx = labels[i];\r
1371 //                    if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))\r
1372 //                        \r
1373 //                    {\r
1374 //                        int d = CV_DTREE_CAT_DIR(idx,subset);\r
1375 //                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
1376 //                        if( --nz )\r
1377 //                            break;\r
1378 //                    }\r
1379 //                }\r
1380 //            }\r
1381 //            else // split on ordered var\r
1382 //            {\r
1383 //                float* values_buf = data->pred_float_buf;\r
1384 //                const float* values = 0;\r
1385 //                uchar* missing_buf = (uchar*)data->pred_int_buf;\r
1386 //                const uchar* missing = 0;\r
1387 //                data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );\r
1388 //                float split_val = node->split->ord.c;\r
1389 //                \r
1390 //                for( i = 0; i < n; i++ )\r
1391 //                {\r
1392 //                    if( !dir[i] && !missing[i])\r
1393 //                    {\r
1394 //                        int d = values[i] <= split_val ? -1 : 1;\r
1395 //                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
1396 //                        if( --nz )\r
1397 //                            break;\r
1398 //                    }\r
1399 //                }\r
1400 //            }\r
1401 //        }\r
1402 //    }\r
1403 //\r
1404 //    // find the default direction for the rest\r
1405 //    if( nz )\r
1406 //    {\r
1407 //        for( i = nr = 0; i < n; i++ )\r
1408 //            nr += dir[i] > 0;\r
1409 //        nl = n - nr - nz;\r
1410 //        d0 = nl > nr ? -1 : nr > nl;\r
1411 //    }\r
1412 //\r
1413 //    // make sure that every sample is directed either to the left or to the right\r
1414 //    for( i = 0; i < n; i++ )\r
1415 //    {\r
1416 //        int d = dir[i];\r
1417 //        if( !d )\r
1418 //        {\r
1419 //            d = d0;\r
1420 //            if( !d )\r
1421 //                d = d1, d1 = -d1;\r
1422 //        }\r
1423 //        d = d > 0;\r
1424 //        dir[i] = (char)d; // remap (-1,1) to (0,1)\r
1425 //    }\r
1426 //}\r
1427 \r
1428 void CvForestERTree::split_node_data( CvDTreeNode* node )\r
1429 {\r
1430     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;\r
1431     char* dir = (char*)data->direction->data.ptr;\r
1432     CvDTreeNode *left = 0, *right = 0;\r
1433     int new_buf_idx = data->get_child_buf_idx( node );\r
1434     CvMat* buf = data->buf;\r
1435     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));\r
1436 \r
1437     complete_node_dir(node);\r
1438 \r
1439     for( i = nl = nr = 0; i < n; i++ )\r
1440     {\r
1441         int d = dir[i];\r
1442         nr += d;\r
1443         nl += d^1;\r
1444     }\r
1445 \r
1446     bool split_input_data;\r
1447     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );\r
1448     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );\r
1449 \r
1450     split_input_data = node->depth + 1 < data->params.max_depth &&\r
1451         (node->left->sample_count > data->params.min_sample_count ||\r
1452         node->right->sample_count > data->params.min_sample_count);\r
1453 \r
1454     // split ordered vars\r
1455     for( vi = 0; vi < data->var_count; vi++ )\r
1456     {\r
1457         int ci = data->get_var_type(vi);\r
1458         if (ci >= 0) continue;\r
1459         \r
1460         int n1 = node->get_num_valid(vi), nr1 = 0;\r
1461 \r
1462         float* values_buf = data->get_pred_float_buf();\r
1463         const float* values = 0;\r
1464         int* missing_buf = data->get_pred_int_buf();\r
1465         const int* missing = 0;\r
1466         data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );\r
1467 \r
1468         for( i = 0; i < n; i++ )\r
1469             nr1 += (!missing[i] & dir[i]);\r
1470         left->set_num_valid(vi, n1 - nr1);\r
1471         right->set_num_valid(vi, nr1);                \r
1472     }\r
1473     // split categorical vars, responses and cv_labels using new_idx relocation table\r
1474     for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )\r
1475     {\r
1476         int ci = data->get_var_type(vi);\r
1477         if (ci < 0) continue;\r
1478 \r
1479         int n1 = node->get_num_valid(vi), nr1 = 0;\r
1480         \r
1481         int *src_lbls_buf = data->get_pred_int_buf();\r
1482         const int* src_lbls = 0;\r
1483         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);\r
1484 \r
1485         for(i = 0; i < n; i++)\r
1486             temp_buf[i] = src_lbls[i];\r
1487 \r
1488         if (data->is_buf_16u)\r
1489         {\r
1490             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + \r
1491                 ci*scount + left->offset);\r
1492             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + \r
1493                 ci*scount + right->offset);\r
1494             \r
1495             for( i = 0; i < n; i++ )\r
1496             {\r
1497                 int d = dir[i];\r
1498                 int idx = temp_buf[i];\r
1499                 if (d)\r
1500                 {\r
1501                     *rdst = (unsigned short)idx;\r
1502                     rdst++;\r
1503                     nr1 += (idx != 65535);\r
1504                 }\r
1505                 else\r
1506                 {\r
1507                     *ldst = (unsigned short)idx;\r
1508                     ldst++;\r
1509                 }\r
1510             }\r
1511 \r
1512             if( vi < data->var_count )\r
1513             {\r
1514                 left->set_num_valid(vi, n1 - nr1);\r
1515                 right->set_num_valid(vi, nr1);\r
1516             }\r
1517         }\r
1518         else\r
1519         {\r
1520             int *ldst = buf->data.i + left->buf_idx*buf->cols + \r
1521                 ci*scount + left->offset;\r
1522             int *rdst = buf->data.i + right->buf_idx*buf->cols + \r
1523                 ci*scount + right->offset;\r
1524             \r
1525             for( i = 0; i < n; i++ )\r
1526             {\r
1527                 int d = dir[i];\r
1528                 int idx = temp_buf[i];\r
1529                 if (d)\r
1530                 {\r
1531                     *rdst = idx;\r
1532                     rdst++;\r
1533                     nr1 += (idx >= 0);\r
1534                 }\r
1535                 else\r
1536                 {\r
1537                     *ldst = idx;\r
1538                     ldst++;\r
1539                 }\r
1540                 \r
1541             }\r
1542 \r
1543             if( vi < data->var_count )\r
1544             {\r
1545                 left->set_num_valid(vi, n1 - nr1);\r
1546                 right->set_num_valid(vi, nr1);\r
1547             }\r
1548         }        \r
1549     }\r
1550 \r
1551 \r
1552     // split sample indices\r
1553     int *sample_idx_src_buf = data->get_sample_idx_buf();\r
1554     const int* sample_idx_src = 0;\r
1555     if (split_input_data)\r
1556     {\r
1557         data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);\r
1558 \r
1559         for(i = 0; i < n; i++)\r
1560             temp_buf[i] = sample_idx_src[i];\r
1561 \r
1562         int pos = data->get_work_var_count();\r
1563        \r
1564         if (data->is_buf_16u)\r
1565         {\r
1566             unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
1567                 pos*scount + left->offset);\r
1568             unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + \r
1569                 pos*scount + right->offset);\r
1570             \r
1571             for (i = 0; i < n; i++)\r
1572             {\r
1573                 int d = dir[i];\r
1574                 unsigned short idx = (unsigned short)temp_buf[i];\r
1575                 if (d)\r
1576                 {\r
1577                     *rdst = idx;\r
1578                     rdst++;\r
1579                 }\r
1580                 else\r
1581                 {\r
1582                     *ldst = idx;\r
1583                     ldst++;\r
1584                 }\r
1585             }\r
1586         }\r
1587         else\r
1588         {\r
1589             int* ldst = buf->data.i + left->buf_idx*buf->cols + \r
1590                 pos*scount + left->offset;\r
1591             int* rdst = buf->data.i + right->buf_idx*buf->cols + \r
1592                 pos*scount + right->offset;\r
1593             for (i = 0; i < n; i++)\r
1594             {\r
1595                 int d = dir[i];\r
1596                 int idx = temp_buf[i];\r
1597                 if (d)\r
1598                 {\r
1599                     *rdst = idx;\r
1600                     rdst++;\r
1601                 }\r
1602                 else\r
1603                 {\r
1604                     *ldst = idx;\r
1605                     ldst++;\r
1606                 }\r
1607             }\r
1608         }\r
1609     }\r
1610     \r
1611     // deallocate the parent node data that is not needed anymore\r
1612     data->free_node_data(node);    \r
1613 }\r
1614 \r
1615 CvERTrees::CvERTrees()\r
1616 {\r
1617 }\r
1618 \r
1619 CvERTrees::~CvERTrees()\r
1620 {\r
1621 }\r
1622 \r
1623 bool CvERTrees::train( const CvMat* _train_data, int _tflag,\r
1624                         const CvMat* _responses, const CvMat* _var_idx,\r
1625                         const CvMat* _sample_idx, const CvMat* _var_type,\r
1626                         const CvMat* _missing_mask, CvRTParams params )\r
1627 {\r
1628     bool result = false;\r
1629 \r
1630     CV_FUNCNAME("CvERTrees::train");\r
1631     __BEGIN__\r
1632     int var_count = 0;\r
1633 \r
1634     clear();\r
1635 \r
1636     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,\r
1637         params.regression_accuracy, params.use_surrogates, params.max_categories,\r
1638         params.cv_folds, params.use_1se_rule, false, params.priors );\r
1639 \r
1640     data = new CvERTreeTrainData();\r
1641     CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,\r
1642         _sample_idx, _var_type, _missing_mask, tree_params, true));\r
1643 \r
1644     var_count = data->var_count;\r
1645     if( params.nactive_vars > var_count )\r
1646         params.nactive_vars = var_count;\r
1647     else if( params.nactive_vars == 0 )\r
1648         params.nactive_vars = (int)sqrt((double)var_count);\r
1649     else if( params.nactive_vars < 0 )\r
1650         CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );\r
1651 \r
1652     // Create mask of active variables at the tree nodes\r
1653     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));\r
1654     if( params.calc_var_importance )\r
1655     {\r
1656         CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));\r
1657         cvZero(var_importance);\r
1658     }\r
1659     { // initialize active variables mask\r
1660         CvMat submask1, submask2;\r
1661         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );\r
1662         cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );\r
1663         cvSet( &submask1, cvScalar(1) );\r
1664         cvZero( &submask2 );\r
1665     }\r
1666 \r
1667     CV_CALL(result = grow_forest( params.term_crit ));\r
1668 \r
1669     result = true;\r
1670 \r
1671     __END__\r
1672     return result;\r
1673     \r
1674 }\r
1675 \r
1676 bool CvERTrees::train( CvMLData* data, CvRTParams params)\r
1677 {\r
1678    bool result = false;\r
1679 \r
1680     CV_FUNCNAME( "CvERTrees::train" );\r
1681 \r
1682     __BEGIN__;\r
1683 \r
1684     CV_CALL( result = CvRTrees::train( data, params) );\r
1685 \r
1686     __END__;\r
1687 \r
1688     return result;\r
1689 }\r
1690 \r
1691 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )\r
1692 {\r
1693     bool result = false;\r
1694 \r
1695     CvMat* sample_idx_for_tree      = 0;\r
1696 \r
1697     CV_FUNCNAME("CvERTrees::grow_forest");\r
1698     __BEGIN__;\r
1699 \r
1700     const int max_ntrees = term_crit.max_iter;\r
1701     const double max_oob_err = term_crit.epsilon;\r
1702 \r
1703     const int dims = data->var_count;\r
1704     float maximal_response = 0;\r
1705 \r
1706     CvMat* oob_sample_votes        = 0;\r
1707     CvMat* oob_responses       = 0;\r
1708 \r
1709     float* oob_samples_perm_ptr= 0;\r
1710 \r
1711     float* samples_ptr     = 0;\r
1712     uchar* missing_ptr     = 0;\r
1713     float* true_resp_ptr   = 0;\r
1714     bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;\r
1715 \r
1716     // oob_predictions_sum[i] = sum of predicted values for the i-th sample\r
1717     // oob_num_of_predictions[i] = number of summands\r
1718     //                            (number of predictions for the i-th sample)\r
1719     // initialize these variable to avoid warning C4701\r
1720     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );\r
1721     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );\r
1722      \r
1723     nsamples = data->sample_count;\r
1724     nclasses = data->get_num_classes();\r
1725 \r
1726     if ( is_oob_or_vimportance )\r
1727     {\r
1728         if( data->is_classifier )\r
1729         {\r
1730             CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));\r
1731             cvZero(oob_sample_votes);\r
1732         }\r
1733         else\r
1734         {\r
1735             // oob_responses[0,i] = oob_predictions_sum[i]\r
1736             //    = sum of predicted values for the i-th sample\r
1737             // oob_responses[1,i] = oob_num_of_predictions[i]\r
1738             //    = number of summands (number of predictions for the i-th sample)\r
1739             CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));\r
1740             cvZero(oob_responses);\r
1741             cvGetRow( oob_responses, &oob_predictions_sum, 0 );\r
1742             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );\r
1743         }\r
1744         \r
1745         CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
1746         CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
1747         CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));\r
1748         CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));            \r
1749 \r
1750         CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));\r
1751         {\r
1752             double minval, maxval;\r
1753             CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);\r
1754             cvMinMaxLoc( &responses, &minval, &maxval );\r
1755             maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );\r
1756         }\r
1757     }\r
1758    \r
1759     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
1760     memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
1761 \r
1762     CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));\r
1763 \r
1764     for (int i = 0; i < nsamples; i++)\r
1765         sample_idx_for_tree->data.i[i] = i;\r
1766     ntrees = 0;\r
1767     while( ntrees < max_ntrees )\r
1768     {\r
1769         int i, oob_samples_count = 0;\r
1770         double ncorrect_responses = 0; // used for estimation of variable importance\r
1771         CvForestTree* tree = 0;\r
1772 \r
1773         trees[ntrees] = new CvForestERTree();\r
1774         tree = (CvForestERTree*)trees[ntrees];\r
1775         CV_CALL(tree->train( data, 0, this ));\r
1776 \r
1777         if ( is_oob_or_vimportance )\r
1778         {\r
1779             CvMat sample, missing;\r
1780             // form array of OOB samples indices and get these samples\r
1781             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );\r
1782             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );\r
1783 \r
1784             oob_error = 0;\r
1785             for( i = 0; i < nsamples; i++,\r
1786                 sample.data.fl += dims, missing.data.ptr += dims )\r
1787             {\r
1788                 CvDTreeNode* predicted_node = 0;\r
1789                 \r
1790                 // predict oob samples\r
1791                 if( !predicted_node )\r
1792                     CV_CALL(predicted_node = tree->predict(&sample, &missing, true));\r
1793 \r
1794                 if( !data->is_classifier ) //regression\r
1795                 {\r
1796                     double avg_resp, resp = predicted_node->value;\r
1797                     oob_predictions_sum.data.fl[i] += (float)resp;\r
1798                     oob_num_of_predictions.data.fl[i] += 1;\r
1799 \r
1800                     // compute oob error\r
1801                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];\r
1802                     avg_resp -= true_resp_ptr[i];\r
1803                     oob_error += avg_resp*avg_resp;\r
1804                     resp = (resp - true_resp_ptr[i])/maximal_response;\r
1805                     ncorrect_responses += exp( -resp*resp );\r
1806                 }\r
1807                 else //classification\r
1808                 {\r
1809                     double prdct_resp;\r
1810                     CvPoint max_loc;\r
1811                     CvMat votes;\r
1812 \r
1813                     cvGetRow(oob_sample_votes, &votes, i);\r
1814                     votes.data.i[predicted_node->class_idx]++;\r
1815 \r
1816                     // compute oob error\r
1817                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );\r
1818 \r
1819                     prdct_resp = data->cat_map->data.i[max_loc.x];\r
1820                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;\r
1821 \r
1822                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;\r
1823                 }\r
1824                 oob_samples_count++;\r
1825             }\r
1826             if( oob_samples_count > 0 )\r
1827                 oob_error /= (double)oob_samples_count;\r
1828 \r
1829             // estimate variable importance\r
1830             if( var_importance && oob_samples_count > 0 )\r
1831             {\r
1832                 int m;\r
1833 \r
1834                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));\r
1835                 for( m = 0; m < dims; m++ )\r
1836                 {\r
1837                     double ncorrect_responses_permuted = 0;\r
1838                     // randomly permute values of the m-th variable in the oob samples\r
1839                     float* mth_var_ptr = oob_samples_perm_ptr + m;\r
1840 \r
1841                     for( i = 0; i < nsamples; i++ )\r
1842                     {\r
1843                         int i1, i2;\r
1844                         float temp;\r
1845 \r
1846                         i1 = cvRandInt( &rng ) % nsamples;\r
1847                         i2 = cvRandInt( &rng ) % nsamples;\r
1848                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );\r
1849 \r
1850                         // turn values of (m-1)-th variable, that were permuted\r
1851                         // at the previous iteration, untouched\r
1852                         if( m > 1 )\r
1853                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];\r
1854                     }\r
1855 \r
1856                     // predict "permuted" cases and calculate the number of votes for the\r
1857                     // correct class in the variable-m-permuted oob data\r
1858                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );\r
1859                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );\r
1860                     for( i = 0; i < nsamples; i++,\r
1861                         sample.data.fl += dims, missing.data.ptr += dims )\r
1862                     {\r
1863                         double predct_resp, true_resp;\r
1864 \r
1865                         predct_resp = tree->predict(&sample, &missing, true)->value;\r
1866                         true_resp   = true_resp_ptr[i];\r
1867                         if( data->is_classifier )\r
1868                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;\r
1869                         else\r
1870                         {\r
1871                             true_resp = (true_resp - predct_resp)/maximal_response;\r
1872                             ncorrect_responses_permuted += exp( -true_resp*true_resp );\r
1873                         }\r
1874                     }\r
1875                     var_importance->data.fl[m] += (float)(ncorrect_responses\r
1876                         - ncorrect_responses_permuted);\r
1877                 }\r
1878             }\r
1879         }\r
1880         ntrees++;\r
1881         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )\r
1882             break;\r
1883     }\r
1884     if ( is_oob_or_vimportance )\r
1885     {\r
1886         for ( int vi = 0; vi < var_importance->cols; vi++ )\r
1887                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?\r
1888                     var_importance->data.fl[vi] : 0;\r
1889         if( var_importance )\r
1890             cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );\r
1891     }\r
1892 \r
1893     result = true;\r
1894     \r
1895     cvFree( &oob_samples_perm_ptr );\r
1896     cvFree( &samples_ptr );\r
1897     cvFree( &missing_ptr );\r
1898     cvFree( &true_resp_ptr );\r
1899     \r
1900     cvReleaseMat( &sample_idx_for_tree );\r
1901 \r
1902     cvReleaseMat( &oob_sample_votes );\r
1903     cvReleaseMat( &oob_responses );\r
1904 \r
1905     __END__;\r
1906 \r
1907     return result;\r
1908 }\r
1909 \r
1910 // End of file.\r
1911 \r