]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
fixed mll
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 //\r
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 //\r
5 //  By downloading, copying, installing or using the software you agree to this license.\r
6 //  If you do not agree to this license, do not download, install,\r
7 //  copy or use the software.\r
8 //\r
9 //\r
10 //                        Intel License Agreement\r
11 //\r
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.\r
13 // Third party copyrights are property of their respective owners.\r
14 //\r
15 // Redistribution and use in source and binary forms, with or without modification,\r
16 // are permitted provided that the following conditions are met:\r
17 //\r
18 //   * Redistribution's of source code must retain the above copyright notice,\r
19 //     this list of conditions and the following disclaimer.\r
20 //\r
21 //   * Redistribution's in binary form must reproduce the above copyright notice,\r
22 //     this list of conditions and the following disclaimer in the documentation\r
23 //     and/or other materials provided with the distribution.\r
24 //\r
25 //   * The name of Intel Corporation may not be used to endorse or promote products\r
26 //     derived from this software without specific prior written permission.\r
27 //\r
28 // This software is provided by the copyright holders and contributors "as is" and\r
29 // any express or implied warranties, including, but not limited to, the implied\r
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.\r
31 // In no event shall the Intel Corporation or contributors be liable for any direct,\r
32 // indirect, incidental, special, exemplary, or consequential damages\r
33 // (including, but not limited to, procurement of substitute goods or services;\r
34 // loss of use, data, or profits; or business interruption) however caused\r
35 // and on any theory of liability, whether in contract, strict liability,\r
36 // or tort (including negligence or otherwise) arising in any way out of\r
37 // the use of this software, even if advised of the possibility of such damage.\r
38 //\r
39 //M*/\r
40 \r
41 #include "_ml.h"\r
42 #include <ctype.h>\r
43 \r
44 static const float ord_nan = FLT_MAX*0.5f;\r
45 static const int min_block_size = 1 << 16;\r
46 static const int block_size_delta = 1 << 10;\r
47 \r
48 CvDTreeTrainData::CvDTreeTrainData()\r
49 {\r
50     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
51         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
52     tree_storage = temp_storage = 0;\r
53 \r
54     clear();\r
55 }\r
56 \r
57 \r
58 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,\r
59                       const CvMat* _responses, const CvMat* _var_idx,\r
60                       const CvMat* _sample_idx, const CvMat* _var_type,\r
61                       const CvMat* _missing_mask, const CvDTreeParams& _params,\r
62                       bool _shared, bool _add_labels )\r
63 {\r
64     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
65         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
66 \r
67     tree_storage = temp_storage = 0;\r
68 \r
69     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,\r
70               _var_type, _missing_mask, _params, _shared, _add_labels );\r
71 }\r
72 \r
73 \r
74 CvDTreeTrainData::~CvDTreeTrainData()\r
75 {\r
76     clear();\r
77 }\r
78 \r
79 \r
80 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )\r
81 {\r
82     bool ok = false;\r
83 \r
84     CV_FUNCNAME( "CvDTreeTrainData::set_params" );\r
85 \r
86     __BEGIN__;\r
87 \r
88     // set parameters\r
89     params = _params;\r
90 \r
91     if( params.max_categories < 2 )\r
92         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );\r
93     params.max_categories = MIN( params.max_categories, 15 );\r
94 \r
95     if( params.max_depth < 0 )\r
96         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );\r
97     params.max_depth = MIN( params.max_depth, 25 );\r
98 \r
99     params.min_sample_count = MAX(params.min_sample_count,1);\r
100 \r
101     if( params.cv_folds < 0 )\r
102         CV_ERROR( CV_StsOutOfRange,\r
103         "params.cv_folds should be =0 (the tree is not pruned) "\r
104         "or n>0 (tree is pruned using n-fold cross-validation)" );\r
105 \r
106     if( params.cv_folds == 1 )\r
107         params.cv_folds = 0;\r
108 \r
109     if( params.regression_accuracy < 0 )\r
110         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );\r
111 \r
112     ok = true;\r
113 \r
114     __END__;\r
115 \r
116     return ok;\r
117 }\r
118 \r
119 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))\r
120 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )\r
121 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )\r
122 \r
123 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])\r
124 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )\r
125 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )\r
126 \r
127 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))\r
128 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )\r
129 \r
130 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
131     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
132     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
133     bool _shared, bool _add_labels, bool _update_data )\r
134 {\r
135     CvMat* sample_indices = 0;\r
136     CvMat* var_type0 = 0;\r
137     CvMat* tmp_map = 0;\r
138     int** int_ptr = 0;\r
139     CvPair16u32s* pair16u32s_ptr = 0;\r
140     CvDTreeTrainData* data = 0;\r
141     float *_fdst = 0;\r
142     int *_idst = 0;\r
143     unsigned short* udst = 0;\r
144     int* idst = 0;\r
145 \r
146     CV_FUNCNAME( "CvDTreeTrainData::set_data" );\r
147 \r
148     __BEGIN__;\r
149 \r
150     int sample_all = 0, r_type = 0, cv_n;\r
151     int total_c_count = 0;\r
152     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
153     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step\r
154     int vi, i, size;\r
155     char err[100];\r
156     const int *sidx = 0, *vidx = 0;\r
157     \r
158     if( _update_data && data_root )\r
159     {\r
160         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,\r
161             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );\r
162 \r
163         // compare new and old train data\r
164         if( !(data->var_count == var_count &&\r
165             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
166             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&\r
167             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )\r
168             CV_ERROR( CV_StsBadArg,\r
169             "The new training data must have the same types and the input and output variables "\r
170             "and the same categories for categorical variables" );\r
171 \r
172         cvReleaseMat( &priors );\r
173         cvReleaseMat( &priors_mult );\r
174         cvReleaseMat( &buf );\r
175         cvReleaseMat( &direction );\r
176         cvReleaseMat( &split_buf );\r
177         cvReleaseMemStorage( &temp_storage );\r
178 \r
179         priors = data->priors; data->priors = 0;\r
180         priors_mult = data->priors_mult; data->priors_mult = 0;\r
181         buf = data->buf; data->buf = 0;\r
182         buf_count = data->buf_count; buf_size = data->buf_size;\r
183         sample_count = data->sample_count;\r
184 \r
185         direction = data->direction; data->direction = 0;\r
186         split_buf = data->split_buf; data->split_buf = 0;\r
187         temp_storage = data->temp_storage; data->temp_storage = 0;\r
188         nv_heap = data->nv_heap; cv_heap = data->cv_heap;\r
189 \r
190         data_root = new_node( 0, sample_count, 0, 0 );\r
191         EXIT;\r
192     }\r
193 \r
194     clear();\r
195 \r
196     var_all = 0;\r
197     rng = cvRNG(-1);\r
198 \r
199     CV_CALL( set_params( _params ));\r
200 \r
201     // check parameter types and sizes\r
202     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));\r
203 \r
204     train_data = _train_data;\r
205     responses = _responses;\r
206 \r
207     if( _tflag == CV_ROW_SAMPLE )\r
208     {\r
209         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
210         dv_step = 1;\r
211         if( _missing_mask )\r
212             ms_step = _missing_mask->step, mv_step = 1;\r
213     }\r
214     else\r
215     {\r
216         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
217         ds_step = 1;\r
218         if( _missing_mask )\r
219             mv_step = _missing_mask->step, ms_step = 1;\r
220     }\r
221     tflag = _tflag;\r
222 \r
223     sample_count = sample_all;\r
224     var_count = var_all;\r
225     \r
226     if( _sample_idx )\r
227     {\r
228         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));\r
229         sidx = sample_indices->data.i;\r
230         sample_count = sample_indices->rows + sample_indices->cols - 1;\r
231     }\r
232 \r
233     if( _var_idx )\r
234     {\r
235         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));\r
236         vidx = var_idx->data.i;\r
237         var_count = var_idx->rows + var_idx->cols - 1;\r
238     }\r
239 \r
240     is_buf_16u = false;     \r
241     if ( sample_count < 65536 ) \r
242         is_buf_16u = true;                                \r
243     \r
244     if( !CV_IS_MAT(_responses) ||\r
245         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
246          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
247         (_responses->rows != 1 && _responses->cols != 1) ||\r
248         _responses->rows + _responses->cols - 1 != sample_all )\r
249         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
250                   "floating-point vector containing as many elements as "\r
251                   "the total number of samples in the training data matrix" );\r
252    \r
253   \r
254     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));\r
255 \r
256     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));\r
257    \r
258     \r
259     cat_var_count = 0;\r
260     ord_var_count = -1;\r
261 \r
262     is_classifier = r_type == CV_VAR_CATEGORICAL;\r
263 \r
264     // step 0. calc the number of categorical vars\r
265     for( vi = 0; vi < var_count; vi++ )\r
266     {\r
267         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?\r
268             cat_var_count++ : ord_var_count--;\r
269     }\r
270 \r
271     ord_var_count = ~ord_var_count;\r
272     cv_n = params.cv_folds;\r
273     // set the two last elements of var_type array to be able\r
274     // to locate responses and cross-validation labels using\r
275     // the corresponding get_* functions.\r
276     var_type->data.i[var_count] = cat_var_count;\r
277     var_type->data.i[var_count+1] = cat_var_count+1;\r
278 \r
279     // in case of single ordered predictor we need dummy cv_labels\r
280     // for safe split_node_data() operation\r
281     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;\r
282 \r
283     work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);\r
284     buf_size = (work_var_count + 1)*sample_count;\r
285     shared = _shared;\r
286     buf_count = shared ? 2 : 1;\r
287     \r
288     if ( is_buf_16u )\r
289     {\r
290         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));\r
291         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));\r
292     }\r
293     else\r
294     {\r
295         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));\r
296         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));\r
297     }    \r
298 \r
299     size = is_classifier ? (cat_var_count+1) : cat_var_count;\r
300     size = !size ? 1 : size;\r
301     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));\r
302     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));\r
303         \r
304     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;\r
305     size = !size ? 1 : size;\r
306     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));\r
307 \r
308     // now calculate the maximum size of split,\r
309     // create memory storage that will keep nodes and splits of the decision tree\r
310     // allocate root node and the buffer for the whole training data\r
311     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
312         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));\r
313     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
314     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
315     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
316     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));\r
317 \r
318     nv_size = var_count*sizeof(int);\r
319     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));\r
320 \r
321     temp_block_size = nv_size;\r
322 \r
323     if( cv_n )\r
324     {\r
325         if( sample_count < cv_n*MAX(params.min_sample_count,10) )\r
326             CV_ERROR( CV_StsOutOfRange,\r
327                 "The many folds in cross-validation for such a small dataset" );\r
328 \r
329         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );\r
330         temp_block_size = MAX(temp_block_size, cv_size);\r
331     }\r
332 \r
333     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );\r
334     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));\r
335     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));\r
336     if( cv_size )\r
337         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));\r
338 \r
339     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));\r
340 \r
341     max_c_count = 1;\r
342 \r
343     _fdst = 0;\r
344     _idst = 0;\r
345     if (ord_var_count)\r
346         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));\r
347     if (is_buf_16u && (cat_var_count || is_classifier))\r
348         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));\r
349 \r
350     // transform the training data to convenient representation\r
351     for( vi = 0; vi <= var_count; vi++ )\r
352     {\r
353         int ci;\r
354         const uchar* mask = 0;\r
355         int m_step = 0, step;\r
356         const int* idata = 0;\r
357         const float* fdata = 0;\r
358         int num_valid = 0;\r
359 \r
360         if( vi < var_count ) // analyze i-th input variable\r
361         {\r
362             int vi0 = vidx ? vidx[vi] : vi;\r
363             ci = get_var_type(vi);\r
364             step = ds_step; m_step = ms_step;\r
365             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
366                 idata = _train_data->data.i + vi0*dv_step;\r
367             else\r
368                 fdata = _train_data->data.fl + vi0*dv_step;\r
369             if( _missing_mask )\r
370                 mask = _missing_mask->data.ptr + vi0*mv_step;\r
371         }\r
372         else // analyze _responses\r
373         {\r
374             ci = cat_var_count;\r
375             step = CV_IS_MAT_CONT(_responses->type) ?\r
376                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);\r
377             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )\r
378                 idata = _responses->data.i;\r
379             else\r
380                 fdata = _responses->data.fl;\r
381         }\r
382 \r
383         if( (vi < var_count && ci>=0) ||\r
384             (vi == var_count && is_classifier) ) // process categorical variable or response\r
385         {\r
386             int c_count, prev_label;\r
387             int* c_map;\r
388             \r
389             if (is_buf_16u)\r
390                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
391             else\r
392                 idst = buf->data.i + vi*sample_count;\r
393             \r
394             // copy data\r
395             for( i = 0; i < sample_count; i++ )\r
396             {\r
397                 int val = INT_MAX, si = sidx ? sidx[i] : i;\r
398                 if( !mask || !mask[si*m_step] )\r
399                 {\r
400                     if( idata )\r
401                         val = idata[si*step];\r
402                     else\r
403                     {\r
404                         float t = fdata[si*step];\r
405                         val = cvRound(t);\r
406                         if( val != t )\r
407                         {\r
408                             sprintf( err, "%d-th value of %d-th (categorical) "\r
409                                 "variable is not an integer", i, vi );\r
410                             CV_ERROR( CV_StsBadArg, err );\r
411                         }\r
412                     }\r
413 \r
414                     if( val == INT_MAX )\r
415                     {\r
416                         sprintf( err, "%d-th value of %d-th (categorical) "\r
417                             "variable is too large", i, vi );\r
418                         CV_ERROR( CV_StsBadArg, err );\r
419                     }\r
420                     num_valid++;\r
421                 }\r
422                 if (is_buf_16u)\r
423                 {\r
424                     _idst[i] = val;\r
425                     pair16u32s_ptr[i].u = udst + i;\r
426                     pair16u32s_ptr[i].i = _idst + i;\r
427                 }   \r
428                 else\r
429                 {\r
430                     idst[i] = val;\r
431                     int_ptr[i] = idst + i;\r
432                 }\r
433             }\r
434 \r
435             c_count = num_valid > 0;\r
436             if (is_buf_16u)\r
437             {\r
438                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );\r
439                 // count the categories\r
440                 for( i = 1; i < num_valid; i++ )\r
441                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)\r
442                         c_count ++ ;\r
443             }\r
444             else\r
445             {\r
446                 icvSortIntPtr( int_ptr, sample_count, 0 );\r
447                 // count the categories\r
448                 for( i = 1; i < num_valid; i++ )\r
449                     c_count += *int_ptr[i] != *int_ptr[i-1];\r
450             }\r
451 \r
452             if( vi > 0 )\r
453                 max_c_count = MAX( max_c_count, c_count );\r
454             cat_count->data.i[ci] = c_count;\r
455             cat_ofs->data.i[ci] = total_c_count;\r
456 \r
457             // resize cat_map, if need\r
458             if( cat_map->cols < total_c_count + c_count )\r
459             {\r
460                 tmp_map = cat_map;\r
461                 CV_CALL( cat_map = cvCreateMat( 1,\r
462                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));\r
463                 for( i = 0; i < total_c_count; i++ )\r
464                     cat_map->data.i[i] = tmp_map->data.i[i];\r
465                 cvReleaseMat( &tmp_map );\r
466             }\r
467 \r
468             c_map = cat_map->data.i + total_c_count;\r
469             total_c_count += c_count;\r
470 \r
471             c_count = -1;\r
472             if (is_buf_16u)\r
473             {\r
474                 // compact the class indices and build the map\r
475                 prev_label = ~*pair16u32s_ptr[0].i;\r
476                 for( i = 0; i < num_valid; i++ )\r
477                 {\r
478                     int cur_label = *pair16u32s_ptr[i].i;\r
479                     if( cur_label != prev_label )\r
480                         c_map[++c_count] = prev_label = cur_label;\r
481                     *pair16u32s_ptr[i].u = (unsigned short)c_count;\r
482                 }\r
483                 // replace labels for missing values with -1\r
484                 for( ; i < sample_count; i++ )\r
485                     *pair16u32s_ptr[i].u = 65535;\r
486             }\r
487             else\r
488             {\r
489                 // compact the class indices and build the map\r
490                 prev_label = ~*int_ptr[0];\r
491                 for( i = 0; i < num_valid; i++ )\r
492                 {\r
493                     int cur_label = *int_ptr[i];\r
494                     if( cur_label != prev_label )\r
495                         c_map[++c_count] = prev_label = cur_label;\r
496                     *int_ptr[i] = c_count;\r
497                 }\r
498                 // replace labels for missing values with -1\r
499                 for( ; i < sample_count; i++ )\r
500                     *int_ptr[i] = -1;\r
501             }           \r
502         }\r
503         else if( ci < 0 ) // process ordered variable\r
504         {\r
505             if (is_buf_16u)\r
506                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
507             else\r
508                 idst = buf->data.i + vi*sample_count;\r
509 \r
510             for( i = 0; i < sample_count; i++ )\r
511             {\r
512                 float val = ord_nan;\r
513                 int si = sidx ? sidx[i] : i;\r
514                 if( !mask || !mask[si*m_step] )\r
515                 {\r
516                     if( idata )\r
517                         val = (float)idata[si*step];\r
518                     else\r
519                         val = fdata[si*step];\r
520 \r
521                     if( fabs(val) >= ord_nan )\r
522                     {\r
523                         sprintf( err, "%d-th value of %d-th (ordered) "\r
524                             "variable (=%g) is too large", i, vi, val );\r
525                         CV_ERROR( CV_StsBadArg, err );\r
526                     }\r
527                 }\r
528                 num_valid++;\r
529                 if (is_buf_16u)\r
530                     udst[i] = (unsigned short)i;\r
531                 else\r
532                     idst[i] = i; // Ã¯Ã¥Ã°Ã¥Ã­Ã¥Ã±Ã²Ã¨ Ã¢Ã»Ã¸Ã¥ Ã¢ if( idata )\r
533                 _fdst[i] = val;\r
534                 \r
535             }\r
536             if (is_buf_16u)\r
537                 icvSortUShAux( udst, num_valid, _fdst);\r
538             else\r
539                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );\r
540         }\r
541        \r
542         if( vi < var_count )\r
543             data_root->set_num_valid(vi, num_valid);\r
544     }\r
545 \r
546     // set sample labels\r
547     if (is_buf_16u)\r
548         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);\r
549     else\r
550         idst = buf->data.i + work_var_count*sample_count;\r
551 \r
552     for (i = 0; i < sample_count; i++)\r
553     {\r
554         if (udst)\r
555             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;\r
556         else\r
557             idst[i] = sidx ? sidx[i] : i;\r
558     }\r
559 \r
560     if( cv_n )\r
561     {\r
562         unsigned short* udst = 0;\r
563         int* idst = 0;\r
564         CvRNG* r = &rng;\r
565 \r
566         if (is_buf_16u)\r
567         {\r
568             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);\r
569             for( i = vi = 0; i < sample_count; i++ )\r
570             {\r
571                 udst[i] = (unsigned short)vi++;\r
572                 vi &= vi < cv_n ? -1 : 0;\r
573             }\r
574 \r
575             for( i = 0; i < sample_count; i++ )\r
576             {\r
577                 int a = cvRandInt(r) % sample_count;\r
578                 int b = cvRandInt(r) % sample_count;\r
579                 unsigned short unsh = (unsigned short)vi;\r
580                 CV_SWAP( udst[a], udst[b], unsh );\r
581             }\r
582         }\r
583         else\r
584         {\r
585             idst = buf->data.i + (get_work_var_count()-1)*sample_count;\r
586             for( i = vi = 0; i < sample_count; i++ )\r
587             {\r
588                 idst[i] = vi++;\r
589                 vi &= vi < cv_n ? -1 : 0;\r
590             }\r
591 \r
592             for( i = 0; i < sample_count; i++ )\r
593             {\r
594                 int a = cvRandInt(r) % sample_count;\r
595                 int b = cvRandInt(r) % sample_count;\r
596                 CV_SWAP( idst[a], idst[b], vi );\r
597             }\r
598         }\r
599     }\r
600 \r
601     if ( cat_map ) \r
602         cat_map->cols = MAX( total_c_count, 1 );\r
603 \r
604     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
605         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
606     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));\r
607 \r
608     have_priors = is_classifier && params.priors;\r
609     if( is_classifier )\r
610     {\r
611         int m = get_num_classes();\r
612         double sum = 0;\r
613         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));\r
614         for( i = 0; i < m; i++ )\r
615         {\r
616             double val = have_priors ? params.priors[i] : 1.;\r
617             if( val <= 0 )\r
618                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );\r
619             priors->data.db[i] = val;\r
620             sum += val;\r
621         }\r
622 \r
623         // normalize weights\r
624         if( have_priors )\r
625             cvScale( priors, priors, 1./sum );\r
626 \r
627         CV_CALL( priors_mult = cvCloneMat( priors ));\r
628         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));\r
629     }\r
630 \r
631 \r
632     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));\r
633     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));\r
634 \r
635     {\r
636         int maxNumThreads = 1;\r
637 #ifdef _OPENMP\r
638         maxNumThreads = cv::getNumThreads();\r
639 #endif\r
640         pred_float_buf.resize(maxNumThreads);\r
641         pred_int_buf.resize(maxNumThreads);\r
642         resp_float_buf.resize(maxNumThreads);\r
643         resp_int_buf.resize(maxNumThreads);\r
644         cv_lables_buf.resize(maxNumThreads);\r
645         sample_idx_buf.resize(maxNumThreads);\r
646         for( int ti = 0; ti < maxNumThreads; ti++ )\r
647         {\r
648             pred_float_buf[ti].resize(sample_count);\r
649             pred_int_buf[ti].resize(sample_count);\r
650             resp_float_buf[ti].resize(sample_count);\r
651             resp_int_buf[ti].resize(sample_count);\r
652             cv_lables_buf[ti].resize(sample_count);\r
653             sample_idx_buf[ti].resize(sample_count);\r
654         }\r
655     }\r
656 \r
657     __END__;\r
658 \r
659     if( data )\r
660         delete data;\r
661 \r
662     if (_fdst)\r
663         cvFree( &_fdst );\r
664     if (_idst)\r
665         cvFree( &_idst );\r
666     cvFree( &int_ptr );\r
667     cvReleaseMat( &var_type0 );\r
668     cvReleaseMat( &sample_indices );\r
669     cvReleaseMat( &tmp_map );\r
670 }\r
671 \r
672 void CvDTreeTrainData::do_responses_copy()\r
673 {\r
674     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );\r
675     cvCopy( responses, responses_copy);\r
676     responses = responses_copy;\r
677 }\r
678 \r
679 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )\r
680 {\r
681     CvDTreeNode* root = 0;\r
682     CvMat* isubsample_idx = 0;\r
683     CvMat* subsample_co = 0;\r
684 \r
685     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );\r
686 \r
687     __BEGIN__;\r
688 \r
689     if( !data_root )\r
690         CV_ERROR( CV_StsError, "No training data has been set" );\r
691 \r
692     if( _subsample_idx )\r
693         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
694 \r
695     if( !isubsample_idx )\r
696     {\r
697         // make a copy of the root node\r
698         CvDTreeNode temp;\r
699         int i;\r
700         root = new_node( 0, 1, 0, 0 );\r
701         temp = *root;\r
702         *root = *data_root;\r
703         root->num_valid = temp.num_valid;\r
704         if( root->num_valid )\r
705         {\r
706             for( i = 0; i < var_count; i++ )\r
707                 root->num_valid[i] = data_root->num_valid[i];\r
708         }\r
709         root->cv_Tn = temp.cv_Tn;\r
710         root->cv_node_risk = temp.cv_node_risk;\r
711         root->cv_node_error = temp.cv_node_error;\r
712     }\r
713     else\r
714     {\r
715         int* sidx = isubsample_idx->data.i;\r
716         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)\r
717         int* co, cur_ofs = 0;\r
718         int vi, i;\r
719         int work_var_count = get_work_var_count();\r
720         int count = isubsample_idx->rows + isubsample_idx->cols - 1;\r
721 \r
722         root = new_node( 0, count, 1, 0 );\r
723 \r
724         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
725         cvZero( subsample_co );\r
726         co = subsample_co->data.i;\r
727         for( i = 0; i < count; i++ )\r
728             co[sidx[i]*2]++;\r
729         for( i = 0; i < sample_count; i++ )\r
730         {\r
731             if( co[i*2] )\r
732             {\r
733                 co[i*2+1] = cur_ofs;\r
734                 cur_ofs += co[i*2];\r
735             }\r
736             else\r
737                 co[i*2+1] = -1;\r
738         }\r
739 \r
740         for( vi = 0; vi < work_var_count; vi++ )\r
741         {\r
742             int ci = get_var_type(vi);\r
743 \r
744             if( ci >= 0 || vi >= var_count )\r
745             {\r
746                 int* src_buf = get_pred_int_buf();\r
747                 const int* src = 0;\r
748                 int num_valid = 0;\r
749                 \r
750                 get_cat_var_data( data_root, vi, src_buf, &src );\r
751 \r
752                 if (is_buf_16u)\r
753                 {\r
754                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
755                         vi*sample_count + root->offset);\r
756                     for( i = 0; i < count; i++ )\r
757                     {\r
758                         int val = src[sidx[i]];\r
759                         udst[i] = (unsigned short)val;\r
760                         num_valid += val >= 0;\r
761                     }\r
762                 }\r
763                 else\r
764                 {\r
765                     int* idst = buf->data.i + root->buf_idx*buf->cols + \r
766                         vi*sample_count + root->offset;\r
767                     for( i = 0; i < count; i++ )\r
768                     {\r
769                         int val = src[sidx[i]];\r
770                         idst[i] = val;\r
771                         num_valid += val >= 0;\r
772                     }\r
773                 }\r
774 \r
775                 if( vi < var_count )\r
776                     root->set_num_valid(vi, num_valid);\r
777             }\r
778             else\r
779             {\r
780                 int *src_idx_buf = get_pred_int_buf();\r
781                 const int* src_idx = 0;\r
782                 float *src_val_buf = get_pred_float_buf();\r
783                 const float* src_val = 0;\r
784                 int j = 0, idx, count_i;\r
785                 int num_valid = data_root->get_num_valid(vi);\r
786 \r
787                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );\r
788                 if (is_buf_16u)\r
789                 {\r
790                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
791                         vi*sample_count + data_root->offset);\r
792                     for( i = 0; i < num_valid; i++ )\r
793                     {\r
794                         idx = src_idx[i];\r
795                         count_i = co[idx*2];\r
796                         if( count_i )\r
797                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
798                                 udst_idx[j] = (unsigned short)cur_ofs;\r
799                     }\r
800 \r
801                     root->set_num_valid(vi, j);\r
802 \r
803                     for( ; i < sample_count; i++ )\r
804                     {\r
805                         idx = src_idx[i];\r
806                         count_i = co[idx*2];\r
807                         if( count_i )\r
808                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
809                                 udst_idx[j] = (unsigned short)cur_ofs;\r
810                     }\r
811                 }\r
812                 else\r
813                 {\r
814                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + \r
815                         vi*sample_count + root->offset;\r
816                     for( i = 0; i < num_valid; i++ )\r
817                     {\r
818                         idx = src_idx[i];\r
819                         count_i = co[idx*2];\r
820                         if( count_i )\r
821                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
822                                 idst_idx[j] = cur_ofs;\r
823                     }\r
824 \r
825                     root->set_num_valid(vi, j);\r
826 \r
827                     for( ; i < sample_count; i++ )\r
828                     {\r
829                         idx = src_idx[i];\r
830                         count_i = co[idx*2];\r
831                         if( count_i )\r
832                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
833                                 idst_idx[j] = cur_ofs;\r
834                     }\r
835                 }\r
836             }\r
837         }\r
838         // sample indices subsampling\r
839         int* sample_idx_src_buf = get_sample_idx_buf();\r
840         const int* sample_idx_src = 0;\r
841         get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);\r
842         if (is_buf_16u)\r
843         {\r
844             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
845                 get_work_var_count()*sample_count + root->offset);            \r
846             for (i = 0; i < count; i++)\r
847                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];\r
848         }\r
849         else\r
850         {\r
851             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + \r
852                 get_work_var_count()*sample_count + root->offset;            \r
853             for (i = 0; i < count; i++)\r
854                 sample_idx_dst[i] = sample_idx_src[sidx[i]];\r
855         }\r
856     }\r
857 \r
858     __END__;\r
859 \r
860     cvReleaseMat( &isubsample_idx );\r
861     cvReleaseMat( &subsample_co );\r
862 \r
863     return root;\r
864 }\r
865 \r
866 \r
867 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,\r
868                                     float* values, uchar* missing,\r
869                                     float* responses, bool get_class_idx )\r
870 {\r
871     CvMat* subsample_idx = 0;\r
872     CvMat* subsample_co = 0;\r
873 \r
874     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );\r
875 \r
876     __BEGIN__;\r
877 \r
878     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
879     int* sidx = 0;\r
880     int* co = 0;\r
881 \r
882     if( _subsample_idx )\r
883     {\r
884         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
885         sidx = subsample_idx->data.i;\r
886         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
887         co = subsample_co->data.i;\r
888         cvZero( subsample_co );\r
889         count = subsample_idx->cols + subsample_idx->rows - 1;\r
890         for( i = 0; i < count; i++ )\r
891             co[sidx[i]*2]++;\r
892         for( i = 0; i < total; i++ )\r
893         {\r
894             int count_i = co[i*2];\r
895             if( count_i )\r
896             {\r
897                 co[i*2+1] = cur_ofs*var_count;\r
898                 cur_ofs += count_i;\r
899             }\r
900         }\r
901     }\r
902 \r
903     if( missing )\r
904         memset( missing, 1, count*var_count );\r
905 \r
906     for( vi = 0; vi < var_count; vi++ )\r
907     {\r
908         int ci = get_var_type(vi);\r
909         if( ci >= 0 ) // categorical\r
910         {\r
911             float* dst = values + vi;\r
912             uchar* m = missing ? missing + vi : 0;\r
913             int* src_buf = get_pred_int_buf();\r
914             const int* src = 0; \r
915             get_cat_var_data(data_root, vi, src_buf, &src);\r
916 \r
917             for( i = 0; i < count; i++, dst += var_count )\r
918             {\r
919                 int idx = sidx ? sidx[i] : i;\r
920                 int val = src[idx];\r
921                 *dst = (float)val;\r
922                 if( m )\r
923                 {\r
924                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));\r
925                     m += var_count;\r
926                 }\r
927             }\r
928         }\r
929         else // ordered\r
930         {\r
931             float* dst = values + vi;\r
932             uchar* m = missing ? missing + vi : 0;\r
933             int count1 = data_root->get_num_valid(vi);\r
934             float *src_val_buf = get_pred_float_buf();\r
935             const float *src_val = 0;\r
936             int* src_idx_buf = get_pred_int_buf();\r
937             const int* src_idx = 0;\r
938             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
939 \r
940             for( i = 0; i < count1; i++ )\r
941             {\r
942                 int idx = src_idx[i];\r
943                 int count_i = 1;\r
944                 if( co )\r
945                 {\r
946                     count_i = co[idx*2];\r
947                     cur_ofs = co[idx*2+1];\r
948                 }\r
949                 else\r
950                     cur_ofs = idx*var_count;\r
951                 if( count_i )\r
952                 {\r
953                     float val = src_val[i];\r
954                     for( ; count_i > 0; count_i--, cur_ofs += var_count )\r
955                     {\r
956                         dst[cur_ofs] = val;\r
957                         if( m )\r
958                             m[cur_ofs] = 0;\r
959                     }\r
960                 }\r
961             }\r
962         }\r
963     }\r
964 \r
965     // copy responses\r
966     if( responses )\r
967     {\r
968         if( is_classifier )\r
969         {\r
970             int* src_buf = get_resp_int_buf();\r
971             const int* src = 0;\r
972             get_class_labels(data_root, src_buf, &src);\r
973             for( i = 0; i < count; i++ )\r
974             {\r
975                 int idx = sidx ? sidx[i] : i;\r
976                 int val = get_class_idx ? src[idx] :\r
977                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];\r
978                 responses[i] = (float)val;\r
979             }\r
980         }\r
981         else\r
982         {\r
983             float *_values_buf = get_resp_float_buf();\r
984             const float* _values = 0;\r
985             get_ord_responses(data_root, _values_buf, &_values);\r
986             for( i = 0; i < count; i++ )\r
987             {\r
988                 int idx = sidx ? sidx[i] : i;\r
989                 responses[i] = _values[idx];\r
990             }\r
991         }\r
992     }\r
993 \r
994     __END__;\r
995 \r
996     cvReleaseMat( &subsample_idx );\r
997     cvReleaseMat( &subsample_co );\r
998 }\r
999 \r
1000 \r
1001 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,\r
1002                                          int storage_idx, int offset )\r
1003 {\r
1004     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );\r
1005 \r
1006     node->sample_count = count;\r
1007     node->depth = parent ? parent->depth + 1 : 0;\r
1008     node->parent = parent;\r
1009     node->left = node->right = 0;\r
1010     node->split = 0;\r
1011     node->value = 0;\r
1012     node->class_idx = 0;\r
1013     node->maxlr = 0.;\r
1014 \r
1015     node->buf_idx = storage_idx;\r
1016     node->offset = offset;\r
1017     if( nv_heap )\r
1018         node->num_valid = (int*)cvSetNew( nv_heap );\r
1019     else\r
1020         node->num_valid = 0;\r
1021     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;\r
1022     node->complexity = 0;\r
1023 \r
1024     if( params.cv_folds > 0 && cv_heap )\r
1025     {\r
1026         int cv_n = params.cv_folds;\r
1027         node->Tn = INT_MAX;\r
1028         node->cv_Tn = (int*)cvSetNew( cv_heap );\r
1029         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));\r
1030         node->cv_node_error = node->cv_node_risk + cv_n;\r
1031     }\r
1032     else\r
1033     {\r
1034         node->Tn = 0;\r
1035         node->cv_Tn = 0;\r
1036         node->cv_node_risk = 0;\r
1037         node->cv_node_error = 0;\r
1038     }\r
1039 \r
1040     return node;\r
1041 }\r
1042 \r
1043 \r
1044 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,\r
1045                 int split_point, int inversed, float quality )\r
1046 {\r
1047     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1048     split->var_idx = vi;\r
1049     split->condensed_idx = INT_MIN;\r
1050     split->ord.c = cmp_val;\r
1051     split->ord.split_point = split_point;\r
1052     split->inversed = inversed;\r
1053     split->quality = quality;\r
1054     split->next = 0;\r
1055 \r
1056     return split;\r
1057 }\r
1058 \r
1059 \r
1060 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )\r
1061 {\r
1062     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1063     int i, n = (max_c_count + 31)/32;\r
1064 \r
1065     split->var_idx = vi;\r
1066     split->condensed_idx = INT_MIN;\r
1067     split->inversed = 0;\r
1068     split->quality = quality;\r
1069     for( i = 0; i < n; i++ )\r
1070         split->subset[i] = 0;\r
1071     split->next = 0;\r
1072 \r
1073     return split;\r
1074 }\r
1075 \r
1076 \r
1077 void CvDTreeTrainData::free_node( CvDTreeNode* node )\r
1078 {\r
1079     CvDTreeSplit* split = node->split;\r
1080     free_node_data( node );\r
1081     while( split )\r
1082     {\r
1083         CvDTreeSplit* next = split->next;\r
1084         cvSetRemoveByPtr( split_heap, split );\r
1085         split = next;\r
1086     }\r
1087     node->split = 0;\r
1088     cvSetRemoveByPtr( node_heap, node );\r
1089 }\r
1090 \r
1091 \r
1092 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )\r
1093 {\r
1094     if( node->num_valid )\r
1095     {\r
1096         cvSetRemoveByPtr( nv_heap, node->num_valid );\r
1097         node->num_valid = 0;\r
1098     }\r
1099     // do not free cv_* fields, as all the cross-validation related data is released at once.\r
1100 }\r
1101 \r
1102 \r
1103 void CvDTreeTrainData::free_train_data()\r
1104 {\r
1105     cvReleaseMat( &counts );\r
1106     cvReleaseMat( &buf );\r
1107     cvReleaseMat( &direction );\r
1108     cvReleaseMat( &split_buf );\r
1109     cvReleaseMemStorage( &temp_storage );\r
1110     cvReleaseMat( &responses_copy );\r
1111     pred_float_buf.clear();\r
1112     pred_int_buf.clear();\r
1113     resp_float_buf.clear();\r
1114     resp_int_buf.clear();\r
1115     cv_lables_buf.clear();\r
1116     sample_idx_buf.clear();\r
1117 \r
1118     cv_heap = nv_heap = 0;\r
1119 }\r
1120 \r
1121 \r
1122 void CvDTreeTrainData::clear()\r
1123 {\r
1124     free_train_data();\r
1125 \r
1126     cvReleaseMemStorage( &tree_storage );\r
1127 \r
1128     cvReleaseMat( &var_idx );\r
1129     cvReleaseMat( &var_type );\r
1130     cvReleaseMat( &cat_count );\r
1131     cvReleaseMat( &cat_ofs );\r
1132     cvReleaseMat( &cat_map );\r
1133     cvReleaseMat( &priors );\r
1134     cvReleaseMat( &priors_mult );\r
1135     \r
1136     node_heap = split_heap = 0;\r
1137 \r
1138     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;\r
1139     have_labels = have_priors = is_classifier = false;\r
1140 \r
1141     buf_count = buf_size = 0;\r
1142     shared = false;\r
1143     \r
1144     data_root = 0;\r
1145 \r
1146     rng = cvRNG(-1);\r
1147 }\r
1148 \r
1149 \r
1150 int CvDTreeTrainData::get_num_classes() const\r
1151 {\r
1152     return is_classifier ? cat_count->data.i[cat_var_count] : 0;\r
1153 }\r
1154 \r
1155 \r
1156 int CvDTreeTrainData::get_var_type(int vi) const\r
1157 {\r
1158     return var_type->data.i[vi];\r
1159 }\r
1160 \r
1161 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )\r
1162 {\r
1163     int vidx = var_idx ? var_idx->data.i[vi] : vi;\r
1164     int node_sample_count = n->sample_count; \r
1165     int* sample_indices_buf = get_sample_idx_buf();\r
1166     const int* sample_indices = 0;\r
1167     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);\r
1168 \r
1169     get_sample_indices(n, sample_indices_buf, &sample_indices);\r
1170 \r
1171     if( !is_buf_16u )\r
1172         *indices = buf->data.i + n->buf_idx*buf->cols + \r
1173         vi*sample_count + n->offset;\r
1174     else {\r
1175         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1176             vi*sample_count + n->offset );\r
1177         for( int i = 0; i < node_sample_count; i++ )\r
1178             indices_buf[i] = short_indices[i];\r
1179         *indices = indices_buf;\r
1180     }\r
1181     \r
1182     if( tflag == CV_ROW_SAMPLE )\r
1183     {\r
1184         for( int i = 0; i < node_sample_count && \r
1185             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1186         {\r
1187             int idx = (*indices)[i];\r
1188             idx = sample_indices[idx];\r
1189             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);\r
1190         }\r
1191     }\r
1192     else\r
1193         for( int i = 0; i < node_sample_count && \r
1194             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1195         {\r
1196             int idx = (*indices)[i];\r
1197             idx = sample_indices[idx];\r
1198             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);\r
1199         }\r
1200     \r
1201     *ord_values = ord_values_buf;\r
1202     return 0; //TODO: return the number of non-missing values\r
1203 }\r
1204 \r
1205 \r
1206 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1207 {\r
1208     if (is_classifier)\r
1209         get_cat_var_data( n, var_count, labels_buf, labels );\r
1210 }\r
1211 \r
1212 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )\r
1213 {\r
1214     get_cat_var_data( n, get_work_var_count(), indices_buf, indices );\r
1215 }\r
1216 \r
1217 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)\r
1218 {\r
1219     int sample_count = n->sample_count;\r
1220     int* indices_buf = get_sample_idx_buf();\r
1221     const int* indices = 0;\r
1222 \r
1223     int r_step = responses->step/CV_ELEM_SIZE(responses->type);\r
1224 \r
1225     get_sample_indices(n, indices_buf, &indices);\r
1226 \r
1227     \r
1228     for( int i = 0; i < sample_count && \r
1229         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )\r
1230     {\r
1231         int idx = indices[i];\r
1232         values_buf[i] = *(responses->data.fl + idx * r_step);\r
1233     }\r
1234     \r
1235     *values = values_buf;    \r
1236 }\r
1237 \r
1238 \r
1239 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1240 {\r
1241     if (have_labels)\r
1242         get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );\r
1243 }\r
1244 \r
1245 \r
1246 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )\r
1247 {\r
1248     if( !is_buf_16u )\r
1249         *cat_values = buf->data.i + n->buf_idx*buf->cols + \r
1250         vi*sample_count + n->offset;\r
1251     else {\r
1252         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1253             vi*sample_count + n->offset);\r
1254         for( int i = 0; i < n->sample_count; i++ )\r
1255             cat_values_buf[i] = short_values[i];\r
1256         *cat_values = cat_values_buf;\r
1257     }\r
1258 \r
1259     return 0; //TODO: return the number of non-missing values\r
1260 }\r
1261 \r
1262 \r
1263 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )\r
1264 {\r
1265     int idx = n->buf_idx + 1;\r
1266     if( idx >= buf_count )\r
1267         idx = shared ? 1 : 0;\r
1268     return idx;\r
1269 }\r
1270 \r
1271 \r
1272 void CvDTreeTrainData::write_params( CvFileStorage* fs )\r
1273 {\r
1274     CV_FUNCNAME( "CvDTreeTrainData::write_params" );\r
1275 \r
1276     __BEGIN__;\r
1277 \r
1278     int vi, vcount = var_count;\r
1279 \r
1280     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );\r
1281     cvWriteInt( fs, "var_all", var_all );\r
1282     cvWriteInt( fs, "var_count", var_count );\r
1283     cvWriteInt( fs, "ord_var_count", ord_var_count );\r
1284     cvWriteInt( fs, "cat_var_count", cat_var_count );\r
1285 \r
1286     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );\r
1287     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );\r
1288 \r
1289     if( is_classifier )\r
1290     {\r
1291         cvWriteInt( fs, "max_categories", params.max_categories );\r
1292     }\r
1293     else\r
1294     {\r
1295         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );\r
1296     }\r
1297 \r
1298     cvWriteInt( fs, "max_depth", params.max_depth );\r
1299     cvWriteInt( fs, "min_sample_count", params.min_sample_count );\r
1300     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );\r
1301 \r
1302     if( params.cv_folds > 1 )\r
1303     {\r
1304         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );\r
1305         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );\r
1306     }\r
1307 \r
1308     if( priors )\r
1309         cvWrite( fs, "priors", priors );\r
1310 \r
1311     cvEndWriteStruct( fs );\r
1312 \r
1313     if( var_idx )\r
1314         cvWrite( fs, "var_idx", var_idx );\r
1315 \r
1316     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );\r
1317 \r
1318     for( vi = 0; vi < vcount; vi++ )\r
1319         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );\r
1320 \r
1321     cvEndWriteStruct( fs );\r
1322 \r
1323     if( cat_count && (cat_var_count > 0 || is_classifier) )\r
1324     {\r
1325         CV_ASSERT( cat_count != 0 );\r
1326         cvWrite( fs, "cat_count", cat_count );\r
1327         cvWrite( fs, "cat_map", cat_map );\r
1328     }\r
1329 \r
1330     __END__;\r
1331 }\r
1332 \r
1333 \r
1334 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )\r
1335 {\r
1336     CV_FUNCNAME( "CvDTreeTrainData::read_params" );\r
1337 \r
1338     __BEGIN__;\r
1339 \r
1340     CvFileNode *tparams_node, *vartype_node;\r
1341     CvSeqReader reader;\r
1342     int vi, max_split_size, tree_block_size;\r
1343 \r
1344     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);\r
1345     var_all = cvReadIntByName( fs, node, "var_all" );\r
1346     var_count = cvReadIntByName( fs, node, "var_count", var_all );\r
1347     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );\r
1348     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );\r
1349 \r
1350     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );\r
1351 \r
1352     if( tparams_node ) // training parameters are not necessary\r
1353     {\r
1354         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;\r
1355 \r
1356         if( is_classifier )\r
1357         {\r
1358             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );\r
1359         }\r
1360         else\r
1361         {\r
1362             params.regression_accuracy =\r
1363                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );\r
1364         }\r
1365 \r
1366         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );\r
1367         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );\r
1368         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );\r
1369 \r
1370         if( params.cv_folds > 1 )\r
1371         {\r
1372             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;\r
1373             params.truncate_pruned_tree =\r
1374                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;\r
1375         }\r
1376 \r
1377         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );\r
1378         if( priors )\r
1379         {\r
1380             if( !CV_IS_MAT(priors) )\r
1381                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );\r
1382             priors_mult = cvCloneMat( priors );\r
1383         }\r
1384     }\r
1385 \r
1386     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));\r
1387     if( var_idx )\r
1388     {\r
1389         if( !CV_IS_MAT(var_idx) ||\r
1390             (var_idx->cols != 1 && var_idx->rows != 1) ||\r
1391             var_idx->cols + var_idx->rows - 1 != var_count ||\r
1392             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )\r
1393             CV_ERROR( CV_StsParseError,\r
1394                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );\r
1395 \r
1396         for( vi = 0; vi < var_count; vi++ )\r
1397             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )\r
1398                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );\r
1399     }\r
1400 \r
1401     ////// read var type\r
1402     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));\r
1403 \r
1404     cat_var_count = 0;\r
1405     ord_var_count = -1;\r
1406     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );\r
1407 \r
1408     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )\r
1409         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;\r
1410     else\r
1411     {\r
1412         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||\r
1413             vartype_node->data.seq->total != var_count )\r
1414             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1415 \r
1416         cvStartReadSeq( vartype_node->data.seq, &reader );\r
1417 \r
1418         for( vi = 0; vi < var_count; vi++ )\r
1419         {\r
1420             CvFileNode* n = (CvFileNode*)reader.ptr;\r
1421             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )\r
1422                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1423             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;\r
1424             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
1425         }\r
1426     }\r
1427     var_type->data.i[var_count] = cat_var_count;\r
1428 \r
1429     ord_var_count = ~ord_var_count;\r
1430     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )\r
1431         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );\r
1432     //////\r
1433 \r
1434     if( cat_var_count > 0 || is_classifier )\r
1435     {\r
1436         int ccount, total_c_count = 0;\r
1437         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));\r
1438         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));\r
1439 \r
1440         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||\r
1441             (cat_count->cols != 1 && cat_count->rows != 1) ||\r
1442             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||\r
1443             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||\r
1444             (cat_map->cols != 1 && cat_map->rows != 1) ||\r
1445             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )\r
1446             CV_ERROR( CV_StsParseError,\r
1447             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );\r
1448 \r
1449         ccount = cat_var_count + is_classifier;\r
1450 \r
1451         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));\r
1452         cat_ofs->data.i[0] = 0;\r
1453         max_c_count = 1;\r
1454 \r
1455         for( vi = 0; vi < ccount; vi++ )\r
1456         {\r
1457             int val = cat_count->data.i[vi];\r
1458             if( val <= 0 )\r
1459                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );\r
1460             max_c_count = MAX( max_c_count, val );\r
1461             cat_ofs->data.i[vi+1] = total_c_count += val;\r
1462         }\r
1463 \r
1464         if( cat_map->cols + cat_map->rows - 1 != total_c_count )\r
1465             CV_ERROR( CV_StsBadSize,\r
1466             "cat_map vector length is not equal to the total number of categories in all categorical vars" );\r
1467     }\r
1468 \r
1469     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
1470         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
1471 \r
1472     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
1473     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
1474     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
1475     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),\r
1476             sizeof(CvDTreeNode), tree_storage ));\r
1477     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),\r
1478             max_split_size, tree_storage ));\r
1479 \r
1480     __END__;\r
1481 }\r
1482 \r
1483 float* CvDTreeTrainData::get_pred_float_buf()\r
1484 {\r
1485     return &pred_float_buf[cv::getThreadNum()][0];\r
1486 }\r
1487 int* CvDTreeTrainData::get_pred_int_buf()\r
1488 {\r
1489     return &pred_int_buf[cv::getThreadNum()][0];\r
1490 }\r
1491 float* CvDTreeTrainData::get_resp_float_buf()\r
1492 {\r
1493     return &resp_float_buf[cv::getThreadNum()][0];\r
1494 }\r
1495 int* CvDTreeTrainData::get_resp_int_buf()\r
1496 {\r
1497     return &resp_int_buf[cv::getThreadNum()][0];\r
1498 }\r
1499 int* CvDTreeTrainData::get_cv_lables_buf()\r
1500 {\r
1501     return &cv_lables_buf[cv::getThreadNum()][0];\r
1502 }\r
1503 int* CvDTreeTrainData::get_sample_idx_buf()\r
1504 {\r
1505     return &sample_idx_buf[cv::getThreadNum()][0];\r
1506 }\r
1507 \r
1508 /////////////////////// Decision Tree /////////////////////////\r
1509 \r
1510 CvDTree::CvDTree()\r
1511 {\r
1512     data = 0;\r
1513     var_importance = 0;\r
1514     default_model_name = "my_tree";\r
1515 \r
1516     clear();\r
1517 }\r
1518 \r
1519 \r
1520 void CvDTree::clear()\r
1521 {\r
1522     cvReleaseMat( &var_importance );\r
1523     if( data )\r
1524     {\r
1525         if( !data->shared )\r
1526             delete data;\r
1527         else\r
1528             free_tree();\r
1529         data = 0;\r
1530     }\r
1531     root = 0;\r
1532     pruned_tree_idx = -1;\r
1533 }\r
1534 \r
1535 \r
1536 CvDTree::~CvDTree()\r
1537 {\r
1538     clear();\r
1539 }\r
1540 \r
1541 \r
1542 const CvDTreeNode* CvDTree::get_root() const\r
1543 {\r
1544     return root;\r
1545 }\r
1546 \r
1547 \r
1548 int CvDTree::get_pruned_tree_idx() const\r
1549 {\r
1550     return pruned_tree_idx;\r
1551 }\r
1552 \r
1553 \r
1554 CvDTreeTrainData* CvDTree::get_data()\r
1555 {\r
1556     return data;\r
1557 }\r
1558 \r
1559 \r
1560 bool CvDTree::train( const CvMat* _train_data, int _tflag,\r
1561                      const CvMat* _responses, const CvMat* _var_idx,\r
1562                      const CvMat* _sample_idx, const CvMat* _var_type,\r
1563                      const CvMat* _missing_mask, CvDTreeParams _params )\r
1564 {\r
1565     bool result = false;\r
1566 \r
1567     CV_FUNCNAME( "CvDTree::train" );\r
1568 \r
1569     __BEGIN__;\r
1570 \r
1571     clear();\r
1572     data = new CvDTreeTrainData( _train_data, _tflag, _responses,\r
1573                                  _var_idx, _sample_idx, _var_type,\r
1574                                  _missing_mask, _params, false );\r
1575     CV_CALL( result = do_train(0) );\r
1576 \r
1577     __END__;\r
1578 \r
1579     return result;\r
1580 }\r
1581 \r
1582 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )\r
1583 {\r
1584    bool result = false;\r
1585 \r
1586     CV_FUNCNAME( "CvDTree::train" );\r
1587 \r
1588     __BEGIN__;\r
1589 \r
1590     const CvMat* values = _data->get_values();\r
1591     const CvMat* response = _data->get_response();\r
1592     const CvMat* missing = _data->get_missing();\r
1593     const CvMat* var_types = _data->get_var_types();\r
1594     const CvMat* train_sidx = _data->get_train_sample_idx();\r
1595     const CvMat* var_idx = _data->get_var_idx();\r
1596 \r
1597     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,\r
1598         train_sidx, var_types, missing, _params ) );\r
1599 \r
1600     __END__;\r
1601 \r
1602     return result;\r
1603 }\r
1604 \r
1605 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )\r
1606 {\r
1607     bool result = false;\r
1608 \r
1609     CV_FUNCNAME( "CvDTree::train" );\r
1610 \r
1611     __BEGIN__;\r
1612 \r
1613     clear();\r
1614     data = _data;\r
1615     data->shared = true;\r
1616     CV_CALL( result = do_train(_subsample_idx));\r
1617 \r
1618     __END__;\r
1619 \r
1620     return result;\r
1621 }\r
1622 \r
1623 \r
1624 bool CvDTree::do_train( const CvMat* _subsample_idx )\r
1625 {\r
1626     bool result = false;\r
1627 \r
1628     CV_FUNCNAME( "CvDTree::do_train" );\r
1629 \r
1630     __BEGIN__;\r
1631 \r
1632     root = data->subsample_data( _subsample_idx );\r
1633 \r
1634     CV_CALL( try_split_node(root));\r
1635 \r
1636     if( data->params.cv_folds > 0 )\r
1637         CV_CALL( prune_cv());\r
1638 \r
1639     if( !data->shared )\r
1640         data->free_train_data();\r
1641 \r
1642     result = true;\r
1643 \r
1644     __END__;\r
1645 \r
1646     return result;\r
1647 }\r
1648 \r
1649 \r
1650 void CvDTree::try_split_node( CvDTreeNode* node )\r
1651 {\r
1652     CvDTreeSplit* best_split = 0;\r
1653     int i, n = node->sample_count, vi;\r
1654     bool can_split = true;\r
1655     double quality_scale;\r
1656 \r
1657     calc_node_value( node );\r
1658 \r
1659     if( node->sample_count <= data->params.min_sample_count ||\r
1660         node->depth >= data->params.max_depth )\r
1661         can_split = false;\r
1662 \r
1663     if( can_split && data->is_classifier )\r
1664     {\r
1665         // check if we have a "pure" node,\r
1666         // we assume that cls_count is filled by calc_node_value()\r
1667         int* cls_count = data->counts->data.i;\r
1668         int nz = 0, m = data->get_num_classes();\r
1669         for( i = 0; i < m; i++ )\r
1670             nz += cls_count[i] != 0;\r
1671         if( nz == 1 ) // there is only one class\r
1672             can_split = false;\r
1673     }\r
1674     else if( can_split )\r
1675     {\r
1676         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )\r
1677             can_split = false;\r
1678     }\r
1679 \r
1680     if( can_split )\r
1681     {\r
1682         best_split = find_best_split(node);\r
1683         // TODO: check the split quality ...\r
1684         node->split = best_split;\r
1685     }\r
1686     if( !can_split || !best_split )\r
1687     {\r
1688         data->free_node_data(node);\r
1689         return;\r
1690     }\r
1691 \r
1692     quality_scale = calc_node_dir( node );\r
1693     if( data->params.use_surrogates )\r
1694     {\r
1695         // find all the surrogate splits\r
1696         // and sort them by their similarity to the primary one\r
1697         for( vi = 0; vi < data->var_count; vi++ )\r
1698         {\r
1699             CvDTreeSplit* split;\r
1700             int ci = data->get_var_type(vi);\r
1701 \r
1702             if( vi == best_split->var_idx )\r
1703                 continue;\r
1704 \r
1705             if( ci >= 0 )\r
1706                 split = find_surrogate_split_cat( node, vi );\r
1707             else\r
1708                 split = find_surrogate_split_ord( node, vi );\r
1709 \r
1710             if( split )\r
1711             {\r
1712                 // insert the split\r
1713                 CvDTreeSplit* prev_split = node->split;\r
1714                 split->quality = (float)(split->quality*quality_scale);\r
1715 \r
1716                 while( prev_split->next &&\r
1717                        prev_split->next->quality > split->quality )\r
1718                     prev_split = prev_split->next;\r
1719                 split->next = prev_split->next;\r
1720                 prev_split->next = split;\r
1721             }\r
1722         }\r
1723     }\r
1724     split_node_data( node );\r
1725     try_split_node( node->left );\r
1726     try_split_node( node->right );\r
1727 }\r
1728 \r
1729 \r
1730 // calculate direction (left(-1),right(1),missing(0))\r
1731 // for each sample using the best split\r
1732 // the function returns scale coefficients for surrogate split quality factors.\r
1733 // the scale is applied to normalize surrogate split quality relatively to the\r
1734 // best (primary) split quality. That is, if a surrogate split is absolutely\r
1735 // identical to the primary split, its quality will be set to the maximum value =\r
1736 // quality of the primary split; otherwise, it will be lower.\r
1737 // besides, the function compute node->maxlr,\r
1738 // minimum possible quality (w/o considering the above mentioned scale)\r
1739 // for a surrogate split. Surrogate splits with quality less than node->maxlr\r
1740 // are not discarded.\r
1741 double CvDTree::calc_node_dir( CvDTreeNode* node )\r
1742 {\r
1743     char* dir = (char*)data->direction->data.ptr;\r
1744     int i, n = node->sample_count, vi = node->split->var_idx;\r
1745     double L, R;\r
1746 \r
1747     assert( !node->split->inversed );\r
1748 \r
1749     if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
1750     {\r
1751         int* labels_buf = data->get_pred_int_buf();\r
1752         const int* labels = 0;\r
1753         const int* subset = node->split->subset;\r
1754         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1755         if( !data->have_priors )\r
1756         {\r
1757             int sum = 0, sum_abs = 0;\r
1758 \r
1759             for( i = 0; i < n; i++ )\r
1760             {\r
1761                 int idx = labels[i];\r
1762                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?\r
1763                     CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1764                 sum += d; sum_abs += d & 1;\r
1765                 dir[i] = (char)d;\r
1766             }\r
1767 \r
1768             R = (sum_abs + sum) >> 1;\r
1769             L = (sum_abs - sum) >> 1;\r
1770         }\r
1771         else\r
1772         {\r
1773             const double* priors = data->priors_mult->data.db;\r
1774             double sum = 0, sum_abs = 0;\r
1775             int *responses_buf = data->get_resp_int_buf();\r
1776             const int* responses;\r
1777             data->get_class_labels(node, responses_buf, &responses);\r
1778 \r
1779             for( i = 0; i < n; i++ )\r
1780             {\r
1781                 int idx = labels[i];\r
1782                 double w = priors[responses[i]];\r
1783                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1784                 sum += d*w; sum_abs += (d & 1)*w;\r
1785                 dir[i] = (char)d;\r
1786             }\r
1787 \r
1788             R = (sum_abs + sum) * 0.5;\r
1789             L = (sum_abs - sum) * 0.5;\r
1790         }\r
1791     }\r
1792     else // split on ordered var\r
1793     {\r
1794         int split_point = node->split->ord.split_point;\r
1795         int n1 = node->get_num_valid(vi);\r
1796         \r
1797         float* val_buf = data->get_pred_float_buf();\r
1798         const float* val = 0;\r
1799         int* sorted_buf = data->get_pred_int_buf();\r
1800         const int* sorted = 0;\r
1801         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);\r
1802         \r
1803         assert( 0 <= split_point && split_point < n1-1 );\r
1804 \r
1805         if( !data->have_priors )\r
1806         {\r
1807             for( i = 0; i <= split_point; i++ )\r
1808                 dir[sorted[i]] = (char)-1;\r
1809             for( ; i < n1; i++ )\r
1810                 dir[sorted[i]] = (char)1;\r
1811             for( ; i < n; i++ )\r
1812                 dir[sorted[i]] = (char)0;\r
1813 \r
1814             L = split_point-1;\r
1815             R = n1 - split_point + 1;\r
1816         }\r
1817         else\r
1818         {\r
1819             const double* priors = data->priors_mult->data.db;\r
1820             int* responses_buf = data->get_resp_int_buf();\r
1821             const int* responses = 0;\r
1822             data->get_class_labels(node, responses_buf, &responses);\r
1823             L = R = 0;\r
1824 \r
1825             for( i = 0; i <= split_point; i++ )\r
1826             {\r
1827                 int idx = sorted[i];\r
1828                 double w = priors[responses[idx]];\r
1829                 dir[idx] = (char)-1;\r
1830                 L += w;\r
1831             }\r
1832 \r
1833             for( ; i < n1; i++ )\r
1834             {\r
1835                 int idx = sorted[i];\r
1836                 double w = priors[responses[idx]];\r
1837                 dir[idx] = (char)1;\r
1838                 R += w;\r
1839             }\r
1840 \r
1841             for( ; i < n; i++ )\r
1842                 dir[sorted[i]] = (char)0;\r
1843         }\r
1844     }\r
1845     node->maxlr = MAX( L, R );\r
1846     return node->split->quality/(L + R);\r
1847 }\r
1848 \r
1849 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )\r
1850 {\r
1851     int vi;\r
1852     CvDTreeSplit *bestSplit = 0;\r
1853     int maxNumThreads = 1;\r
1854 #ifdef _OPENMP\r
1855     maxNumThreads = cv::getNumThreads();\r
1856 #endif\r
1857     vector<CvDTreeSplit*> splits(maxNumThreads);\r
1858     vector<CvDTreeSplit*> bestSplits(maxNumThreads);\r
1859     for (int i = 0; i < maxNumThreads; i++)\r
1860     {\r
1861         splits[i] = data->new_split_cat( 0, -1.0f );\r
1862         bestSplits[i] = data->new_split_cat( 0, -1.0f );\r
1863     }\r
1864 \r
1865     bool can_split = false;\r
1866 #ifdef _OPENMP\r
1867 #pragma omp parallel for num_threads(maxNumThreads) schedule(dynamic)\r
1868 #endif\r
1869     for( vi = 0; vi < data->var_count; vi++ )\r
1870     {\r
1871         CvDTreeSplit *res, *t;\r
1872         int threadIdx = cv::getThreadNum();\r
1873         int ci = data->get_var_type(vi);\r
1874         if( node->get_num_valid(vi) <= 1 )\r
1875             continue;\r
1876 \r
1877         if( data->is_classifier )\r
1878         {\r
1879             if( ci >= 0 )\r
1880                 res = find_split_cat_class( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );\r
1881             else\r
1882                 res = find_split_ord_class( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );\r
1883         }\r
1884         else\r
1885         {\r
1886             if( ci >= 0 )\r
1887                 res = find_split_cat_reg( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );\r
1888             else\r
1889                 res = find_split_ord_reg( node, vi, bestSplits[threadIdx]->quality, splits[threadIdx] );\r
1890         }\r
1891 \r
1892         if( res )\r
1893         {\r
1894             can_split = true;\r
1895             if( bestSplits[threadIdx]->quality < splits[threadIdx]->quality )\r
1896                 CV_SWAP( bestSplits[threadIdx], splits[threadIdx], t );\r
1897         }\r
1898     }\r
1899     if ( can_split )\r
1900     {\r
1901         bestSplit = bestSplits[0];\r
1902         for(int i = 1; i < maxNumThreads; i++)\r
1903         {\r
1904             if( bestSplit->quality < bestSplits[i]->quality )\r
1905                 bestSplit = bestSplits[i];\r
1906         }\r
1907     }\r
1908     for(int i = 0; i < maxNumThreads; i++)\r
1909     {\r
1910         cvSetRemoveByPtr( data->split_heap, splits[i] );\r
1911         if( bestSplits[i] != bestSplit )\r
1912             cvSetRemoveByPtr( data->split_heap, bestSplits[i] );\r
1913     }\r
1914     return bestSplit;\r
1915 }\r
1916 \r
1917 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,\r
1918                                             float init_quality, CvDTreeSplit* _split )\r
1919 {\r
1920     const float epsilon = FLT_EPSILON*2;\r
1921     int n = node->sample_count;\r
1922     int n1 = node->get_num_valid(vi);\r
1923     int m = data->get_num_classes();\r
1924 \r
1925     float* values_buf = data->get_pred_float_buf();\r
1926     const float* values = 0;\r
1927     int* indices_buf = data->get_pred_int_buf();\r
1928     const int* indices = 0;\r
1929     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
1930     int* responses_buf =  data->get_resp_int_buf();\r
1931     const int* responses = 0;\r
1932     data->get_class_labels( node, responses_buf, &responses );\r
1933 \r
1934     const int* rc0 = data->counts->data.i;\r
1935     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
1936     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
1937     int i, best_i = -1;\r
1938     double lsum2 = 0, rsum2 = 0, best_val = init_quality;\r
1939     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
1940 \r
1941     // init arrays of class instance counters on both sides of the split\r
1942     for( i = 0; i < m; i++ )\r
1943     {\r
1944         lc[i] = 0;\r
1945         rc[i] = rc0[i];\r
1946     }\r
1947 \r
1948     // compensate for missing values\r
1949     for( i = n1; i < n; i++ )\r
1950     {\r
1951         rc[responses[indices[i]]]--;\r
1952     }\r
1953 \r
1954     if( !priors )\r
1955     {\r
1956         int L = 0, R = n1;\r
1957 \r
1958         for( i = 0; i < m; i++ )\r
1959             rsum2 += (double)rc[i]*rc[i];\r
1960 \r
1961         for( i = 0; i < n1 - 1; i++ )\r
1962         {\r
1963             int idx = responses[indices[i]];\r
1964             int lv, rv;\r
1965             L++; R--;\r
1966             lv = lc[idx]; rv = rc[idx];\r
1967             lsum2 += lv*2 + 1;\r
1968             rsum2 -= rv*2 - 1;\r
1969             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1970 \r
1971             if( values[i] + epsilon < values[i+1] )\r
1972             {\r
1973                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1974                 if( best_val < val )\r
1975                 {\r
1976                     best_val = val;\r
1977                     best_i = i;\r
1978                 }\r
1979             }\r
1980         }\r
1981     }\r
1982     else\r
1983     {\r
1984         double L = 0, R = 0;\r
1985         for( i = 0; i < m; i++ )\r
1986         {\r
1987             double wv = rc[i]*priors[i];\r
1988             R += wv;\r
1989             rsum2 += wv*wv;\r
1990         }\r
1991 \r
1992         for( i = 0; i < n1 - 1; i++ )\r
1993         {\r
1994             int idx = responses[indices[i]];\r
1995             int lv, rv;\r
1996             double p = priors[idx], p2 = p*p;\r
1997             L += p; R -= p;\r
1998             lv = lc[idx]; rv = rc[idx];\r
1999             lsum2 += p2*(lv*2 + 1);\r
2000             rsum2 -= p2*(rv*2 - 1);\r
2001             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
2002 \r
2003             if( values[i] + epsilon < values[i+1] )\r
2004             {\r
2005                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
2006                 if( best_val < val )\r
2007                 {\r
2008                     best_val = val;\r
2009                     best_i = i;\r
2010                 }\r
2011             }\r
2012         }\r
2013     }\r
2014 \r
2015     CvDTreeSplit* split = 0;\r
2016     if( best_i >= 0 )\r
2017     {\r
2018         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );\r
2019         split->var_idx = vi;\r
2020         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;\r
2021         split->ord.split_point = best_i;\r
2022         split->inversed = 0;\r
2023         split->quality = (float)best_val;\r
2024     }\r
2025     return split;\r
2026 }\r
2027 \r
2028 \r
2029 void CvDTree::cluster_categories( const int* vectors, int n, int m,\r
2030                                 int* csums, int k, int* labels )\r
2031 {\r
2032     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm\r
2033     int iters = 0, max_iters = 100;\r
2034     int i, j, idx;\r
2035     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );\r
2036     double *v_weights = buf, *c_weights = buf + k;\r
2037     bool modified = true;\r
2038     CvRNG* r = &data->rng;\r
2039 \r
2040     // assign labels randomly\r
2041     for( i = idx = 0; i < n; i++ )\r
2042     {\r
2043         int sum = 0;\r
2044         const int* v = vectors + i*m;\r
2045         labels[i] = idx++;\r
2046         idx &= idx < k ? -1 : 0;\r
2047 \r
2048         // compute weight of each vector\r
2049         for( j = 0; j < m; j++ )\r
2050             sum += v[j];\r
2051         v_weights[i] = sum ? 1./sum : 0.;\r
2052     }\r
2053 \r
2054     for( i = 0; i < n; i++ )\r
2055     {\r
2056         int i1 = cvRandInt(r) % n;\r
2057         int i2 = cvRandInt(r) % n;\r
2058         CV_SWAP( labels[i1], labels[i2], j );\r
2059     }\r
2060 \r
2061     for( iters = 0; iters <= max_iters; iters++ )\r
2062     {\r
2063         // calculate csums\r
2064         for( i = 0; i < k; i++ )\r
2065         {\r
2066             for( j = 0; j < m; j++ )\r
2067                 csums[i*m + j] = 0;\r
2068         }\r
2069 \r
2070         for( i = 0; i < n; i++ )\r
2071         {\r
2072             const int* v = vectors + i*m;\r
2073             int* s = csums + labels[i]*m;\r
2074             for( j = 0; j < m; j++ )\r
2075                 s[j] += v[j];\r
2076         }\r
2077 \r
2078         // exit the loop here, when we have up-to-date csums\r
2079         if( iters == max_iters || !modified )\r
2080             break;\r
2081 \r
2082         modified = false;\r
2083 \r
2084         // calculate weight of each cluster\r
2085         for( i = 0; i < k; i++ )\r
2086         {\r
2087             const int* s = csums + i*m;\r
2088             int sum = 0;\r
2089             for( j = 0; j < m; j++ )\r
2090                 sum += s[j];\r
2091             c_weights[i] = sum ? 1./sum : 0;\r
2092         }\r
2093 \r
2094         // now for each vector determine the closest cluster\r
2095         for( i = 0; i < n; i++ )\r
2096         {\r
2097             const int* v = vectors + i*m;\r
2098             double alpha = v_weights[i];\r
2099             double min_dist2 = DBL_MAX;\r
2100             int min_idx = -1;\r
2101 \r
2102             for( idx = 0; idx < k; idx++ )\r
2103             {\r
2104                 const int* s = csums + idx*m;\r
2105                 double dist2 = 0., beta = c_weights[idx];\r
2106                 for( j = 0; j < m; j++ )\r
2107                 {\r
2108                     double t = v[j]*alpha - s[j]*beta;\r
2109                     dist2 += t*t;\r
2110                 }\r
2111                 if( min_dist2 > dist2 )\r
2112                 {\r
2113                     min_dist2 = dist2;\r
2114                     min_idx = idx;\r
2115                 }\r
2116             }\r
2117 \r
2118             if( min_idx != labels[i] )\r
2119                 modified = true;\r
2120             labels[i] = min_idx;\r
2121         }\r
2122     }\r
2123 }\r
2124 \r
2125 \r
2126 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
2127 {\r
2128     int ci = data->get_var_type(vi);\r
2129     int n = node->sample_count;\r
2130     int m = data->get_num_classes();\r
2131     int _mi = data->cat_count->data.i[ci], mi = _mi;\r
2132 \r
2133     int* labels_buf = data->get_pred_int_buf();\r
2134     const int* labels = 0;\r
2135     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2136     int *responses_buf = data->get_resp_int_buf();\r
2137     const int* responses = 0;\r
2138     data->get_class_labels(node, responses_buf, &responses);\r
2139 \r
2140     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
2141     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
2142     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;\r
2143     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );\r
2144     int* cluster_labels = 0;\r
2145     int** int_ptr = 0;\r
2146     int i, j, k, idx;\r
2147     double L = 0, R = 0;\r
2148     double best_val = init_quality;\r
2149     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;\r
2150     const double* priors = data->priors_mult->data.db;\r
2151 \r
2152     // init array of counters:\r
2153     // c_{jk} - number of samples that have vi-th input variable = j and response = k.\r
2154     for( j = -1; j < mi; j++ )\r
2155         for( k = 0; k < m; k++ )\r
2156             cjk[j*m + k] = 0;\r
2157 \r
2158     for( i = 0; i < n; i++ )\r
2159     {\r
2160        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];\r
2161        k = responses[i];\r
2162        cjk[j*m + k]++;\r
2163     }\r
2164 \r
2165     if( m > 2 )\r
2166     {\r
2167         if( mi > data->params.max_categories )\r
2168         {\r
2169             mi = MIN(data->params.max_categories, n);\r
2170             cjk += _mi*m;\r
2171             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));\r
2172             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );\r
2173         }\r
2174         subset_i = 1;\r
2175         subset_n = 1 << mi;\r
2176     }\r
2177     else\r
2178     {\r
2179         assert( m == 2 );\r
2180         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );\r
2181         for( j = 0; j < mi; j++ )\r
2182             int_ptr[j] = cjk + j*2 + 1;\r
2183         icvSortIntPtr( int_ptr, mi, 0 );\r
2184         subset_i = 0;\r
2185         subset_n = mi;\r
2186     }\r
2187 \r
2188     for( k = 0; k < m; k++ )\r
2189     {\r
2190         int sum = 0;\r
2191         for( j = 0; j < mi; j++ )\r
2192             sum += cjk[j*m + k];\r
2193         rc[k] = sum;\r
2194         lc[k] = 0;\r
2195     }\r
2196 \r
2197     for( j = 0; j < mi; j++ )\r
2198     {\r
2199         double sum = 0;\r
2200         for( k = 0; k < m; k++ )\r
2201             sum += cjk[j*m + k]*priors[k];\r
2202         c_weights[j] = sum;\r
2203         R += c_weights[j];\r
2204     }\r
2205 \r
2206     for( ; subset_i < subset_n; subset_i++ )\r
2207     {\r
2208         double weight;\r
2209         int* crow;\r
2210         double lsum2 = 0, rsum2 = 0;\r
2211 \r
2212         if( m == 2 )\r
2213             idx = (int)(int_ptr[subset_i] - cjk)/2;\r
2214         else\r
2215         {\r
2216             int graycode = (subset_i>>1)^subset_i;\r
2217             int diff = graycode ^ prevcode;\r
2218 \r
2219             // determine index of the changed bit.\r
2220             Cv32suf u;\r
2221             idx = diff >= (1 << 16) ? 16 : 0;\r
2222             u.f = (float)(((diff >> 16) | diff) & 65535);\r
2223             idx += (u.i >> 23) - 127;\r
2224             subtract = graycode < prevcode;\r
2225             prevcode = graycode;\r
2226         }\r
2227 \r
2228         crow = cjk + idx*m;\r
2229         weight = c_weights[idx];\r
2230         if( weight < FLT_EPSILON )\r
2231             continue;\r
2232 \r
2233         if( !subtract )\r
2234         {\r
2235             for( k = 0; k < m; k++ )\r
2236             {\r
2237                 int t = crow[k];\r
2238                 int lval = lc[k] + t;\r
2239                 int rval = rc[k] - t;\r
2240                 double p = priors[k], p2 = p*p;\r
2241                 lsum2 += p2*lval*lval;\r
2242                 rsum2 += p2*rval*rval;\r
2243                 lc[k] = lval; rc[k] = rval;\r
2244             }\r
2245             L += weight;\r
2246             R -= weight;\r
2247         }\r
2248         else\r
2249         {\r
2250             for( k = 0; k < m; k++ )\r
2251             {\r
2252                 int t = crow[k];\r
2253                 int lval = lc[k] - t;\r
2254                 int rval = rc[k] + t;\r
2255                 double p = priors[k], p2 = p*p;\r
2256                 lsum2 += p2*lval*lval;\r
2257                 rsum2 += p2*rval*rval;\r
2258                 lc[k] = lval; rc[k] = rval;\r
2259             }\r
2260             L -= weight;\r
2261             R += weight;\r
2262         }\r
2263 \r
2264         if( L > FLT_EPSILON && R > FLT_EPSILON )\r
2265         {\r
2266             double val = (lsum2*R + rsum2*L)/((double)L*R);\r
2267             if( best_val < val )\r
2268             {\r
2269                 best_val = val;\r
2270                 best_subset = subset_i;\r
2271             }\r
2272         }\r
2273     }\r
2274 \r
2275     CvDTreeSplit* split = 0;\r
2276     if( best_subset >= 0 ) \r
2277     {\r
2278         split = _split ? _split : data->new_split_cat( 0, -1.0f );\r
2279         split->var_idx = vi;\r
2280         split->quality = (float)best_val;\r
2281         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));\r
2282         if( m == 2 )\r
2283         {\r
2284             for( i = 0; i <= best_subset; i++ )\r
2285             {\r
2286                 idx = (int)(int_ptr[i] - cjk) >> 1;\r
2287                 split->subset[idx >> 5] |= 1 << (idx & 31);\r
2288             }\r
2289         }\r
2290         else\r
2291         {\r
2292             for( i = 0; i < _mi; i++ )\r
2293             {\r
2294                 idx = cluster_labels ? cluster_labels[i] : i;\r
2295                 if( best_subset & (1 << idx) )\r
2296                     split->subset[i >> 5] |= 1 << (i & 31);\r
2297             }\r
2298         }\r
2299     }\r
2300     return split;\r
2301 }\r
2302 \r
2303 \r
2304 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
2305 {\r
2306     const float epsilon = FLT_EPSILON*2;\r
2307     int n = node->sample_count;\r
2308     int n1 = node->get_num_valid(vi);\r
2309 \r
2310     float* values_buf = data->get_pred_float_buf();\r
2311     const float* values = 0;\r
2312     int* indices_buf = data->get_pred_int_buf();\r
2313     const int* indices = 0;\r
2314     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2315     float* responses_buf =  data->get_resp_float_buf();\r
2316     const float* responses = 0;\r
2317     data->get_ord_responses( node, responses_buf, &responses );\r
2318 \r
2319     int i, best_i = -1;\r
2320     double best_val = init_quality, lsum = 0, rsum = node->value*n;\r
2321     int L = 0, R = n1;\r
2322 \r
2323     // compensate for missing values\r
2324     for( i = n1; i < n; i++ )\r
2325         rsum -= responses[indices[i]];\r
2326 \r
2327     // find the optimal split\r
2328     for( i = 0; i < n1 - 1; i++ )\r
2329     {\r
2330         float t = responses[indices[i]];\r
2331         L++; R--;\r
2332         lsum += t;\r
2333         rsum -= t;\r
2334 \r
2335         if( values[i] + epsilon < values[i+1] )\r
2336         {\r
2337             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2338             if( best_val < val )\r
2339             {\r
2340                 best_val = val;\r
2341                 best_i = i;\r
2342             }\r
2343         }\r
2344     }\r
2345 \r
2346     CvDTreeSplit* split = 0;\r
2347     if( best_i >= 0 )\r
2348     {\r
2349         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );\r
2350         split->var_idx = vi;\r
2351         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;\r
2352         split->ord.split_point = best_i;\r
2353         split->inversed = 0;\r
2354         split->quality = (float)best_val;\r
2355     }\r
2356     return split;\r
2357 }\r
2358 \r
2359 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )\r
2360 {\r
2361     int ci = data->get_var_type(vi);\r
2362     int n = node->sample_count;\r
2363     int mi = data->cat_count->data.i[ci];\r
2364     int* labels_buf = data->get_pred_int_buf();\r
2365     const int* labels = 0;\r
2366     float* responses_buf = data->get_resp_float_buf();\r
2367     const float* responses = 0;\r
2368     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2369     data->get_ord_responses(node, responses_buf, &responses);\r
2370 \r
2371     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;\r
2372     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;\r
2373     double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );\r
2374     int i, L = 0, R = 0;\r
2375     double best_val = init_quality, lsum = 0, rsum = 0;\r
2376     int best_subset = -1, subset_i;\r
2377 \r
2378     for( i = -1; i < mi; i++ )\r
2379         sum[i] = counts[i] = 0;\r
2380 \r
2381     // calculate sum response and weight of each category of the input var\r
2382     for( i = 0; i < n; i++ )\r
2383     {\r
2384         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];\r
2385         double s = sum[idx] + responses[i];\r
2386         int nc = counts[idx] + 1;\r
2387         sum[idx] = s;\r
2388         counts[idx] = nc;\r
2389     }\r
2390 \r
2391     // calculate average response in each category\r
2392     for( i = 0; i < mi; i++ )\r
2393     {\r
2394         R += counts[i];\r
2395         rsum += sum[i];\r
2396         sum[i] /= MAX(counts[i],1);\r
2397         sum_ptr[i] = sum + i;\r
2398     }\r
2399 \r
2400     icvSortDblPtr( sum_ptr, mi, 0 );\r
2401 \r
2402     // revert back to unnormalized sums\r
2403     // (there should be a very little loss of accuracy)\r
2404     for( i = 0; i < mi; i++ )\r
2405         sum[i] *= counts[i];\r
2406 \r
2407     for( subset_i = 0; subset_i < mi-1; subset_i++ )\r
2408     {\r
2409         int idx = (int)(sum_ptr[subset_i] - sum);\r
2410         int ni = counts[idx];\r
2411 \r
2412         if( ni )\r
2413         {\r
2414             double s = sum[idx];\r
2415             lsum += s; L += ni;\r
2416             rsum -= s; R -= ni;\r
2417 \r
2418             if( L && R )\r
2419             {\r
2420                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2421                 if( best_val < val )\r
2422                 {\r
2423                     best_val = val;\r
2424                     best_subset = subset_i;\r
2425                 }\r
2426             }\r
2427         }\r
2428     }\r
2429 \r
2430     CvDTreeSplit* split = 0;\r
2431     if( best_subset >= 0 )\r
2432     {\r
2433         split = _split ? _split : data->new_split_cat( 0, -1.0f);\r
2434         split->var_idx = vi;\r
2435         split->quality = (float)best_val;\r
2436         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));\r
2437         for( i = 0; i <= best_subset; i++ )\r
2438         {\r
2439             int idx = (int)(sum_ptr[i] - sum);\r
2440             split->subset[idx >> 5] |= 1 << (idx & 31);\r
2441         }\r
2442     }\r
2443     return split;\r
2444 }\r
2445 \r
2446 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )\r
2447 {\r
2448     const float epsilon = FLT_EPSILON*2;\r
2449     const char* dir = (char*)data->direction->data.ptr;\r
2450     int n1 = node->get_num_valid(vi);\r
2451     float* values_buf = data->get_pred_float_buf();\r
2452     const float* values = 0;\r
2453     int* indices_buf = data->get_pred_int_buf();\r
2454     const int* indices = 0;\r
2455     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2456     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2457     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2458     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2459     // RR - ... both send to the right\r
2460     int i, best_i = -1, best_inversed = 0;\r
2461     double best_val;\r
2462 \r
2463     if( !data->have_priors )\r
2464     {\r
2465         int LL = 0, RL = 0, LR, RR;\r
2466         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;\r
2467         int sum = 0, sum_abs = 0;\r
2468 \r
2469         for( i = 0; i < n1; i++ )\r
2470         {\r
2471             int d = dir[indices[i]];\r
2472             sum += d; sum_abs += d & 1;\r
2473         }\r
2474 \r
2475         // sum_abs = R + L; sum = R - L\r
2476         RR = (sum_abs + sum) >> 1;\r
2477         LR = (sum_abs - sum) >> 1;\r
2478 \r
2479         // initially all the samples are sent to the right by the surrogate split,\r
2480         // LR of them are sent to the left by primary split, and RR - to the right.\r
2481         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2482         for( i = 0; i < n1 - 1; i++ )\r
2483         {\r
2484             int d = dir[indices[i]];\r
2485 \r
2486             if( d < 0 )\r
2487             {\r
2488                 LL++; LR--;\r
2489                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )\r
2490                 {\r
2491                     best_val = LL + RR;\r
2492                     best_i = i; best_inversed = 0;\r
2493                 }\r
2494             }\r
2495             else if( d > 0 )\r
2496             {\r
2497                 RL++; RR--;\r
2498                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )\r
2499                 {\r
2500                     best_val = RL + LR;\r
2501                     best_i = i; best_inversed = 1;\r
2502                 }\r
2503             }\r
2504         }\r
2505         best_val = _best_val;\r
2506     }\r
2507     else\r
2508     {\r
2509         double LL = 0, RL = 0, LR, RR;\r
2510         double worst_val = node->maxlr;\r
2511         double sum = 0, sum_abs = 0;\r
2512         const double* priors = data->priors_mult->data.db;\r
2513         int* responses_buf = data->get_resp_int_buf();\r
2514         const int* responses = 0;\r
2515         data->get_class_labels(node, responses_buf, &responses);\r
2516         best_val = worst_val;\r
2517 \r
2518         for( i = 0; i < n1; i++ )\r
2519         {\r
2520             int idx = indices[i];\r
2521             double w = priors[responses[idx]];\r
2522             int d = dir[idx];\r
2523             sum += d*w; sum_abs += (d & 1)*w;\r
2524         }\r
2525 \r
2526         // sum_abs = R + L; sum = R - L\r
2527         RR = (sum_abs + sum)*0.5;\r
2528         LR = (sum_abs - sum)*0.5;\r
2529 \r
2530         // initially all the samples are sent to the right by the surrogate split,\r
2531         // LR of them are sent to the left by primary split, and RR - to the right.\r
2532         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2533         for( i = 0; i < n1 - 1; i++ )\r
2534         {\r
2535             int idx = indices[i];\r
2536             double w = priors[responses[idx]];\r
2537             int d = dir[idx];\r
2538 \r
2539             if( d < 0 )\r
2540             {\r
2541                 LL += w; LR -= w;\r
2542                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )\r
2543                 {\r
2544                     best_val = LL + RR;\r
2545                     best_i = i; best_inversed = 0;\r
2546                 }\r
2547             }\r
2548             else if( d > 0 )\r
2549             {\r
2550                 RL += w; RR -= w;\r
2551                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )\r
2552                 {\r
2553                     best_val = RL + LR;\r
2554                     best_i = i; best_inversed = 1;\r
2555                 }\r
2556             }\r
2557         }\r
2558     }\r
2559     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,\r
2560         (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;\r
2561 }\r
2562 \r
2563 \r
2564 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )\r
2565 {\r
2566     const char* dir = (char*)data->direction->data.ptr;\r
2567     int n = node->sample_count;\r
2568     int* labels_buf = data->get_pred_int_buf();\r
2569     const int* labels = 0;\r
2570     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2571     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2572     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2573     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2574     // RR - ... both send to the right\r
2575     CvDTreeSplit* split = data->new_split_cat( vi, 0 );\r
2576     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;\r
2577     double best_val = 0;\r
2578     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;\r
2579     double* rc = lc + mi + 1;\r
2580 \r
2581     for( i = -1; i < mi; i++ )\r
2582         lc[i] = rc[i] = 0;\r
2583 \r
2584     // for each category calculate the weight of samples\r
2585     // sent to the left (lc) and to the right (rc) by the primary split\r
2586     if( !data->have_priors )\r
2587     {\r
2588         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;\r
2589         int* _rc = _lc + mi + 1;\r
2590 \r
2591         for( i = -1; i < mi; i++ )\r
2592             _lc[i] = _rc[i] = 0;\r
2593 \r
2594         for( i = 0; i < n; i++ )\r
2595         {\r
2596             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2597             int d = dir[i];\r
2598             int sum = _lc[idx] + d;\r
2599             int sum_abs = _rc[idx] + (d & 1);\r
2600             _lc[idx] = sum; _rc[idx] = sum_abs;\r
2601         }\r
2602 \r
2603         for( i = 0; i < mi; i++ )\r
2604         {\r
2605             int sum = _lc[i];\r
2606             int sum_abs = _rc[i];\r
2607             lc[i] = (sum_abs - sum) >> 1;\r
2608             rc[i] = (sum_abs + sum) >> 1;\r
2609         }\r
2610     }\r
2611     else\r
2612     {\r
2613         const double* priors = data->priors_mult->data.db;\r
2614         int* responses_buf = data->get_resp_int_buf();\r
2615         const int* responses = 0;\r
2616         data->get_class_labels(node, responses_buf, &responses);\r
2617 \r
2618         for( i = 0; i < n; i++ )\r
2619         {\r
2620             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2621             double w = priors[responses[i]];\r
2622             int d = dir[i];\r
2623             double sum = lc[idx] + d*w;\r
2624             double sum_abs = rc[idx] + (d & 1)*w;\r
2625             lc[idx] = sum; rc[idx] = sum_abs;\r
2626         }\r
2627 \r
2628         for( i = 0; i < mi; i++ )\r
2629         {\r
2630             double sum = lc[i];\r
2631             double sum_abs = rc[i];\r
2632             lc[i] = (sum_abs - sum) * 0.5;\r
2633             rc[i] = (sum_abs + sum) * 0.5;\r
2634         }\r
2635     }\r
2636 \r
2637     // 2. now form the split.\r
2638     // in each category send all the samples to the same direction as majority\r
2639     for( i = 0; i < mi; i++ )\r
2640     {\r
2641         double lval = lc[i], rval = rc[i];\r
2642         if( lval > rval )\r
2643         {\r
2644             split->subset[i >> 5] |= 1 << (i & 31);\r
2645             best_val += lval;\r
2646             l_win++;\r
2647         }\r
2648         else\r
2649             best_val += rval;\r
2650     }\r
2651 \r
2652     split->quality = (float)best_val;\r
2653     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )\r
2654         cvSetRemoveByPtr( data->split_heap, split ), split = 0;\r
2655 \r
2656     return split;\r
2657 }\r
2658 \r
2659 \r
2660 void CvDTree::calc_node_value( CvDTreeNode* node )\r
2661 {\r
2662     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;\r
2663     int* cv_labels_buf = data->get_cv_lables_buf();\r
2664     const int* cv_labels = 0;\r
2665     data->get_cv_labels(node, cv_labels_buf, &cv_labels);\r
2666 \r
2667     if( data->is_classifier )\r
2668     {\r
2669         // in case of classification tree:\r
2670         //  * node value is the label of the class that has the largest weight in the node.\r
2671         //  * node risk is the weighted number of misclassified samples,\r
2672         //  * j-th cross-validation fold value and risk are calculated as above,\r
2673         //    but using the samples with cv_labels(*)!=j.\r
2674         //  * j-th cross-validation fold error is calculated as the weighted number of\r
2675         //    misclassified samples with cv_labels(*)==j.\r
2676 \r
2677         // compute the number of instances of each class\r
2678         int* cls_count = data->counts->data.i;\r
2679         int* responses_buf = data->get_resp_int_buf();\r
2680         const int* responses = 0;\r
2681         data->get_class_labels(node, responses_buf, &responses);\r
2682         int m = data->get_num_classes();\r
2683         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));\r
2684         double max_val = -1, total_weight = 0;\r
2685         int max_k = -1;\r
2686         double* priors = data->priors_mult->data.db;\r
2687 \r
2688         for( k = 0; k < m; k++ )\r
2689             cls_count[k] = 0;\r
2690 \r
2691         if( cv_n == 0 )\r
2692         {\r
2693             for( i = 0; i < n; i++ )\r
2694                 cls_count[responses[i]]++;\r
2695         }\r
2696         else\r
2697         {\r
2698             for( j = 0; j < cv_n; j++ )\r
2699                 for( k = 0; k < m; k++ )\r
2700                     cv_cls_count[j*m + k] = 0;\r
2701 \r
2702             for( i = 0; i < n; i++ )\r
2703             {\r
2704                 j = cv_labels[i]; k = responses[i];\r
2705                 cv_cls_count[j*m + k]++;\r
2706             }\r
2707 \r
2708             for( j = 0; j < cv_n; j++ )\r
2709                 for( k = 0; k < m; k++ )\r
2710                     cls_count[k] += cv_cls_count[j*m + k];\r
2711         }\r
2712 \r
2713         if( data->have_priors && node->parent == 0 )\r
2714         {\r
2715             // compute priors_mult from priors, take the sample ratio into account.\r
2716             double sum = 0;\r
2717             for( k = 0; k < m; k++ )\r
2718             {\r
2719                 int n_k = cls_count[k];\r
2720                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);\r
2721                 sum += priors[k];\r
2722             }\r
2723             sum = 1./sum;\r
2724             for( k = 0; k < m; k++ )\r
2725                 priors[k] *= sum;\r
2726         }\r
2727 \r
2728         for( k = 0; k < m; k++ )\r
2729         {\r
2730             double val = cls_count[k]*priors[k];\r
2731             total_weight += val;\r
2732             if( max_val < val )\r
2733             {\r
2734                 max_val = val;\r
2735                 max_k = k;\r
2736             }\r
2737         }\r
2738 \r
2739         node->class_idx = max_k;\r
2740         node->value = data->cat_map->data.i[\r
2741             data->cat_ofs->data.i[data->cat_var_count] + max_k];\r
2742         node->node_risk = total_weight - max_val;\r
2743 \r
2744         for( j = 0; j < cv_n; j++ )\r
2745         {\r
2746             double sum_k = 0, sum = 0, max_val_k = 0;\r
2747             max_val = -1; max_k = -1;\r
2748 \r
2749             for( k = 0; k < m; k++ )\r
2750             {\r
2751                 double w = priors[k];\r
2752                 double val_k = cv_cls_count[j*m + k]*w;\r
2753                 double val = cls_count[k]*w - val_k;\r
2754                 sum_k += val_k;\r
2755                 sum += val;\r
2756                 if( max_val < val )\r
2757                 {\r
2758                     max_val = val;\r
2759                     max_val_k = val_k;\r
2760                     max_k = k;\r
2761                 }\r
2762             }\r
2763 \r
2764             node->cv_Tn[j] = INT_MAX;\r
2765             node->cv_node_risk[j] = sum - max_val;\r
2766             node->cv_node_error[j] = sum_k - max_val_k;\r
2767         }\r
2768     }\r
2769     else\r
2770     {\r
2771         // in case of regression tree:\r
2772         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,\r
2773         //    n is the number of samples in the node.\r
2774         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)\r
2775         //  * j-th cross-validation fold value and risk are calculated as above,\r
2776         //    but using the samples with cv_labels(*)!=j.\r
2777         //  * j-th cross-validation fold error is calculated\r
2778         //    using samples with cv_labels(*)==j as the test subset:\r
2779         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),\r
2780         //    where node_value_j is the node value calculated\r
2781         //    as described in the previous bullet, and summation is done\r
2782         //    over the samples with cv_labels(*)==j.\r
2783 \r
2784         double sum = 0, sum2 = 0;\r
2785         float* values_buf = data->get_resp_float_buf();\r
2786         const float* values = 0;\r
2787         data->get_ord_responses(node, values_buf, &values);\r
2788         double *cv_sum = 0, *cv_sum2 = 0;\r
2789         int* cv_count = 0;\r
2790 \r
2791         if( cv_n == 0 )\r
2792         {\r
2793             for( i = 0; i < n; i++ )\r
2794             {\r
2795                 double t = values[i];\r
2796                 sum += t;\r
2797                 sum2 += t*t;\r
2798             }\r
2799         }\r
2800         else\r
2801         {\r
2802             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );\r
2803             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );\r
2804             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );\r
2805 \r
2806             for( j = 0; j < cv_n; j++ )\r
2807             {\r
2808                 cv_sum[j] = cv_sum2[j] = 0.;\r
2809                 cv_count[j] = 0;\r
2810             }\r
2811 \r
2812             for( i = 0; i < n; i++ )\r
2813             {\r
2814                 j = cv_labels[i];\r
2815                 double t = values[i];\r
2816                 double s = cv_sum[j] + t;\r
2817                 double s2 = cv_sum2[j] + t*t;\r
2818                 int nc = cv_count[j] + 1;\r
2819                 cv_sum[j] = s;\r
2820                 cv_sum2[j] = s2;\r
2821                 cv_count[j] = nc;\r
2822             }\r
2823 \r
2824             for( j = 0; j < cv_n; j++ )\r
2825             {\r
2826                 sum += cv_sum[j];\r
2827                 sum2 += cv_sum2[j];\r
2828             }\r
2829         }\r
2830 \r
2831         node->node_risk = sum2 - (sum/n)*sum;\r
2832         node->value = sum/n;\r
2833 \r
2834         for( j = 0; j < cv_n; j++ )\r
2835         {\r
2836             double s = cv_sum[j], si = sum - s;\r
2837             double s2 = cv_sum2[j], s2i = sum2 - s2;\r
2838             int c = cv_count[j], ci = n - c;\r
2839             double r = si/MAX(ci,1);\r
2840             node->cv_node_risk[j] = s2i - r*r*ci;\r
2841             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;\r
2842             node->cv_Tn[j] = INT_MAX;\r
2843         }\r
2844     }\r
2845 }\r
2846 \r
2847 \r
2848 void CvDTree::complete_node_dir( CvDTreeNode* node )\r
2849 {\r
2850     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;\r
2851     int nz = n - node->get_num_valid(node->split->var_idx);\r
2852     char* dir = (char*)data->direction->data.ptr;\r
2853 \r
2854     // try to complete direction using surrogate splits\r
2855     if( nz && data->params.use_surrogates )\r
2856     {\r
2857         CvDTreeSplit* split = node->split->next;\r
2858         for( ; split != 0 && nz; split = split->next )\r
2859         {\r
2860             int inversed_mask = split->inversed ? -1 : 0;\r
2861             vi = split->var_idx;\r
2862 \r
2863             if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
2864             {\r
2865                 int* labels_buf = data->get_pred_int_buf();\r
2866                 const int* labels = 0;\r
2867                 data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2868                 const int* subset = split->subset;\r
2869 \r
2870                 for( i = 0; i < n; i++ )\r
2871                 {\r
2872                     int idx = labels[i];\r
2873                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))\r
2874                         \r
2875                     {\r
2876                         int d = CV_DTREE_CAT_DIR(idx,subset);\r
2877                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
2878                         if( --nz )\r
2879                             break;\r
2880                     }\r
2881                 }\r
2882             }\r
2883             else // split on ordered var\r
2884             {\r
2885                 float* values_buf = data->get_pred_float_buf();\r
2886                 const float* values = 0;\r
2887                 int* indices_buf = data->get_pred_int_buf();\r
2888                 const int* indices = 0;\r
2889                 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2890                 int split_point = split->ord.split_point;\r
2891                 int n1 = node->get_num_valid(vi);\r
2892 \r
2893                 assert( 0 <= split_point && split_point < n-1 );\r
2894 \r
2895                 for( i = 0; i < n1; i++ )\r
2896                 {\r
2897                     int idx = indices[i];\r
2898                     if( !dir[idx] )\r
2899                     {\r
2900                         int d = i <= split_point ? -1 : 1;\r
2901                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);\r
2902                         if( --nz )\r
2903                             break;\r
2904                     }\r
2905                 }\r
2906             }\r
2907         }\r
2908     }\r
2909 \r
2910     // find the default direction for the rest\r
2911     if( nz )\r
2912     {\r
2913         for( i = nr = 0; i < n; i++ )\r
2914             nr += dir[i] > 0;\r
2915         nl = n - nr - nz;\r
2916         d0 = nl > nr ? -1 : nr > nl;\r
2917     }\r
2918 \r
2919     // make sure that every sample is directed either to the left or to the right\r
2920     for( i = 0; i < n; i++ )\r
2921     {\r
2922         int d = dir[i];\r
2923         if( !d )\r
2924         {\r
2925             d = d0;\r
2926             if( !d )\r
2927                 d = d1, d1 = -d1;\r
2928         }\r
2929         d = d > 0;\r
2930         dir[i] = (char)d; // remap (-1,1) to (0,1)\r
2931     }\r
2932 }\r
2933 \r
2934 \r
2935 void CvDTree::split_node_data( CvDTreeNode* node )\r
2936 {\r
2937     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;\r
2938     char* dir = (char*)data->direction->data.ptr;\r
2939     CvDTreeNode *left = 0, *right = 0;\r
2940     int* new_idx = data->split_buf->data.i;\r
2941     int new_buf_idx = data->get_child_buf_idx( node );\r
2942     int work_var_count = data->get_work_var_count();\r
2943     CvMat* buf = data->buf;\r
2944     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));\r
2945 \r
2946     complete_node_dir(node);\r
2947 \r
2948     for( i = nl = nr = 0; i < n; i++ )\r
2949     {\r
2950         int d = dir[i];\r
2951         // initialize new indices for splitting ordered variables\r
2952         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li\r
2953         nr += d;\r
2954         nl += d^1;\r
2955     }\r
2956 \r
2957 \r
2958     bool split_input_data;\r
2959     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );\r
2960     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );\r
2961 \r
2962     split_input_data = node->depth + 1 < data->params.max_depth &&\r
2963         (node->left->sample_count > data->params.min_sample_count ||\r
2964         node->right->sample_count > data->params.min_sample_count);\r
2965 \r
2966     // split ordered variables, keep both halves sorted.\r
2967     for( vi = 0; vi < data->var_count; vi++ )\r
2968     {\r
2969         int ci = data->get_var_type(vi);\r
2970         int n1 = node->get_num_valid(vi);\r
2971         int *src_idx_buf = data->get_pred_int_buf();\r
2972         const int* src_idx = 0;\r
2973         float *src_val_buf = data->get_pred_float_buf();\r
2974         const float* src_val = 0;\r
2975         \r
2976         if( ci >= 0 || !split_input_data )\r
2977             continue;\r
2978 \r
2979         data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
2980 \r
2981         for(i = 0; i < n; i++)\r
2982             temp_buf[i] = src_idx[i];\r
2983 \r
2984         if (data->is_buf_16u)\r
2985         {\r
2986             unsigned short *ldst, *rdst, *ldst0, *rdst0;\r
2987             //unsigned short tl, tr;\r
2988             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
2989                 vi*scount + left->offset);\r
2990             rdst0 = rdst = (unsigned short*)(ldst + nl);\r
2991 \r
2992             // split sorted\r
2993             for( i = 0; i < n1; i++ )\r
2994             {\r
2995                 int idx = temp_buf[i];\r
2996                 int d = dir[idx];\r
2997                 idx = new_idx[idx];\r
2998                 if (d)\r
2999                 {\r
3000                     *rdst = (unsigned short)idx;\r
3001                     rdst++;\r
3002                 }\r
3003                 else\r
3004                 {\r
3005                     *ldst = (unsigned short)idx;\r
3006                     ldst++;\r
3007                 }\r
3008             }\r
3009 \r
3010             left->set_num_valid(vi, (int)(ldst - ldst0));\r
3011             right->set_num_valid(vi, (int)(rdst - rdst0));\r
3012 \r
3013             // split missing\r
3014             for( ; i < n; i++ )\r
3015             {\r
3016                 int idx = temp_buf[i];\r
3017                 int d = dir[idx];\r
3018                 idx = new_idx[idx];\r
3019                 if (d)\r
3020                 {\r
3021                     *rdst = (unsigned short)idx;\r
3022                     rdst++;\r
3023                 }\r
3024                 else\r
3025                 {\r
3026                     *ldst = (unsigned short)idx;\r
3027                     ldst++;\r
3028                 }\r
3029             }\r
3030         }\r
3031         else\r
3032         {\r
3033             int *ldst0, *ldst, *rdst0, *rdst;\r
3034             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + \r
3035                 vi*scount + left->offset;\r
3036             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + \r
3037                 vi*scount + right->offset;\r
3038 \r
3039             // split sorted\r
3040             for( i = 0; i < n1; i++ )\r
3041             {\r
3042                 int idx = temp_buf[i];\r
3043                 int d = dir[idx];\r
3044                 idx = new_idx[idx];\r
3045                 if (d)\r
3046                 {\r
3047                     *rdst = idx;\r
3048                     rdst++;\r
3049                 }\r
3050                 else\r
3051                 {\r
3052                     *ldst = idx;\r
3053                     ldst++;\r
3054                 }\r
3055             }\r
3056 \r
3057             left->set_num_valid(vi, (int)(ldst - ldst0));\r
3058             right->set_num_valid(vi, (int)(rdst - rdst0));\r
3059 \r
3060             // split missing\r
3061             for( ; i < n; i++ )\r
3062             {\r
3063                 int idx = temp_buf[i];\r
3064                 int d = dir[idx];\r
3065                 idx = new_idx[idx];\r
3066                 if (d)\r
3067                 {\r
3068                     *rdst = idx;\r
3069                     rdst++;\r
3070                 }\r
3071                 else\r
3072                 {\r
3073                     *ldst = idx;\r
3074                     ldst++;\r
3075                 }\r
3076             }\r
3077         }\r
3078     }\r
3079 \r
3080     // split categorical vars, responses and cv_labels using new_idx relocation table\r
3081     for( vi = 0; vi < work_var_count; vi++ )\r
3082     {\r
3083         int ci = data->get_var_type(vi);\r
3084         int n1 = node->get_num_valid(vi), nr1 = 0;\r
3085         \r
3086         if( ci < 0 || (vi < data->var_count && !split_input_data) )\r
3087             continue;\r
3088 \r
3089         int *src_lbls_buf = data->get_pred_int_buf();\r
3090         const int* src_lbls = 0;\r
3091         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);\r
3092 \r
3093         for(i = 0; i < n; i++)\r
3094             temp_buf[i] = src_lbls[i];\r
3095 \r
3096         if (data->is_buf_16u)\r
3097         {\r
3098             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + \r
3099                 vi*scount + left->offset);\r
3100             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + \r
3101                 vi*scount + right->offset);\r
3102             \r
3103             for( i = 0; i < n; i++ )\r
3104             {\r
3105                 int d = dir[i];\r
3106                 int idx = temp_buf[i];\r
3107                 if (d)\r
3108                 {\r
3109                     *rdst = (unsigned short)idx;\r
3110                     rdst++;\r
3111                     nr1 += (idx != 65535 )&d;\r
3112                 }\r
3113                 else\r
3114                 {\r
3115                     *ldst = (unsigned short)idx;\r
3116                     ldst++;\r
3117                 }\r
3118             }\r
3119 \r
3120             if( vi < data->var_count )\r
3121             {\r
3122                 left->set_num_valid(vi, n1 - nr1);\r
3123                 right->set_num_valid(vi, nr1);\r
3124             }\r
3125         }\r
3126         else\r
3127         {\r
3128             int *ldst = buf->data.i + left->buf_idx*buf->cols + \r
3129                 vi*scount + left->offset;\r
3130             int *rdst = buf->data.i + right->buf_idx*buf->cols + \r
3131                 vi*scount + right->offset;\r
3132             \r
3133             for( i = 0; i < n; i++ )\r
3134             {\r
3135                 int d = dir[i];\r
3136                 int idx = temp_buf[i];\r
3137                 if (d)\r
3138                 {\r
3139                     *rdst = idx;\r
3140                     rdst++;\r
3141                     nr1 += (idx >= 0)&d;\r
3142                 }\r
3143                 else\r
3144                 {\r
3145                     *ldst = idx;\r
3146                     ldst++;\r
3147                 }\r
3148                 \r
3149             }\r
3150 \r
3151             if( vi < data->var_count )\r
3152             {\r
3153                 left->set_num_valid(vi, n1 - nr1);\r
3154                 right->set_num_valid(vi, nr1);\r
3155             }\r
3156         }        \r
3157     }\r
3158 \r
3159 \r
3160     // split sample indices\r
3161     int *sample_idx_src_buf = data->get_sample_idx_buf();\r
3162     const int* sample_idx_src = 0;\r
3163     data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);\r
3164 \r
3165     for(i = 0; i < n; i++)\r
3166         temp_buf[i] = sample_idx_src[i];\r
3167 \r
3168     int pos = data->get_work_var_count();\r
3169     if (data->is_buf_16u)\r
3170     {\r
3171         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
3172             pos*scount + left->offset);\r
3173         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + \r
3174             pos*scount + right->offset);\r
3175         for (i = 0; i < n; i++)\r
3176         {\r
3177             int d = dir[i];\r
3178             unsigned short idx = (unsigned short)temp_buf[i];\r
3179             if (d)\r
3180             {\r
3181                 *rdst = idx;\r
3182                 rdst++;\r
3183             }\r
3184             else\r
3185             {\r
3186                 *ldst = idx;\r
3187                 ldst++;\r
3188             }\r
3189         }\r
3190     }\r
3191     else\r
3192     {\r
3193         int* ldst = buf->data.i + left->buf_idx*buf->cols + \r
3194             pos*scount + left->offset;\r
3195         int* rdst = buf->data.i + right->buf_idx*buf->cols + \r
3196             pos*scount + right->offset;\r
3197         for (i = 0; i < n; i++)\r
3198         {\r
3199             int d = dir[i];\r
3200             int idx = temp_buf[i];\r
3201             if (d)\r
3202             {\r
3203                 *rdst = idx;\r
3204                 rdst++;\r
3205             }\r
3206             else\r
3207             {\r
3208                 *ldst = idx;\r
3209                 ldst++;\r
3210             }\r
3211         }\r
3212     }\r
3213     \r
3214     // deallocate the parent node data that is not needed anymore\r
3215     data->free_node_data(node);    \r
3216 }\r
3217 \r
3218 float CvDTree::calc_error( CvMLData* _data, int type )\r
3219 {\r
3220     float err = 0;\r
3221     const CvMat* values = _data->get_values();\r
3222     const CvMat* response = _data->get_response();\r
3223     const CvMat* missing = _data->get_missing();\r
3224     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
3225     const CvMat* var_types = _data->get_var_types();\r
3226     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
3227     int r_step = CV_IS_MAT_CONT(response->type) ?\r
3228                 1 : response->step / CV_ELEM_SIZE(response->type);\r
3229     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;\r
3230     int sample_count = sample_idx ? sample_idx->cols : 0;\r
3231     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
3232     if ( is_classifier )\r
3233     {\r
3234         for( int i = 0; i < sample_count; i++ )\r
3235         {\r
3236             CvMat sample, miss;\r
3237             int si = sidx ? sidx[i] : i;\r
3238             cvGetRow( values, &sample, si ); \r
3239             if( missing ) \r
3240                 cvGetRow( missing, &miss, si );             \r
3241             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3242             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
3243             err += d;\r
3244         }\r
3245         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
3246     }\r
3247     else\r
3248     {\r
3249         for( int i = 0; i < sample_count; i++ )\r
3250         {\r
3251             CvMat sample, miss;\r
3252             int si = sidx ? sidx[i] : i;\r
3253             cvGetRow( values, &sample, si ); \r
3254             if( missing ) \r
3255                 cvGetRow( missing, &miss, si );             \r
3256             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3257             float d = r - response->data.fl[si*r_step];\r
3258             err += d*d;\r
3259         }\r
3260         err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
3261     }\r
3262     return err;\r
3263 }\r
3264 \r
3265 void CvDTree::prune_cv()\r
3266 {\r
3267     CvMat* ab = 0;\r
3268     CvMat* temp = 0;\r
3269     CvMat* err_jk = 0;\r
3270 \r
3271     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.\r
3272     // 2. choose the best tree index (if need, apply 1SE rule).\r
3273     // 3. store the best index and cut the branches.\r
3274 \r
3275     CV_FUNCNAME( "CvDTree::prune_cv" );\r
3276 \r
3277     __BEGIN__;\r
3278 \r
3279     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;\r
3280     // currently, 1SE for regression is not implemented\r
3281     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;\r
3282     double* err;\r
3283     double min_err = 0, min_err_se = 0;\r
3284     int min_idx = -1;\r
3285 \r
3286     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));\r
3287 \r
3288     // build the main tree sequence, calculate alpha's\r
3289     for(;;tree_count++)\r
3290     {\r
3291         double min_alpha = update_tree_rnc(tree_count, -1);\r
3292         if( cut_tree(tree_count, -1, min_alpha) )\r
3293             break;\r
3294 \r
3295         if( ab->cols <= tree_count )\r
3296         {\r
3297             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));\r
3298             for( ti = 0; ti < ab->cols; ti++ )\r
3299                 temp->data.db[ti] = ab->data.db[ti];\r
3300             cvReleaseMat( &ab );\r
3301             ab = temp;\r
3302             temp = 0;\r
3303         }\r
3304 \r
3305         ab->data.db[tree_count] = min_alpha;\r
3306     }\r
3307 \r
3308     ab->data.db[0] = 0.;\r
3309 \r
3310     if( tree_count > 0 )\r
3311     {\r
3312         for( ti = 1; ti < tree_count-1; ti++ )\r
3313             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);\r
3314         ab->data.db[tree_count-1] = DBL_MAX*0.5;\r
3315 \r
3316         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));\r
3317         err = err_jk->data.db;\r
3318 \r
3319         for( j = 0; j < cv_n; j++ )\r
3320         {\r
3321             int tj = 0, tk = 0;\r
3322             for( ; tk < tree_count; tj++ )\r
3323             {\r
3324                 double min_alpha = update_tree_rnc(tj, j);\r
3325                 if( cut_tree(tj, j, min_alpha) )\r
3326                     min_alpha = DBL_MAX;\r
3327 \r
3328                 for( ; tk < tree_count; tk++ )\r
3329                 {\r
3330                     if( ab->data.db[tk] > min_alpha )\r
3331                         break;\r
3332                     err[j*tree_count + tk] = root->tree_error;\r
3333                 }\r
3334             }\r
3335         }\r
3336 \r
3337         for( ti = 0; ti < tree_count; ti++ )\r
3338         {\r
3339             double sum_err = 0;\r
3340             for( j = 0; j < cv_n; j++ )\r
3341                 sum_err += err[j*tree_count + ti];\r
3342             if( ti == 0 || sum_err < min_err )\r
3343             {\r
3344                 min_err = sum_err;\r
3345                 min_idx = ti;\r
3346                 if( use_1se )\r
3347                     min_err_se = sqrt( sum_err*(n - sum_err) );\r
3348             }\r
3349             else if( sum_err < min_err + min_err_se )\r
3350                 min_idx = ti;\r
3351         }\r
3352     }\r
3353 \r
3354     pruned_tree_idx = min_idx;\r
3355     free_prune_data(data->params.truncate_pruned_tree != 0);\r
3356 \r
3357     __END__;\r
3358 \r
3359     cvReleaseMat( &err_jk );\r
3360     cvReleaseMat( &ab );\r
3361     cvReleaseMat( &temp );\r
3362 }\r
3363 \r
3364 \r
3365 double CvDTree::update_tree_rnc( int T, int fold )\r
3366 {\r
3367     CvDTreeNode* node = root;\r
3368     double min_alpha = DBL_MAX;\r
3369 \r
3370     for(;;)\r
3371     {\r
3372         CvDTreeNode* parent;\r
3373         for(;;)\r
3374         {\r
3375             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3376             if( t <= T || !node->left )\r
3377             {\r
3378                 node->complexity = 1;\r
3379                 node->tree_risk = node->node_risk;\r
3380                 node->tree_error = 0.;\r
3381                 if( fold >= 0 )\r
3382                 {\r
3383                     node->tree_risk = node->cv_node_risk[fold];\r
3384                     node->tree_error = node->cv_node_error[fold];\r
3385                 }\r
3386                 break;\r
3387             }\r
3388             node = node->left;\r
3389         }\r
3390 \r
3391         for( parent = node->parent; parent && parent->right == node;\r
3392             node = parent, parent = parent->parent )\r
3393         {\r
3394             parent->complexity += node->complexity;\r
3395             parent->tree_risk += node->tree_risk;\r
3396             parent->tree_error += node->tree_error;\r
3397 \r
3398             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)\r
3399                 - parent->tree_risk)/(parent->complexity - 1);\r
3400             min_alpha = MIN( min_alpha, parent->alpha );\r
3401         }\r
3402 \r
3403         if( !parent )\r
3404             break;\r
3405 \r
3406         parent->complexity = node->complexity;\r
3407         parent->tree_risk = node->tree_risk;\r
3408         parent->tree_error = node->tree_error;\r
3409         node = parent->right;\r
3410     }\r
3411 \r
3412     return min_alpha;\r
3413 }\r
3414 \r
3415 \r
3416 int CvDTree::cut_tree( int T, int fold, double min_alpha )\r
3417 {\r
3418     CvDTreeNode* node = root;\r
3419     if( !node->left )\r
3420         return 1;\r
3421 \r
3422     for(;;)\r
3423     {\r
3424         CvDTreeNode* parent;\r
3425         for(;;)\r
3426         {\r
3427             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3428             if( t <= T || !node->left )\r
3429                 break;\r
3430             if( node->alpha <= min_alpha + FLT_EPSILON )\r
3431             {\r
3432                 if( fold >= 0 )\r
3433                     node->cv_Tn[fold] = T;\r
3434                 else\r
3435                     node->Tn = T;\r
3436                 if( node == root )\r
3437                     return 1;\r
3438                 break;\r
3439             }\r
3440             node = node->left;\r
3441         }\r
3442 \r
3443         for( parent = node->parent; parent && parent->right == node;\r
3444             node = parent, parent = parent->parent )\r
3445             ;\r
3446 \r
3447         if( !parent )\r
3448             break;\r
3449 \r
3450         node = parent->right;\r
3451     }\r
3452 \r
3453     return 0;\r
3454 }\r
3455 \r
3456 \r
3457 void CvDTree::free_prune_data(bool cut_tree)\r
3458 {\r
3459     CvDTreeNode* node = root;\r
3460 \r
3461     for(;;)\r
3462     {\r
3463         CvDTreeNode* parent;\r
3464         for(;;)\r
3465         {\r
3466             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )\r
3467             // as we will clear the whole cross-validation heap at the end\r
3468             node->cv_Tn = 0;\r
3469             node->cv_node_error = node->cv_node_risk = 0;\r
3470             if( !node->left )\r
3471                 break;\r
3472             node = node->left;\r
3473         }\r
3474 \r
3475         for( parent = node->parent; parent && parent->right == node;\r
3476             node = parent, parent = parent->parent )\r
3477         {\r
3478             if( cut_tree && parent->Tn <= pruned_tree_idx )\r
3479             {\r
3480                 data->free_node( parent->left );\r
3481                 data->free_node( parent->right );\r
3482                 parent->left = parent->right = 0;\r
3483             }\r
3484         }\r
3485 \r
3486         if( !parent )\r
3487             break;\r
3488 \r
3489         node = parent->right;\r
3490     }\r
3491 \r
3492     if( data->cv_heap )\r
3493         cvClearSet( data->cv_heap );\r
3494 }\r
3495 \r
3496 \r
3497 void CvDTree::free_tree()\r
3498 {\r
3499     if( root && data && data->shared )\r
3500     {\r
3501         pruned_tree_idx = INT_MIN;\r
3502         free_prune_data(true);\r
3503         data->free_node(root);\r
3504         root = 0;\r
3505     }\r
3506 }\r
3507 \r
3508 CvDTreeNode* CvDTree::predict( const CvMat* _sample,\r
3509     const CvMat* _missing, bool preprocessed_input ) const\r
3510 {\r
3511     CvDTreeNode* result = 0;\r
3512     int* catbuf = 0;\r
3513 \r
3514     CV_FUNCNAME( "CvDTree::predict" );\r
3515 \r
3516     __BEGIN__;\r
3517 \r
3518     int i, step, mstep = 0;\r
3519     const float* sample;\r
3520     const uchar* m = 0;\r
3521     CvDTreeNode* node = root;\r
3522     const int* vtype;\r
3523     const int* vidx;\r
3524     const int* cmap;\r
3525     const int* cofs;\r
3526 \r
3527     if( !node )\r
3528         CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
3529 \r
3530     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
3531         (_sample->cols != 1 && _sample->rows != 1) ||\r
3532         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
3533         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
3534             CV_ERROR( CV_StsBadArg,\r
3535         "the input sample must be 1d floating-point vector with the same "\r
3536         "number of elements as the total number of variables used for training" );\r
3537 \r
3538     sample = _sample->data.fl;\r
3539     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
3540 \r
3541     if( data->cat_count && !preprocessed_input ) // cache for categorical variables\r
3542     {\r
3543         int n = data->cat_count->cols;\r
3544         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));\r
3545         for( i = 0; i < n; i++ )\r
3546             catbuf[i] = -1;\r
3547     }\r
3548 \r
3549     if( _missing )\r
3550     {\r
3551         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
3552         !CV_ARE_SIZES_EQ(_missing, _sample) )\r
3553             CV_ERROR( CV_StsBadArg,\r
3554         "the missing data mask must be 8-bit vector of the same size as input sample" );\r
3555         m = _missing->data.ptr;\r
3556         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
3557     }\r
3558 \r
3559     vtype = data->var_type->data.i;\r
3560     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
3561     cmap = data->cat_map ? data->cat_map->data.i : 0;\r
3562     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;\r
3563 \r
3564     while( node->Tn > pruned_tree_idx && node->left )\r
3565     {\r
3566         CvDTreeSplit* split = node->split;\r
3567         int dir = 0;\r
3568         for( ; !dir && split != 0; split = split->next )\r
3569         {\r
3570             int vi = split->var_idx;\r
3571             int ci = vtype[vi];\r
3572             i = vidx ? vidx[vi] : vi;\r
3573             float val = sample[i*step];\r
3574             if( m && m[i*mstep] )\r
3575                 continue;\r
3576             if( ci < 0 ) // ordered\r
3577                 dir = val <= split->ord.c ? -1 : 1;\r
3578             else // categorical\r
3579             {\r
3580                 int c;\r
3581                 if( preprocessed_input )\r
3582                     c = cvRound(val);\r
3583                 else\r
3584                 {\r
3585                     c = catbuf[ci];\r
3586                     if( c < 0 )\r
3587                     {\r
3588                         int a = c = cofs[ci];\r
3589                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];\r
3590                         \r
3591                         int ival = cvRound(val);\r
3592                         if( ival != val )\r
3593                             CV_ERROR( CV_StsBadArg,\r
3594                             "one of input categorical variable is not an integer" );\r
3595                         \r
3596                         int sh = 0;\r
3597                         while( a < b )\r
3598                         {\r
3599                             sh++;\r
3600                             c = (a + b) >> 1;\r
3601                             if( ival < cmap[c] )\r
3602                                 b = c;\r
3603                             else if( ival > cmap[c] )\r
3604                                 a = c+1;\r
3605                             else\r
3606                                 break;\r
3607                         }\r
3608 \r
3609                         if( c < 0 || ival != cmap[c] )\r
3610                             continue;\r
3611 \r
3612                         catbuf[ci] = c -= cofs[ci];\r
3613                     }\r
3614                 }\r
3615                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;\r
3616                 dir = CV_DTREE_CAT_DIR(c, split->subset);\r
3617             }\r
3618 \r
3619             if( split->inversed )\r
3620                 dir = -dir;\r
3621         }\r
3622 \r
3623         if( !dir )\r
3624         {\r
3625             double diff = node->right->sample_count - node->left->sample_count;\r
3626             dir = diff < 0 ? -1 : 1;\r
3627         }\r
3628         node = dir < 0 ? node->left : node->right;\r
3629     }\r
3630 \r
3631     result = node;\r
3632 \r
3633     __END__;\r
3634 \r
3635     return result;\r
3636 }\r
3637 \r
3638 \r
3639 const CvMat* CvDTree::get_var_importance()\r
3640 {\r
3641     if( !var_importance )\r
3642     {\r
3643         CvDTreeNode* node = root;\r
3644         double* importance;\r
3645         if( !node )\r
3646             return 0;\r
3647         var_importance = cvCreateMat( 1, data->var_count, CV_64F );\r
3648         cvZero( var_importance );\r
3649         importance = var_importance->data.db;\r
3650 \r
3651         for(;;)\r
3652         {\r
3653             CvDTreeNode* parent;\r
3654             for( ;; node = node->left )\r
3655             {\r
3656                 CvDTreeSplit* split = node->split;\r
3657 \r
3658                 if( !node->left || node->Tn <= pruned_tree_idx )\r
3659                     break;\r
3660 \r
3661                 for( ; split != 0; split = split->next )\r
3662                     importance[split->var_idx] += split->quality;\r
3663             }\r
3664 \r
3665             for( parent = node->parent; parent && parent->right == node;\r
3666                 node = parent, parent = parent->parent )\r
3667                 ;\r
3668 \r
3669             if( !parent )\r
3670                 break;\r
3671 \r
3672             node = parent->right;\r
3673         }\r
3674 \r
3675         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );\r
3676     }\r
3677 \r
3678     return var_importance;\r
3679 }\r
3680 \r
3681 \r
3682 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )\r
3683 {\r
3684     int ci;\r
3685 \r
3686     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );\r
3687     cvWriteInt( fs, "var", split->var_idx );\r
3688     cvWriteReal( fs, "quality", split->quality );\r
3689 \r
3690     ci = data->get_var_type(split->var_idx);\r
3691     if( ci >= 0 ) // split on a categorical var\r
3692     {\r
3693         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;\r
3694         for( i = 0; i < n; i++ )\r
3695             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;\r
3696 \r
3697         // ad-hoc rule when to use inverse categorical split notation\r
3698         // to achieve more compact and clear representation\r
3699         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;\r
3700 \r
3701         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?\r
3702                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );\r
3703 \r
3704         for( i = 0; i < n; i++ )\r
3705         {\r
3706             int dir = CV_DTREE_CAT_DIR(i,split->subset);\r
3707             if( dir*default_dir < 0 )\r
3708                 cvWriteInt( fs, 0, i );\r
3709         }\r
3710         cvEndWriteStruct( fs );\r
3711     }\r
3712     else\r
3713         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );\r
3714 \r
3715     cvEndWriteStruct( fs );\r
3716 }\r
3717 \r
3718 \r
3719 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )\r
3720 {\r
3721     CvDTreeSplit* split;\r
3722 \r
3723     cvStartWriteStruct( fs, 0, CV_NODE_MAP );\r
3724 \r
3725     cvWriteInt( fs, "depth", node->depth );\r
3726     cvWriteInt( fs, "sample_count", node->sample_count );\r
3727     cvWriteReal( fs, "value", node->value );\r
3728 \r
3729     if( data->is_classifier )\r
3730         cvWriteInt( fs, "norm_class_idx", node->class_idx );\r
3731 \r
3732     cvWriteInt( fs, "Tn", node->Tn );\r
3733     cvWriteInt( fs, "complexity", node->complexity );\r
3734     cvWriteReal( fs, "alpha", node->alpha );\r
3735     cvWriteReal( fs, "node_risk", node->node_risk );\r
3736     cvWriteReal( fs, "tree_risk", node->tree_risk );\r
3737     cvWriteReal( fs, "tree_error", node->tree_error );\r
3738 \r
3739     if( node->left )\r
3740     {\r
3741         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );\r
3742 \r
3743         for( split = node->split; split != 0; split = split->next )\r
3744             write_split( fs, split );\r
3745 \r
3746         cvEndWriteStruct( fs );\r
3747     }\r
3748 \r
3749     cvEndWriteStruct( fs );\r
3750 }\r
3751 \r
3752 \r
3753 void CvDTree::write_tree_nodes( CvFileStorage* fs )\r
3754 {\r
3755     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );\r
3756 \r
3757     __BEGIN__;\r
3758 \r
3759     CvDTreeNode* node = root;\r
3760 \r
3761     // traverse the tree and save all the nodes in depth-first order\r
3762     for(;;)\r
3763     {\r
3764         CvDTreeNode* parent;\r
3765         for(;;)\r
3766         {\r
3767             write_node( fs, node );\r
3768             if( !node->left )\r
3769                 break;\r
3770             node = node->left;\r
3771         }\r
3772 \r
3773         for( parent = node->parent; parent && parent->right == node;\r
3774             node = parent, parent = parent->parent )\r
3775             ;\r
3776 \r
3777         if( !parent )\r
3778             break;\r
3779 \r
3780         node = parent->right;\r
3781     }\r
3782 \r
3783     __END__;\r
3784 }\r
3785 \r
3786 \r
3787 void CvDTree::write( CvFileStorage* fs, const char* name )\r
3788 {\r
3789     //CV_FUNCNAME( "CvDTree::write" );\r
3790 \r
3791     __BEGIN__;\r
3792 \r
3793     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );\r
3794 \r
3795     get_var_importance();\r
3796     data->write_params( fs );\r
3797     if( var_importance )\r
3798         cvWrite( fs, "var_importance", var_importance );\r
3799     write( fs );\r
3800 \r
3801     cvEndWriteStruct( fs );\r
3802 \r
3803     __END__;\r
3804 }\r
3805 \r
3806 \r
3807 void CvDTree::write( CvFileStorage* fs )\r
3808 {\r
3809     //CV_FUNCNAME( "CvDTree::write" );\r
3810 \r
3811     __BEGIN__;\r
3812 \r
3813     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );\r
3814 \r
3815     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );\r
3816     write_tree_nodes( fs );\r
3817     cvEndWriteStruct( fs );\r
3818 \r
3819     __END__;\r
3820 }\r
3821 \r
3822 \r
3823 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )\r
3824 {\r
3825     CvDTreeSplit* split = 0;\r
3826 \r
3827     CV_FUNCNAME( "CvDTree::read_split" );\r
3828 \r
3829     __BEGIN__;\r
3830 \r
3831     int vi, ci;\r
3832 \r
3833     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3834         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );\r
3835 \r
3836     vi = cvReadIntByName( fs, fnode, "var", -1 );\r
3837     if( (unsigned)vi >= (unsigned)data->var_count )\r
3838         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );\r
3839 \r
3840     ci = data->get_var_type(vi);\r
3841     if( ci >= 0 ) // split on categorical var\r
3842     {\r
3843         int i, n = data->cat_count->data.i[ci], inversed = 0, val;\r
3844         CvSeqReader reader;\r
3845         CvFileNode* inseq;\r
3846         split = data->new_split_cat( vi, 0 );\r
3847         inseq = cvGetFileNodeByName( fs, fnode, "in" );\r
3848         if( !inseq )\r
3849         {\r
3850             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );\r
3851             inversed = 1;\r
3852         }\r
3853         if( !inseq ||\r
3854             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))\r
3855             CV_ERROR( CV_StsParseError,\r
3856             "Either 'in' or 'not_in' tags should be inside a categorical split data" );\r
3857 \r
3858         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )\r
3859         {\r
3860             val = inseq->data.i;\r
3861             if( (unsigned)val >= (unsigned)n )\r
3862                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3863 \r
3864             split->subset[val >> 5] |= 1 << (val & 31);\r
3865         }\r
3866         else\r
3867         {\r
3868             cvStartReadSeq( inseq->data.seq, &reader );\r
3869 \r
3870             for( i = 0; i < reader.seq->total; i++ )\r
3871             {\r
3872                 CvFileNode* inode = (CvFileNode*)reader.ptr;\r
3873                 val = inode->data.i;\r
3874                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )\r
3875                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3876 \r
3877                 split->subset[val >> 5] |= 1 << (val & 31);\r
3878                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3879             }\r
3880         }\r
3881 \r
3882         // for categorical splits we do not use inversed splits,\r
3883         // instead we inverse the variable set in the split\r
3884         if( inversed )\r
3885             for( i = 0; i < (n + 31) >> 5; i++ )\r
3886                 split->subset[i] ^= -1;\r
3887     }\r
3888     else\r
3889     {\r
3890         CvFileNode* cmp_node;\r
3891         split = data->new_split_ord( vi, 0, 0, 0, 0 );\r
3892 \r
3893         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );\r
3894         if( !cmp_node )\r
3895         {\r
3896             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );\r
3897             split->inversed = 1;\r
3898         }\r
3899 \r
3900         split->ord.c = (float)cvReadReal( cmp_node );\r
3901     }\r
3902 \r
3903     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );\r
3904 \r
3905     __END__;\r
3906 \r
3907     return split;\r
3908 }\r
3909 \r
3910 \r
3911 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )\r
3912 {\r
3913     CvDTreeNode* node = 0;\r
3914 \r
3915     CV_FUNCNAME( "CvDTree::read_node" );\r
3916 \r
3917     __BEGIN__;\r
3918 \r
3919     CvFileNode* splits;\r
3920     int i, depth;\r
3921 \r
3922     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3923         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );\r
3924 \r
3925     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));\r
3926     depth = cvReadIntByName( fs, fnode, "depth", -1 );\r
3927     if( depth != node->depth )\r
3928         CV_ERROR( CV_StsParseError, "incorrect node depth" );\r
3929 \r
3930     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );\r
3931     node->value = cvReadRealByName( fs, fnode, "value" );\r
3932     if( data->is_classifier )\r
3933         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );\r
3934 \r
3935     node->Tn = cvReadIntByName( fs, fnode, "Tn" );\r
3936     node->complexity = cvReadIntByName( fs, fnode, "complexity" );\r
3937     node->alpha = cvReadRealByName( fs, fnode, "alpha" );\r
3938     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );\r
3939     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );\r
3940     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );\r
3941 \r
3942     splits = cvGetFileNodeByName( fs, fnode, "splits" );\r
3943     if( splits )\r
3944     {\r
3945         CvSeqReader reader;\r
3946         CvDTreeSplit* last_split = 0;\r
3947 \r
3948         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )\r
3949             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );\r
3950 \r
3951         cvStartReadSeq( splits->data.seq, &reader );\r
3952         for( i = 0; i < reader.seq->total; i++ )\r
3953         {\r
3954             CvDTreeSplit* split;\r
3955             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));\r
3956             if( !last_split )\r
3957                 node->split = last_split = split;\r
3958             else\r
3959                 last_split = last_split->next = split;\r
3960 \r
3961             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3962         }\r
3963     }\r
3964 \r
3965     __END__;\r
3966 \r
3967     return node;\r
3968 }\r
3969 \r
3970 \r
3971 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )\r
3972 {\r
3973     CV_FUNCNAME( "CvDTree::read_tree_nodes" );\r
3974 \r
3975     __BEGIN__;\r
3976 \r
3977     CvSeqReader reader;\r
3978     CvDTreeNode _root;\r
3979     CvDTreeNode* parent = &_root;\r
3980     int i;\r
3981     parent->left = parent->right = parent->parent = 0;\r
3982 \r
3983     cvStartReadSeq( fnode->data.seq, &reader );\r
3984 \r
3985     for( i = 0; i < reader.seq->total; i++ )\r
3986     {\r
3987         CvDTreeNode* node;\r
3988 \r
3989         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));\r
3990         if( !parent->left )\r
3991             parent->left = node;\r
3992         else\r
3993             parent->right = node;\r
3994         if( node->split )\r
3995             parent = node;\r
3996         else\r
3997         {\r
3998             while( parent && parent->right )\r
3999                 parent = parent->parent;\r
4000         }\r
4001 \r
4002         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
4003     }\r
4004 \r
4005     root = _root.left;\r
4006 \r
4007     __END__;\r
4008 }\r
4009 \r
4010 \r
4011 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )\r
4012 {\r
4013     CvDTreeTrainData* _data = new CvDTreeTrainData();\r
4014     _data->read_params( fs, fnode );\r
4015 \r
4016     read( fs, fnode, _data );\r
4017     get_var_importance();\r
4018 }\r
4019 \r
4020 \r
4021 // a special entry point for reading weak decision trees from the tree ensembles\r
4022 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )\r
4023 {\r
4024     CV_FUNCNAME( "CvDTree::read" );\r
4025 \r
4026     __BEGIN__;\r
4027 \r
4028     CvFileNode* tree_nodes;\r
4029 \r
4030     clear();\r
4031     data = _data;\r
4032 \r
4033     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );\r
4034     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )\r
4035         CV_ERROR( CV_StsParseError, "nodes tag is missing" );\r
4036 \r
4037     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );\r
4038     read_tree_nodes( fs, tree_nodes );\r
4039 \r
4040     __END__;\r
4041 }\r
4042 \r
4043 /* End of file. */\r