]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
fixed some warnings
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 //\r
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 //\r
5 //  By downloading, copying, installing or using the software you agree to this license.\r
6 //  If you do not agree to this license, do not download, install,\r
7 //  copy or use the software.\r
8 //\r
9 //\r
10 //                        Intel License Agreement\r
11 //\r
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.\r
13 // Third party copyrights are property of their respective owners.\r
14 //\r
15 // Redistribution and use in source and binary forms, with or without modification,\r
16 // are permitted provided that the following conditions are met:\r
17 //\r
18 //   * Redistribution's of source code must retain the above copyright notice,\r
19 //     this list of conditions and the following disclaimer.\r
20 //\r
21 //   * Redistribution's in binary form must reproduce the above copyright notice,\r
22 //     this list of conditions and the following disclaimer in the documentation\r
23 //     and/or other materials provided with the distribution.\r
24 //\r
25 //   * The name of Intel Corporation may not be used to endorse or promote products\r
26 //     derived from this software without specific prior written permission.\r
27 //\r
28 // This software is provided by the copyright holders and contributors "as is" and\r
29 // any express or implied warranties, including, but not limited to, the implied\r
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.\r
31 // In no event shall the Intel Corporation or contributors be liable for any direct,\r
32 // indirect, incidental, special, exemplary, or consequential damages\r
33 // (including, but not limited to, procurement of substitute goods or services;\r
34 // loss of use, data, or profits; or business interruption) however caused\r
35 // and on any theory of liability, whether in contract, strict liability,\r
36 // or tort (including negligence or otherwise) arising in any way out of\r
37 // the use of this software, even if advised of the possibility of such damage.\r
38 //\r
39 //M*/\r
40 \r
41 #include "_ml.h"\r
42 #include <ctype.h>\r
43 \r
44 static const float ord_nan = FLT_MAX*0.5f;\r
45 static const int min_block_size = 1 << 16;\r
46 static const int block_size_delta = 1 << 10;\r
47 \r
48 CvDTreeTrainData::CvDTreeTrainData()\r
49 {\r
50     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
51         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
52     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
53     pred_float_buf = resp_float_buf = 0;\r
54     tree_storage = temp_storage = 0;\r
55 \r
56     clear();\r
57 }\r
58 \r
59 \r
60 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,\r
61                       const CvMat* _responses, const CvMat* _var_idx,\r
62                       const CvMat* _sample_idx, const CvMat* _var_type,\r
63                       const CvMat* _missing_mask, const CvDTreeParams& _params,\r
64                       bool _shared, bool _add_labels )\r
65 {\r
66     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
67         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
68 \r
69     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
70     pred_float_buf = resp_float_buf = 0;\r
71 \r
72     tree_storage = temp_storage = 0;\r
73 \r
74     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,\r
75               _var_type, _missing_mask, _params, _shared, _add_labels );\r
76 }\r
77 \r
78 \r
79 CvDTreeTrainData::~CvDTreeTrainData()\r
80 {\r
81     clear();\r
82 }\r
83 \r
84 \r
85 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )\r
86 {\r
87     bool ok = false;\r
88 \r
89     CV_FUNCNAME( "CvDTreeTrainData::set_params" );\r
90 \r
91     __BEGIN__;\r
92 \r
93     // set parameters\r
94     params = _params;\r
95 \r
96     if( params.max_categories < 2 )\r
97         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );\r
98     params.max_categories = MIN( params.max_categories, 15 );\r
99 \r
100     if( params.max_depth < 0 )\r
101         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );\r
102     params.max_depth = MIN( params.max_depth, 25 );\r
103 \r
104     params.min_sample_count = MAX(params.min_sample_count,1);\r
105 \r
106     if( params.cv_folds < 0 )\r
107         CV_ERROR( CV_StsOutOfRange,\r
108         "params.cv_folds should be =0 (the tree is not pruned) "\r
109         "or n>0 (tree is pruned using n-fold cross-validation)" );\r
110 \r
111     if( params.cv_folds == 1 )\r
112         params.cv_folds = 0;\r
113 \r
114     if( params.regression_accuracy < 0 )\r
115         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );\r
116 \r
117     ok = true;\r
118 \r
119     __END__;\r
120 \r
121     return ok;\r
122 }\r
123 \r
124 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))\r
125 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )\r
126 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )\r
127 \r
128 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])\r
129 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )\r
130 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )\r
131 \r
132 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))\r
133 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )\r
134 \r
135 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
136     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
137     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
138     bool _shared, bool _add_labels, bool _update_data )\r
139 {\r
140     CvMat* sample_indices = 0;\r
141     CvMat* var_type0 = 0;\r
142     CvMat* tmp_map = 0;\r
143     int** int_ptr = 0;\r
144     CvPair16u32s* pair16u32s_ptr = 0;\r
145     CvDTreeTrainData* data = 0;\r
146     float *_fdst = 0;\r
147     int *_idst = 0;\r
148     unsigned short* udst = 0;\r
149     int* idst = 0;\r
150 \r
151     CV_FUNCNAME( "CvDTreeTrainData::set_data" );\r
152 \r
153     __BEGIN__;\r
154 \r
155     int sample_all = 0, r_type = 0, cv_n;\r
156     int total_c_count = 0;\r
157     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
158     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step\r
159     int vi, i, size;\r
160     char err[100];\r
161     const int *sidx = 0, *vidx = 0;\r
162     \r
163     \r
164     if( _update_data && data_root )\r
165     {\r
166         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,\r
167             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );\r
168 \r
169         // compare new and old train data\r
170         if( !(data->var_count == var_count &&\r
171             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
172             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&\r
173             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )\r
174             CV_ERROR( CV_StsBadArg,\r
175             "The new training data must have the same types and the input and output variables "\r
176             "and the same categories for categorical variables" );\r
177 \r
178         cvReleaseMat( &priors );\r
179         cvReleaseMat( &priors_mult );\r
180         cvReleaseMat( &buf );\r
181         cvReleaseMat( &direction );\r
182         cvReleaseMat( &split_buf );\r
183         cvReleaseMemStorage( &temp_storage );\r
184 \r
185         priors = data->priors; data->priors = 0;\r
186         priors_mult = data->priors_mult; data->priors_mult = 0;\r
187         buf = data->buf; data->buf = 0;\r
188         buf_count = data->buf_count; buf_size = data->buf_size;\r
189         sample_count = data->sample_count;\r
190 \r
191         direction = data->direction; data->direction = 0;\r
192         split_buf = data->split_buf; data->split_buf = 0;\r
193         temp_storage = data->temp_storage; data->temp_storage = 0;\r
194         nv_heap = data->nv_heap; cv_heap = data->cv_heap;\r
195 \r
196         data_root = new_node( 0, sample_count, 0, 0 );\r
197         EXIT;\r
198     }\r
199 \r
200     clear();\r
201 \r
202     var_all = 0;\r
203     rng = cvRNG(-1);\r
204 \r
205     CV_CALL( set_params( _params ));\r
206 \r
207     // check parameter types and sizes\r
208     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));\r
209 \r
210     train_data = _train_data;\r
211     responses = _responses;\r
212 \r
213     if( _tflag == CV_ROW_SAMPLE )\r
214     {\r
215         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
216         dv_step = 1;\r
217         if( _missing_mask )\r
218             ms_step = _missing_mask->step, mv_step = 1;\r
219     }\r
220     else\r
221     {\r
222         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
223         ds_step = 1;\r
224         if( _missing_mask )\r
225             mv_step = _missing_mask->step, ms_step = 1;\r
226     }\r
227     tflag = _tflag;\r
228 \r
229     sample_count = sample_all;\r
230     var_count = var_all;\r
231     \r
232     if( _sample_idx )\r
233     {\r
234         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));\r
235         sidx = sample_indices->data.i;\r
236         sample_count = sample_indices->rows + sample_indices->cols - 1;\r
237     }\r
238 \r
239     if( _var_idx )\r
240     {\r
241         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));\r
242         vidx = var_idx->data.i;\r
243         var_count = var_idx->rows + var_idx->cols - 1;\r
244     }\r
245 \r
246     is_buf_16u = false;     \r
247     if ( sample_count < 65536 ) \r
248         is_buf_16u = true;                                \r
249     \r
250     if( !CV_IS_MAT(_responses) ||\r
251         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
252          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
253         (_responses->rows != 1 && _responses->cols != 1) ||\r
254         _responses->rows + _responses->cols - 1 != sample_all )\r
255         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
256                   "floating-point vector containing as many elements as "\r
257                   "the total number of samples in the training data matrix" );\r
258    \r
259   \r
260     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));\r
261 \r
262     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));\r
263    \r
264     \r
265     cat_var_count = 0;\r
266     ord_var_count = -1;\r
267 \r
268     is_classifier = r_type == CV_VAR_CATEGORICAL;\r
269 \r
270     // step 0. calc the number of categorical vars\r
271     for( vi = 0; vi < var_count; vi++ )\r
272     {\r
273         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?\r
274             cat_var_count++ : ord_var_count--;\r
275     }\r
276 \r
277     ord_var_count = ~ord_var_count;\r
278     cv_n = params.cv_folds;\r
279     // set the two last elements of var_type array to be able\r
280     // to locate responses and cross-validation labels using\r
281     // the corresponding get_* functions.\r
282     var_type->data.i[var_count] = cat_var_count;\r
283     var_type->data.i[var_count+1] = cat_var_count+1;\r
284 \r
285     // in case of single ordered predictor we need dummy cv_labels\r
286     // for safe split_node_data() operation\r
287     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;\r
288 \r
289     work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);\r
290     buf_size = (work_var_count + 1)*sample_count;\r
291     shared = _shared;\r
292     buf_count = shared ? 2 : 1;\r
293     \r
294     if ( is_buf_16u )\r
295     {\r
296         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));\r
297         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));\r
298     }\r
299     else\r
300     {\r
301         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));\r
302         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));\r
303     }    \r
304 \r
305     size = is_classifier ? (cat_var_count+1) : cat_var_count;\r
306     size = !size ? 1 : size;\r
307     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));\r
308     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));\r
309         \r
310     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;\r
311     size = !size ? 1 : size;\r
312     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));\r
313 \r
314     // now calculate the maximum size of split,\r
315     // create memory storage that will keep nodes and splits of the decision tree\r
316     // allocate root node and the buffer for the whole training data\r
317     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
318         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));\r
319     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
320     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
321     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
322     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));\r
323 \r
324     nv_size = var_count*sizeof(int);\r
325     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));\r
326 \r
327     temp_block_size = nv_size;\r
328 \r
329     if( cv_n )\r
330     {\r
331         if( sample_count < cv_n*MAX(params.min_sample_count,10) )\r
332             CV_ERROR( CV_StsOutOfRange,\r
333                 "The many folds in cross-validation for such a small dataset" );\r
334 \r
335         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );\r
336         temp_block_size = MAX(temp_block_size, cv_size);\r
337     }\r
338 \r
339     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );\r
340     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));\r
341     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));\r
342     if( cv_size )\r
343         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));\r
344 \r
345     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));\r
346 \r
347     max_c_count = 1;\r
348 \r
349     _fdst = 0;\r
350     _idst = 0;\r
351     if (ord_var_count)\r
352         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));\r
353     if (is_buf_16u && (cat_var_count || is_classifier))\r
354         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));\r
355 \r
356     // transform the training data to convenient representation\r
357     for( vi = 0; vi <= var_count; vi++ )\r
358     {\r
359         int ci;\r
360         const uchar* mask = 0;\r
361         int m_step = 0, step;\r
362         const int* idata = 0;\r
363         const float* fdata = 0;\r
364         int num_valid = 0;\r
365 \r
366         if( vi < var_count ) // analyze i-th input variable\r
367         {\r
368             int vi0 = vidx ? vidx[vi] : vi;\r
369             ci = get_var_type(vi);\r
370             step = ds_step; m_step = ms_step;\r
371             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
372                 idata = _train_data->data.i + vi0*dv_step;\r
373             else\r
374                 fdata = _train_data->data.fl + vi0*dv_step;\r
375             if( _missing_mask )\r
376                 mask = _missing_mask->data.ptr + vi0*mv_step;\r
377         }\r
378         else // analyze _responses\r
379         {\r
380             ci = cat_var_count;\r
381             step = CV_IS_MAT_CONT(_responses->type) ?\r
382                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);\r
383             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )\r
384                 idata = _responses->data.i;\r
385             else\r
386                 fdata = _responses->data.fl;\r
387         }\r
388 \r
389         if( (vi < var_count && ci>=0) ||\r
390             (vi == var_count && is_classifier) ) // process categorical variable or response\r
391         {\r
392             int c_count, prev_label;\r
393             int* c_map;\r
394             \r
395             if (is_buf_16u)\r
396                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
397             else\r
398                 idst = buf->data.i + vi*sample_count;\r
399             \r
400             // copy data\r
401             for( i = 0; i < sample_count; i++ )\r
402             {\r
403                 int val = INT_MAX, si = sidx ? sidx[i] : i;\r
404                 if( !mask || !mask[si*m_step] )\r
405                 {\r
406                     if( idata )\r
407                         val = idata[si*step];\r
408                     else\r
409                     {\r
410                         float t = fdata[si*step];\r
411                         val = cvRound(t);\r
412                         if( val != t )\r
413                         {\r
414                             sprintf( err, "%d-th value of %d-th (categorical) "\r
415                                 "variable is not an integer", i, vi );\r
416                             CV_ERROR( CV_StsBadArg, err );\r
417                         }\r
418                     }\r
419 \r
420                     if( val == INT_MAX )\r
421                     {\r
422                         sprintf( err, "%d-th value of %d-th (categorical) "\r
423                             "variable is too large", i, vi );\r
424                         CV_ERROR( CV_StsBadArg, err );\r
425                     }\r
426                     num_valid++;\r
427                 }\r
428                 if (is_buf_16u)\r
429                 {\r
430                     _idst[i] = val;\r
431                     pair16u32s_ptr[i].u = udst + i;\r
432                     pair16u32s_ptr[i].i = _idst + i;\r
433                 }   \r
434                 else\r
435                 {\r
436                     idst[i] = val;\r
437                     int_ptr[i] = idst + i;\r
438                 }\r
439             }\r
440 \r
441             c_count = num_valid > 0;\r
442 \r
443             if (is_buf_16u)\r
444             {\r
445                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );\r
446                 // count the categories\r
447                 for( i = 1; i < num_valid; i++ )\r
448                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)\r
449                         c_count ++ ;\r
450             }\r
451             else\r
452             {\r
453                 icvSortIntPtr( int_ptr, sample_count, 0 );\r
454                 // count the categories\r
455                 for( i = 1; i < num_valid; i++ )\r
456                     c_count += *int_ptr[i] != *int_ptr[i-1];\r
457             }\r
458 \r
459             if( vi > 0 )\r
460                 max_c_count = MAX( max_c_count, c_count );\r
461             cat_count->data.i[ci] = c_count;\r
462             cat_ofs->data.i[ci] = total_c_count;\r
463 \r
464             // resize cat_map, if need\r
465             if( cat_map->cols < total_c_count + c_count )\r
466             {\r
467                 tmp_map = cat_map;\r
468                 CV_CALL( cat_map = cvCreateMat( 1,\r
469                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));\r
470                 for( i = 0; i < total_c_count; i++ )\r
471                     cat_map->data.i[i] = tmp_map->data.i[i];\r
472                 cvReleaseMat( &tmp_map );\r
473             }\r
474 \r
475             c_map = cat_map->data.i + total_c_count;\r
476             total_c_count += c_count;\r
477 \r
478             c_count = -1;\r
479             if (is_buf_16u)\r
480             {\r
481                 // compact the class indices and build the map\r
482                 prev_label = ~*pair16u32s_ptr[0].i;\r
483                 for( i = 0; i < num_valid; i++ )\r
484                 {\r
485                     int cur_label = *pair16u32s_ptr[i].i;\r
486                     if( cur_label != prev_label )\r
487                         c_map[++c_count] = prev_label = cur_label;\r
488                     *pair16u32s_ptr[i].u = (unsigned short)c_count;\r
489                 }\r
490                 // replace labels for missing values with -1\r
491                 for( ; i < sample_count; i++ )\r
492                     *pair16u32s_ptr[i].u = 65535;\r
493             }\r
494             else\r
495             {\r
496                 // compact the class indices and build the map\r
497                 prev_label = ~*int_ptr[0];\r
498                 for( i = 0; i < num_valid; i++ )\r
499                 {\r
500                     int cur_label = *int_ptr[i];\r
501                     if( cur_label != prev_label )\r
502                         c_map[++c_count] = prev_label = cur_label;\r
503                     *int_ptr[i] = c_count;\r
504                 }\r
505                 // replace labels for missing values with -1\r
506                 for( ; i < sample_count; i++ )\r
507                     *int_ptr[i] = -1;\r
508             }           \r
509         }\r
510         else if( ci < 0 ) // process ordered variable\r
511         {\r
512             if (is_buf_16u)\r
513                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
514             else\r
515                 idst = buf->data.i + vi*sample_count;\r
516 \r
517             for( i = 0; i < sample_count; i++ )\r
518             {\r
519                 float val = ord_nan;\r
520                 int si = sidx ? sidx[i] : i;\r
521                 if( !mask || !mask[si*m_step] )\r
522                 {\r
523                     if( idata )\r
524                         val = (float)idata[si*step];\r
525                     else\r
526                         val = fdata[si*step];\r
527 \r
528                     if( fabs(val) >= ord_nan )\r
529                     {\r
530                         sprintf( err, "%d-th value of %d-th (ordered) "\r
531                             "variable (=%g) is too large", i, vi, val );\r
532                         CV_ERROR( CV_StsBadArg, err );\r
533                     }\r
534                 }\r
535                 num_valid++;\r
536                 if (is_buf_16u)\r
537                     udst[i] = (unsigned short)i;\r
538                 else\r
539                     idst[i] = i; // Ã¯Ã¥Ã°Ã¥Ã­Ã¥Ã±Ã²Ã¨ Ã¢Ã»Ã¸Ã¥ Ã¢ if( idata )\r
540                 _fdst[i] = val;\r
541                 \r
542             }\r
543             if (is_buf_16u)\r
544                 icvSortUShAux( udst, num_valid, _fdst);\r
545             else\r
546                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );\r
547         }\r
548        \r
549         if( vi < var_count )\r
550             data_root->set_num_valid(vi, num_valid);\r
551     }\r
552 \r
553     // set sample labels\r
554     if (is_buf_16u)\r
555         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);\r
556     else\r
557         idst = buf->data.i + work_var_count*sample_count;\r
558 \r
559     for (i = 0; i < sample_count; i++)\r
560     {\r
561         if (udst)\r
562             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;\r
563         else\r
564             idst[i] = sidx ? sidx[i] : i;\r
565     }\r
566 \r
567     if( cv_n )\r
568     {\r
569         unsigned short* udst = 0;\r
570         int* idst = 0;\r
571         CvRNG* r = &rng;\r
572 \r
573         if (is_buf_16u)\r
574         {\r
575             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);\r
576             for( i = vi = 0; i < sample_count; i++ )\r
577             {\r
578                 udst[i] = (unsigned short)vi++;\r
579                 vi &= vi < cv_n ? -1 : 0;\r
580             }\r
581 \r
582             for( i = 0; i < sample_count; i++ )\r
583             {\r
584                 int a = cvRandInt(r) % sample_count;\r
585                 int b = cvRandInt(r) % sample_count;\r
586                 unsigned short unsh = (unsigned short)vi;\r
587                 CV_SWAP( udst[a], udst[b], unsh );\r
588             }\r
589         }\r
590         else\r
591         {\r
592             idst = buf->data.i + (get_work_var_count()-1)*sample_count;\r
593             for( i = vi = 0; i < sample_count; i++ )\r
594             {\r
595                 idst[i] = vi++;\r
596                 vi &= vi < cv_n ? -1 : 0;\r
597             }\r
598 \r
599             for( i = 0; i < sample_count; i++ )\r
600             {\r
601                 int a = cvRandInt(r) % sample_count;\r
602                 int b = cvRandInt(r) % sample_count;\r
603                 CV_SWAP( idst[a], idst[b], vi );\r
604             }\r
605         }\r
606     }\r
607 \r
608     if ( cat_map ) \r
609         cat_map->cols = MAX( total_c_count, 1 );\r
610 \r
611     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
612         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
613     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));\r
614 \r
615     have_priors = is_classifier && params.priors;\r
616     if( is_classifier )\r
617     {\r
618         int m = get_num_classes();\r
619         double sum = 0;\r
620         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));\r
621         for( i = 0; i < m; i++ )\r
622         {\r
623             double val = have_priors ? params.priors[i] : 1.;\r
624             if( val <= 0 )\r
625                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );\r
626             priors->data.db[i] = val;\r
627             sum += val;\r
628         }\r
629 \r
630         // normalize weights\r
631         if( have_priors )\r
632             cvScale( priors, priors, 1./sum );\r
633 \r
634         CV_CALL( priors_mult = cvCloneMat( priors ));\r
635         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));\r
636     }\r
637 \r
638 \r
639     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));\r
640     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));\r
641 \r
642     CV_CALL( pred_float_buf = (float*)cvAlloc(sample_count*sizeof(pred_float_buf[0])) );\r
643     CV_CALL( pred_int_buf = (int*)cvAlloc(sample_count*sizeof(pred_int_buf[0])) );\r
644     CV_CALL( resp_float_buf = (float*)cvAlloc(sample_count*sizeof(resp_float_buf[0])) );\r
645     CV_CALL( resp_int_buf = (int*)cvAlloc(sample_count*sizeof(resp_int_buf[0])) );\r
646     CV_CALL( cv_lables_buf = (int*)cvAlloc(sample_count*sizeof(cv_lables_buf[0])) );\r
647     CV_CALL( sample_idx_buf = (int*)cvAlloc(sample_count*sizeof(sample_idx_buf[0])) );\r
648 \r
649     __END__;\r
650 \r
651     if( data )\r
652         delete data;\r
653 \r
654     if (_fdst)\r
655         cvFree( &_fdst );\r
656     if (_idst)\r
657         cvFree( &_idst );\r
658     cvFree( &int_ptr );\r
659     cvReleaseMat( &var_type0 );\r
660     cvReleaseMat( &sample_indices );\r
661     cvReleaseMat( &tmp_map );\r
662 }\r
663 \r
664 \r
665 \r
666 void CvDTreeTrainData::do_responses_copy()\r
667 {\r
668     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );\r
669     cvCopy( responses, responses_copy);\r
670     responses = responses_copy;\r
671 }\r
672 \r
673 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )\r
674 {\r
675     CvDTreeNode* root = 0;\r
676     CvMat* isubsample_idx = 0;\r
677     CvMat* subsample_co = 0;\r
678 \r
679     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );\r
680 \r
681     __BEGIN__;\r
682 \r
683     if( !data_root )\r
684         CV_ERROR( CV_StsError, "No training data has been set" );\r
685 \r
686     if( _subsample_idx )\r
687         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
688 \r
689     if( !isubsample_idx )\r
690     {\r
691         // make a copy of the root node\r
692         CvDTreeNode temp;\r
693         int i;\r
694         root = new_node( 0, 1, 0, 0 );\r
695         temp = *root;\r
696         *root = *data_root;\r
697         root->num_valid = temp.num_valid;\r
698         if( root->num_valid )\r
699         {\r
700             for( i = 0; i < var_count; i++ )\r
701                 root->num_valid[i] = data_root->num_valid[i];\r
702         }\r
703         root->cv_Tn = temp.cv_Tn;\r
704         root->cv_node_risk = temp.cv_node_risk;\r
705         root->cv_node_error = temp.cv_node_error;\r
706     }\r
707     else\r
708     {\r
709         int* sidx = isubsample_idx->data.i;\r
710         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)\r
711         int* co, cur_ofs = 0;\r
712         int vi, i;\r
713         int work_var_count = get_work_var_count();\r
714         int count = isubsample_idx->rows + isubsample_idx->cols - 1;\r
715 \r
716         root = new_node( 0, count, 1, 0 );\r
717 \r
718         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
719         cvZero( subsample_co );\r
720         co = subsample_co->data.i;\r
721         for( i = 0; i < count; i++ )\r
722             co[sidx[i]*2]++;\r
723         for( i = 0; i < sample_count; i++ )\r
724         {\r
725             if( co[i*2] )\r
726             {\r
727                 co[i*2+1] = cur_ofs;\r
728                 cur_ofs += co[i*2];\r
729             }\r
730             else\r
731                 co[i*2+1] = -1;\r
732         }\r
733 \r
734         for( vi = 0; vi < work_var_count; vi++ )\r
735         {\r
736             int ci = get_var_type(vi);\r
737 \r
738             if( ci >= 0 || vi >= var_count )\r
739             {\r
740                 int* src_buf = pred_int_buf;\r
741                 const int* src = 0;\r
742                 int num_valid = 0;\r
743                 \r
744                 get_cat_var_data( data_root, vi, src_buf, &src );\r
745 \r
746                 if (is_buf_16u)\r
747                 {\r
748                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
749                         vi*sample_count + root->offset);\r
750                     for( i = 0; i < count; i++ )\r
751                     {\r
752                         int val = src[sidx[i]];\r
753                         udst[i] = (unsigned short)val;\r
754                         num_valid += val >= 0;\r
755                     }\r
756                 }\r
757                 else\r
758                 {\r
759                     int* idst = buf->data.i + root->buf_idx*buf->cols + \r
760                         vi*sample_count + root->offset;\r
761                     for( i = 0; i < count; i++ )\r
762                     {\r
763                         int val = src[sidx[i]];\r
764                         idst[i] = val;\r
765                         num_valid += val >= 0;\r
766                     }\r
767                 }\r
768 \r
769                 if( vi < var_count )\r
770                     root->set_num_valid(vi, num_valid);\r
771             }\r
772             else\r
773             {\r
774                 int *src_idx_buf = pred_int_buf;\r
775                 const int* src_idx = 0;\r
776                 float *src_val_buf = pred_float_buf;\r
777                 const float* src_val = 0;\r
778                 int j = 0, idx, count_i;\r
779                 int num_valid = data_root->get_num_valid(vi);\r
780 \r
781                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );\r
782                 if (is_buf_16u)\r
783                 {\r
784                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
785                         vi*sample_count + data_root->offset);\r
786                     for( i = 0; i < num_valid; i++ )\r
787                     {\r
788                         idx = src_idx[i];\r
789                         count_i = co[idx*2];\r
790                         if( count_i )\r
791                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
792                                 udst_idx[j] = (unsigned short)cur_ofs;\r
793                     }\r
794 \r
795                     root->set_num_valid(vi, j);\r
796 \r
797                     for( ; i < sample_count; i++ )\r
798                     {\r
799                         idx = src_idx[i];\r
800                         count_i = co[idx*2];\r
801                         if( count_i )\r
802                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
803                                 udst_idx[j] = (unsigned short)cur_ofs;\r
804                     }\r
805                 }\r
806                 else\r
807                 {\r
808                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + \r
809                         vi*sample_count + root->offset;\r
810                     for( i = 0; i < num_valid; i++ )\r
811                     {\r
812                         idx = src_idx[i];\r
813                         count_i = co[idx*2];\r
814                         if( count_i )\r
815                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
816                                 idst_idx[j] = cur_ofs;\r
817                     }\r
818 \r
819                     root->set_num_valid(vi, j);\r
820 \r
821                     for( ; i < sample_count; i++ )\r
822                     {\r
823                         idx = src_idx[i];\r
824                         count_i = co[idx*2];\r
825                         if( count_i )\r
826                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
827                                 idst_idx[j] = cur_ofs;\r
828                     }\r
829                 }\r
830             }\r
831         }\r
832         // sample indices subsampling\r
833         int* sample_idx_src_buf = sample_idx_buf;\r
834         const int* sample_idx_src = 0;\r
835         get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);\r
836         if (is_buf_16u)\r
837         {\r
838             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
839                 get_work_var_count()*sample_count + root->offset);            \r
840             for (i = 0; i < count; i++)\r
841                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];\r
842         }\r
843         else\r
844         {\r
845             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + \r
846                 get_work_var_count()*sample_count + root->offset;            \r
847             for (i = 0; i < count; i++)\r
848                 sample_idx_dst[i] = sample_idx_src[sidx[i]];\r
849         }\r
850     }\r
851 \r
852     __END__;\r
853 \r
854     cvReleaseMat( &isubsample_idx );\r
855     cvReleaseMat( &subsample_co );\r
856 \r
857     return root;\r
858 }\r
859 \r
860 \r
861 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,\r
862                                     float* values, uchar* missing,\r
863                                     float* responses, bool get_class_idx )\r
864 {\r
865     CvMat* subsample_idx = 0;\r
866     CvMat* subsample_co = 0;\r
867 \r
868     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );\r
869 \r
870     __BEGIN__;\r
871 \r
872     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
873     int* sidx = 0;\r
874     int* co = 0;\r
875 \r
876     if( _subsample_idx )\r
877     {\r
878         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
879         sidx = subsample_idx->data.i;\r
880         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
881         co = subsample_co->data.i;\r
882         cvZero( subsample_co );\r
883         count = subsample_idx->cols + subsample_idx->rows - 1;\r
884         for( i = 0; i < count; i++ )\r
885             co[sidx[i]*2]++;\r
886         for( i = 0; i < total; i++ )\r
887         {\r
888             int count_i = co[i*2];\r
889             if( count_i )\r
890             {\r
891                 co[i*2+1] = cur_ofs*var_count;\r
892                 cur_ofs += count_i;\r
893             }\r
894         }\r
895     }\r
896 \r
897     if( missing )\r
898         memset( missing, 1, count*var_count );\r
899 \r
900     for( vi = 0; vi < var_count; vi++ )\r
901     {\r
902         int ci = get_var_type(vi);\r
903         if( ci >= 0 ) // categorical\r
904         {\r
905             float* dst = values + vi;\r
906             uchar* m = missing ? missing + vi : 0;\r
907             int* src_buf = pred_int_buf;\r
908             const int* src = 0; \r
909             get_cat_var_data(data_root, vi, src_buf, &src);\r
910 \r
911             for( i = 0; i < count; i++, dst += var_count )\r
912             {\r
913                 int idx = sidx ? sidx[i] : i;\r
914                 int val = src[idx];\r
915                 *dst = (float)val;\r
916                 if( m )\r
917                 {\r
918                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));\r
919                     m += var_count;\r
920                 }\r
921             }\r
922         }\r
923         else // ordered\r
924         {\r
925             float* dst = values + vi;\r
926             uchar* m = missing ? missing + vi : 0;\r
927             int count1 = data_root->get_num_valid(vi);\r
928             float *src_val_buf = pred_float_buf;\r
929             const float *src_val = 0;\r
930             int* src_idx_buf = pred_int_buf;\r
931             const int* src_idx = 0;\r
932             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
933 \r
934             for( i = 0; i < count1; i++ )\r
935             {\r
936                 int idx = src_idx[i];\r
937                 int count_i = 1;\r
938                 if( co )\r
939                 {\r
940                     count_i = co[idx*2];\r
941                     cur_ofs = co[idx*2+1];\r
942                 }\r
943                 else\r
944                     cur_ofs = idx*var_count;\r
945                 if( count_i )\r
946                 {\r
947                     float val = src_val[i];\r
948                     for( ; count_i > 0; count_i--, cur_ofs += var_count )\r
949                     {\r
950                         dst[cur_ofs] = val;\r
951                         if( m )\r
952                             m[cur_ofs] = 0;\r
953                     }\r
954                 }\r
955             }\r
956         }\r
957     }\r
958 \r
959     // copy responses\r
960     if( responses )\r
961     {\r
962         if( is_classifier )\r
963         {\r
964             int* src_buf = resp_int_buf;\r
965             const int* src = 0;\r
966             get_class_labels(data_root, src_buf, &src);\r
967             for( i = 0; i < count; i++ )\r
968             {\r
969                 int idx = sidx ? sidx[i] : i;\r
970                 int val = get_class_idx ? src[idx] :\r
971                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];\r
972                 responses[i] = (float)val;\r
973             }\r
974         }\r
975         else\r
976         {\r
977             float *_values_buf = resp_float_buf;\r
978             const float* _values = 0;\r
979             get_ord_responses(data_root, _values_buf, &_values);\r
980             for( i = 0; i < count; i++ )\r
981             {\r
982                 int idx = sidx ? sidx[i] : i;\r
983                 responses[i] = _values[idx];\r
984             }\r
985         }\r
986     }\r
987 \r
988     __END__;\r
989 \r
990     cvReleaseMat( &subsample_idx );\r
991     cvReleaseMat( &subsample_co );\r
992 }\r
993 \r
994 \r
995 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,\r
996                                          int storage_idx, int offset )\r
997 {\r
998     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );\r
999 \r
1000     node->sample_count = count;\r
1001     node->depth = parent ? parent->depth + 1 : 0;\r
1002     node->parent = parent;\r
1003     node->left = node->right = 0;\r
1004     node->split = 0;\r
1005     node->value = 0;\r
1006     node->class_idx = 0;\r
1007     node->maxlr = 0.;\r
1008 \r
1009     node->buf_idx = storage_idx;\r
1010     node->offset = offset;\r
1011     if( nv_heap )\r
1012         node->num_valid = (int*)cvSetNew( nv_heap );\r
1013     else\r
1014         node->num_valid = 0;\r
1015     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;\r
1016     node->complexity = 0;\r
1017 \r
1018     if( params.cv_folds > 0 && cv_heap )\r
1019     {\r
1020         int cv_n = params.cv_folds;\r
1021         node->Tn = INT_MAX;\r
1022         node->cv_Tn = (int*)cvSetNew( cv_heap );\r
1023         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));\r
1024         node->cv_node_error = node->cv_node_risk + cv_n;\r
1025     }\r
1026     else\r
1027     {\r
1028         node->Tn = 0;\r
1029         node->cv_Tn = 0;\r
1030         node->cv_node_risk = 0;\r
1031         node->cv_node_error = 0;\r
1032     }\r
1033 \r
1034     return node;\r
1035 }\r
1036 \r
1037 \r
1038 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,\r
1039                 int split_point, int inversed, float quality )\r
1040 {\r
1041     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1042     split->var_idx = vi;\r
1043     split->condensed_idx = INT_MIN;\r
1044     split->ord.c = cmp_val;\r
1045     split->ord.split_point = split_point;\r
1046     split->inversed = inversed;\r
1047     split->quality = quality;\r
1048     split->next = 0;\r
1049 \r
1050     return split;\r
1051 }\r
1052 \r
1053 \r
1054 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )\r
1055 {\r
1056     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1057     int i, n = (max_c_count + 31)/32;\r
1058 \r
1059     split->var_idx = vi;\r
1060     split->condensed_idx = INT_MIN;\r
1061     split->inversed = 0;\r
1062     split->quality = quality;\r
1063     for( i = 0; i < n; i++ )\r
1064         split->subset[i] = 0;\r
1065     split->next = 0;\r
1066 \r
1067     return split;\r
1068 }\r
1069 \r
1070 \r
1071 void CvDTreeTrainData::free_node( CvDTreeNode* node )\r
1072 {\r
1073     CvDTreeSplit* split = node->split;\r
1074     free_node_data( node );\r
1075     while( split )\r
1076     {\r
1077         CvDTreeSplit* next = split->next;\r
1078         cvSetRemoveByPtr( split_heap, split );\r
1079         split = next;\r
1080     }\r
1081     node->split = 0;\r
1082     cvSetRemoveByPtr( node_heap, node );\r
1083 }\r
1084 \r
1085 \r
1086 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )\r
1087 {\r
1088     if( node->num_valid )\r
1089     {\r
1090         cvSetRemoveByPtr( nv_heap, node->num_valid );\r
1091         node->num_valid = 0;\r
1092     }\r
1093     // do not free cv_* fields, as all the cross-validation related data is released at once.\r
1094 }\r
1095 \r
1096 \r
1097 void CvDTreeTrainData::free_train_data()\r
1098 {\r
1099     cvReleaseMat( &counts );\r
1100     cvReleaseMat( &buf );\r
1101     cvReleaseMat( &direction );\r
1102     cvReleaseMat( &split_buf );\r
1103     cvReleaseMemStorage( &temp_storage );\r
1104     cvReleaseMat( &responses_copy );\r
1105     cvFree( &pred_float_buf );\r
1106     cvFree( &pred_int_buf );\r
1107     cvFree( &resp_float_buf );\r
1108     cvFree( &resp_int_buf );\r
1109     cvFree( &cv_lables_buf );\r
1110     cvFree( &sample_idx_buf );\r
1111 \r
1112     cv_heap = nv_heap = 0;\r
1113 }\r
1114 \r
1115 \r
1116 void CvDTreeTrainData::clear()\r
1117 {\r
1118     free_train_data();\r
1119 \r
1120     cvReleaseMemStorage( &tree_storage );\r
1121 \r
1122     cvReleaseMat( &var_idx );\r
1123     cvReleaseMat( &var_type );\r
1124     cvReleaseMat( &cat_count );\r
1125     cvReleaseMat( &cat_ofs );\r
1126     cvReleaseMat( &cat_map );\r
1127     cvReleaseMat( &priors );\r
1128     cvReleaseMat( &priors_mult );\r
1129     \r
1130     node_heap = split_heap = 0;\r
1131 \r
1132     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;\r
1133     have_labels = have_priors = is_classifier = false;\r
1134 \r
1135     buf_count = buf_size = 0;\r
1136     shared = false;\r
1137     \r
1138     data_root = 0;\r
1139 \r
1140     rng = cvRNG(-1);\r
1141 }\r
1142 \r
1143 \r
1144 int CvDTreeTrainData::get_num_classes() const\r
1145 {\r
1146     return is_classifier ? cat_count->data.i[cat_var_count] : 0;\r
1147 }\r
1148 \r
1149 \r
1150 int CvDTreeTrainData::get_var_type(int vi) const\r
1151 {\r
1152     return var_type->data.i[vi];\r
1153 }\r
1154 \r
1155 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )\r
1156 {\r
1157     int vidx = var_idx ? var_idx->data.i[vi] : vi;\r
1158     int node_sample_count = n->sample_count; \r
1159     int* sample_indices_buf = sample_idx_buf;\r
1160     const int* sample_indices = 0;\r
1161     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);\r
1162 \r
1163     get_sample_indices(n, sample_indices_buf, &sample_indices);\r
1164 \r
1165     if( !is_buf_16u )\r
1166         *indices = buf->data.i + n->buf_idx*buf->cols + \r
1167         vi*sample_count + n->offset;\r
1168     else {\r
1169         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1170             vi*sample_count + n->offset );\r
1171         for( int i = 0; i < node_sample_count; i++ )\r
1172             indices_buf[i] = short_indices[i];\r
1173         *indices = indices_buf;\r
1174     }\r
1175     \r
1176     if( tflag == CV_ROW_SAMPLE )\r
1177     {\r
1178         for( int i = 0; i < node_sample_count && \r
1179             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1180         {\r
1181             int idx = (*indices)[i];\r
1182             idx = sample_indices[idx];\r
1183             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);\r
1184         }\r
1185     }\r
1186     else\r
1187         for( int i = 0; i < node_sample_count && \r
1188             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1189         {\r
1190             int idx = (*indices)[i];\r
1191             idx = sample_indices[idx];\r
1192             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);\r
1193         }\r
1194     \r
1195     *ord_values = ord_values_buf;\r
1196     return 0; //TODO: return the number of non-missing values\r
1197 }\r
1198 \r
1199 \r
1200 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1201 {\r
1202     if (is_classifier)\r
1203         get_cat_var_data( n, var_count, labels_buf, labels );\r
1204 }\r
1205 \r
1206 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )\r
1207 {\r
1208     get_cat_var_data( n, get_work_var_count(), indices_buf, indices );\r
1209 }\r
1210 \r
1211 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)\r
1212 {\r
1213     int sample_count = n->sample_count;\r
1214     int* indices_buf = sample_idx_buf;\r
1215     const int* indices = 0;\r
1216 \r
1217     int r_step = responses->step/CV_ELEM_SIZE(responses->type);\r
1218 \r
1219     get_sample_indices(n, indices_buf, &indices);\r
1220 \r
1221     \r
1222     for( int i = 0; i < sample_count && \r
1223         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )\r
1224     {\r
1225         int idx = indices[i];\r
1226         values_buf[i] = *(responses->data.fl + idx * r_step);\r
1227     }\r
1228     \r
1229     *values = values_buf;    \r
1230 }\r
1231 \r
1232 \r
1233 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1234 {\r
1235     if (have_labels)\r
1236         get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );\r
1237 }\r
1238 \r
1239 \r
1240 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )\r
1241 {\r
1242     if( !is_buf_16u )\r
1243         *cat_values = buf->data.i + n->buf_idx*buf->cols + \r
1244         vi*sample_count + n->offset;\r
1245     else {\r
1246         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1247             vi*sample_count + n->offset);\r
1248         for( int i = 0; i < n->sample_count; i++ )\r
1249             cat_values_buf[i] = short_values[i];\r
1250         *cat_values = cat_values_buf;\r
1251     }\r
1252 \r
1253     return 0; //TODO: return the number of non-missing values\r
1254 }\r
1255 \r
1256 \r
1257 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )\r
1258 {\r
1259     int idx = n->buf_idx + 1;\r
1260     if( idx >= buf_count )\r
1261         idx = shared ? 1 : 0;\r
1262     return idx;\r
1263 }\r
1264 \r
1265 \r
1266 void CvDTreeTrainData::write_params( CvFileStorage* fs )\r
1267 {\r
1268     CV_FUNCNAME( "CvDTreeTrainData::write_params" );\r
1269 \r
1270     __BEGIN__;\r
1271 \r
1272     int vi, vcount = var_count;\r
1273 \r
1274     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );\r
1275     cvWriteInt( fs, "var_all", var_all );\r
1276     cvWriteInt( fs, "var_count", var_count );\r
1277     cvWriteInt( fs, "ord_var_count", ord_var_count );\r
1278     cvWriteInt( fs, "cat_var_count", cat_var_count );\r
1279 \r
1280     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );\r
1281     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );\r
1282 \r
1283     if( is_classifier )\r
1284     {\r
1285         cvWriteInt( fs, "max_categories", params.max_categories );\r
1286     }\r
1287     else\r
1288     {\r
1289         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );\r
1290     }\r
1291 \r
1292     cvWriteInt( fs, "max_depth", params.max_depth );\r
1293     cvWriteInt( fs, "min_sample_count", params.min_sample_count );\r
1294     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );\r
1295 \r
1296     if( params.cv_folds > 1 )\r
1297     {\r
1298         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );\r
1299         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );\r
1300     }\r
1301 \r
1302     if( priors )\r
1303         cvWrite( fs, "priors", priors );\r
1304 \r
1305     cvEndWriteStruct( fs );\r
1306 \r
1307     if( var_idx )\r
1308         cvWrite( fs, "var_idx", var_idx );\r
1309 \r
1310     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );\r
1311 \r
1312     for( vi = 0; vi < vcount; vi++ )\r
1313         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );\r
1314 \r
1315     cvEndWriteStruct( fs );\r
1316 \r
1317     if( cat_count && (cat_var_count > 0 || is_classifier) )\r
1318     {\r
1319         CV_ASSERT( cat_count != 0 );\r
1320         cvWrite( fs, "cat_count", cat_count );\r
1321         cvWrite( fs, "cat_map", cat_map );\r
1322     }\r
1323 \r
1324     __END__;\r
1325 }\r
1326 \r
1327 \r
1328 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )\r
1329 {\r
1330     CV_FUNCNAME( "CvDTreeTrainData::read_params" );\r
1331 \r
1332     __BEGIN__;\r
1333 \r
1334     CvFileNode *tparams_node, *vartype_node;\r
1335     CvSeqReader reader;\r
1336     int vi, max_split_size, tree_block_size;\r
1337 \r
1338     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);\r
1339     var_all = cvReadIntByName( fs, node, "var_all" );\r
1340     var_count = cvReadIntByName( fs, node, "var_count", var_all );\r
1341     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );\r
1342     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );\r
1343 \r
1344     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );\r
1345 \r
1346     if( tparams_node ) // training parameters are not necessary\r
1347     {\r
1348         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;\r
1349 \r
1350         if( is_classifier )\r
1351         {\r
1352             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );\r
1353         }\r
1354         else\r
1355         {\r
1356             params.regression_accuracy =\r
1357                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );\r
1358         }\r
1359 \r
1360         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );\r
1361         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );\r
1362         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );\r
1363 \r
1364         if( params.cv_folds > 1 )\r
1365         {\r
1366             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;\r
1367             params.truncate_pruned_tree =\r
1368                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;\r
1369         }\r
1370 \r
1371         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );\r
1372         if( priors )\r
1373         {\r
1374             if( !CV_IS_MAT(priors) )\r
1375                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );\r
1376             priors_mult = cvCloneMat( priors );\r
1377         }\r
1378     }\r
1379 \r
1380     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));\r
1381     if( var_idx )\r
1382     {\r
1383         if( !CV_IS_MAT(var_idx) ||\r
1384             (var_idx->cols != 1 && var_idx->rows != 1) ||\r
1385             var_idx->cols + var_idx->rows - 1 != var_count ||\r
1386             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )\r
1387             CV_ERROR( CV_StsParseError,\r
1388                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );\r
1389 \r
1390         for( vi = 0; vi < var_count; vi++ )\r
1391             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )\r
1392                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );\r
1393     }\r
1394 \r
1395     ////// read var type\r
1396     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));\r
1397 \r
1398     cat_var_count = 0;\r
1399     ord_var_count = -1;\r
1400     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );\r
1401 \r
1402     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )\r
1403         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;\r
1404     else\r
1405     {\r
1406         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||\r
1407             vartype_node->data.seq->total != var_count )\r
1408             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1409 \r
1410         cvStartReadSeq( vartype_node->data.seq, &reader );\r
1411 \r
1412         for( vi = 0; vi < var_count; vi++ )\r
1413         {\r
1414             CvFileNode* n = (CvFileNode*)reader.ptr;\r
1415             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )\r
1416                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1417             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;\r
1418             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
1419         }\r
1420     }\r
1421     var_type->data.i[var_count] = cat_var_count;\r
1422 \r
1423     ord_var_count = ~ord_var_count;\r
1424     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )\r
1425         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );\r
1426     //////\r
1427 \r
1428     if( cat_var_count > 0 || is_classifier )\r
1429     {\r
1430         int ccount, total_c_count = 0;\r
1431         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));\r
1432         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));\r
1433 \r
1434         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||\r
1435             (cat_count->cols != 1 && cat_count->rows != 1) ||\r
1436             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||\r
1437             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||\r
1438             (cat_map->cols != 1 && cat_map->rows != 1) ||\r
1439             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )\r
1440             CV_ERROR( CV_StsParseError,\r
1441             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );\r
1442 \r
1443         ccount = cat_var_count + is_classifier;\r
1444 \r
1445         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));\r
1446         cat_ofs->data.i[0] = 0;\r
1447         max_c_count = 1;\r
1448 \r
1449         for( vi = 0; vi < ccount; vi++ )\r
1450         {\r
1451             int val = cat_count->data.i[vi];\r
1452             if( val <= 0 )\r
1453                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );\r
1454             max_c_count = MAX( max_c_count, val );\r
1455             cat_ofs->data.i[vi+1] = total_c_count += val;\r
1456         }\r
1457 \r
1458         if( cat_map->cols + cat_map->rows - 1 != total_c_count )\r
1459             CV_ERROR( CV_StsBadSize,\r
1460             "cat_map vector length is not equal to the total number of categories in all categorical vars" );\r
1461     }\r
1462 \r
1463     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
1464         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
1465 \r
1466     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
1467     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
1468     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
1469     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),\r
1470             sizeof(CvDTreeNode), tree_storage ));\r
1471     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),\r
1472             max_split_size, tree_storage ));\r
1473 \r
1474     __END__;\r
1475 }\r
1476 /////////////////////// Decision Tree /////////////////////////\r
1477 \r
1478 CvDTree::CvDTree()\r
1479 {\r
1480     data = 0;\r
1481     var_importance = 0;\r
1482     default_model_name = "my_tree";\r
1483 \r
1484     clear();\r
1485 }\r
1486 \r
1487 \r
1488 void CvDTree::clear()\r
1489 {\r
1490     cvReleaseMat( &var_importance );\r
1491     if( data )\r
1492     {\r
1493         if( !data->shared )\r
1494             delete data;\r
1495         else\r
1496             free_tree();\r
1497         data = 0;\r
1498     }\r
1499     root = 0;\r
1500     pruned_tree_idx = -1;\r
1501 }\r
1502 \r
1503 \r
1504 CvDTree::~CvDTree()\r
1505 {\r
1506     clear();\r
1507 }\r
1508 \r
1509 \r
1510 const CvDTreeNode* CvDTree::get_root() const\r
1511 {\r
1512     return root;\r
1513 }\r
1514 \r
1515 \r
1516 int CvDTree::get_pruned_tree_idx() const\r
1517 {\r
1518     return pruned_tree_idx;\r
1519 }\r
1520 \r
1521 \r
1522 CvDTreeTrainData* CvDTree::get_data()\r
1523 {\r
1524     return data;\r
1525 }\r
1526 \r
1527 \r
1528 bool CvDTree::train( const CvMat* _train_data, int _tflag,\r
1529                      const CvMat* _responses, const CvMat* _var_idx,\r
1530                      const CvMat* _sample_idx, const CvMat* _var_type,\r
1531                      const CvMat* _missing_mask, CvDTreeParams _params )\r
1532 {\r
1533     bool result = false;\r
1534 \r
1535     CV_FUNCNAME( "CvDTree::train" );\r
1536 \r
1537     __BEGIN__;\r
1538 \r
1539     clear();\r
1540     data = new CvDTreeTrainData( _train_data, _tflag, _responses,\r
1541                                  _var_idx, _sample_idx, _var_type,\r
1542                                  _missing_mask, _params, false );\r
1543     CV_CALL( result = do_train(0) );\r
1544 \r
1545     __END__;\r
1546 \r
1547     return result;\r
1548 }\r
1549 \r
1550 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )\r
1551 {\r
1552    bool result = false;\r
1553 \r
1554     CV_FUNCNAME( "CvDTree::train" );\r
1555 \r
1556     __BEGIN__;\r
1557 \r
1558     const CvMat* values = _data->get_values();
1559     const CvMat* response = _data->get_response();
1560     const CvMat* missing = _data->get_missing();
1561     const CvMat* var_types = _data->get_var_types();
1562     const CvMat* train_sidx = _data->get_train_sample_idx();
1563     const CvMat* var_idx = _data->get_var_idx();\r
1564 \r
1565     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,\r
1566         train_sidx, var_types, missing, _params ) );\r
1567 \r
1568     __END__;\r
1569 \r
1570     return result;\r
1571 }\r
1572 \r
1573 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )\r
1574 {\r
1575     bool result = false;\r
1576 \r
1577     CV_FUNCNAME( "CvDTree::train" );\r
1578 \r
1579     __BEGIN__;\r
1580 \r
1581     clear();\r
1582     data = _data;\r
1583     data->shared = true;\r
1584     CV_CALL( result = do_train(_subsample_idx));\r
1585 \r
1586     __END__;\r
1587 \r
1588     return result;\r
1589 }\r
1590 \r
1591 \r
1592 bool CvDTree::do_train( const CvMat* _subsample_idx )\r
1593 {\r
1594     bool result = false;\r
1595 \r
1596     CV_FUNCNAME( "CvDTree::do_train" );\r
1597 \r
1598     __BEGIN__;\r
1599 \r
1600     root = data->subsample_data( _subsample_idx );\r
1601 \r
1602     CV_CALL( try_split_node(root));\r
1603 \r
1604     if( data->params.cv_folds > 0 )\r
1605         CV_CALL( prune_cv());\r
1606 \r
1607     if( !data->shared )\r
1608         data->free_train_data();\r
1609 \r
1610     result = true;\r
1611 \r
1612     __END__;\r
1613 \r
1614     return result;\r
1615 }\r
1616 \r
1617 \r
1618 void CvDTree::try_split_node( CvDTreeNode* node )\r
1619 {\r
1620     CvDTreeSplit* best_split = 0;\r
1621     int i, n = node->sample_count, vi;\r
1622     bool can_split = true;\r
1623     double quality_scale;\r
1624 \r
1625     calc_node_value( node );\r
1626 \r
1627     if( node->sample_count <= data->params.min_sample_count ||\r
1628         node->depth >= data->params.max_depth )\r
1629         can_split = false;\r
1630 \r
1631     if( can_split && data->is_classifier )\r
1632     {\r
1633         // check if we have a "pure" node,\r
1634         // we assume that cls_count is filled by calc_node_value()\r
1635         int* cls_count = data->counts->data.i;\r
1636         int nz = 0, m = data->get_num_classes();\r
1637         for( i = 0; i < m; i++ )\r
1638             nz += cls_count[i] != 0;\r
1639         if( nz == 1 ) // there is only one class\r
1640             can_split = false;\r
1641     }\r
1642     else if( can_split )\r
1643     {\r
1644         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )\r
1645             can_split = false;\r
1646     }\r
1647 \r
1648     if( can_split )\r
1649     {\r
1650         best_split = find_best_split(node);\r
1651         // TODO: check the split quality ...\r
1652         node->split = best_split;\r
1653     }\r
1654 \r
1655     if( !can_split || !best_split )\r
1656     {\r
1657         data->free_node_data(node);\r
1658         return;\r
1659     }\r
1660 \r
1661     quality_scale = calc_node_dir( node );\r
1662 \r
1663     if( data->params.use_surrogates )\r
1664     {\r
1665         // find all the surrogate splits\r
1666         // and sort them by their similarity to the primary one\r
1667         for( vi = 0; vi < data->var_count; vi++ )\r
1668         {\r
1669             CvDTreeSplit* split;\r
1670             int ci = data->get_var_type(vi);\r
1671 \r
1672             if( vi == best_split->var_idx )\r
1673                 continue;\r
1674 \r
1675             if( ci >= 0 )\r
1676                 split = find_surrogate_split_cat( node, vi );\r
1677             else\r
1678                 split = find_surrogate_split_ord( node, vi );\r
1679 \r
1680             if( split )\r
1681             {\r
1682                 // insert the split\r
1683                 CvDTreeSplit* prev_split = node->split;\r
1684                 split->quality = (float)(split->quality*quality_scale);\r
1685 \r
1686                 while( prev_split->next &&\r
1687                        prev_split->next->quality > split->quality )\r
1688                     prev_split = prev_split->next;\r
1689                 split->next = prev_split->next;\r
1690                 prev_split->next = split;\r
1691             }\r
1692         }\r
1693     }\r
1694 \r
1695     split_node_data( node );\r
1696     try_split_node( node->left );\r
1697     try_split_node( node->right );\r
1698 }\r
1699 \r
1700 \r
1701 // calculate direction (left(-1),right(1),missing(0))\r
1702 // for each sample using the best split\r
1703 // the function returns scale coefficients for surrogate split quality factors.\r
1704 // the scale is applied to normalize surrogate split quality relatively to the\r
1705 // best (primary) split quality. That is, if a surrogate split is absolutely\r
1706 // identical to the primary split, its quality will be set to the maximum value =\r
1707 // quality of the primary split; otherwise, it will be lower.\r
1708 // besides, the function compute node->maxlr,\r
1709 // minimum possible quality (w/o considering the above mentioned scale)\r
1710 // for a surrogate split. Surrogate splits with quality less than node->maxlr\r
1711 // are not discarded.\r
1712 double CvDTree::calc_node_dir( CvDTreeNode* node )\r
1713 {\r
1714     char* dir = (char*)data->direction->data.ptr;\r
1715     int i, n = node->sample_count, vi = node->split->var_idx;\r
1716     double L, R;\r
1717 \r
1718     assert( !node->split->inversed );\r
1719 \r
1720     if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
1721     {\r
1722         int* labels_buf = data->pred_int_buf;\r
1723         const int* labels = 0;\r
1724         const int* subset = node->split->subset;\r
1725         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1726         if( !data->have_priors )\r
1727         {\r
1728             int sum = 0, sum_abs = 0;\r
1729 \r
1730             for( i = 0; i < n; i++ )\r
1731             {\r
1732                 int idx = labels[i];\r
1733                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?\r
1734                     CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1735                 sum += d; sum_abs += d & 1;\r
1736                 dir[i] = (char)d;\r
1737             }\r
1738 \r
1739             R = (sum_abs + sum) >> 1;\r
1740             L = (sum_abs - sum) >> 1;\r
1741         }\r
1742         else\r
1743         {\r
1744             const double* priors = data->priors_mult->data.db;\r
1745             double sum = 0, sum_abs = 0;\r
1746             int *responses_buf = data->resp_int_buf;\r
1747             const int* responses;\r
1748             data->get_class_labels(node, responses_buf, &responses);\r
1749 \r
1750             for( i = 0; i < n; i++ )\r
1751             {\r
1752                 int idx = labels[i];\r
1753                 double w = priors[responses[i]];\r
1754                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1755                 sum += d*w; sum_abs += (d & 1)*w;\r
1756                 dir[i] = (char)d;\r
1757             }\r
1758 \r
1759             R = (sum_abs + sum) * 0.5;\r
1760             L = (sum_abs - sum) * 0.5;\r
1761         }\r
1762     }\r
1763     else // split on ordered var\r
1764     {\r
1765         int split_point = node->split->ord.split_point;\r
1766         int n1 = node->get_num_valid(vi);\r
1767         \r
1768         float* val_buf = data->pred_float_buf;\r
1769         const float* val = 0;\r
1770         int* sorted_buf = data->pred_int_buf;\r
1771         const int* sorted = 0;\r
1772         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);\r
1773 \r
1774         assert( 0 <= split_point && split_point < n1-1 );\r
1775 \r
1776         if( !data->have_priors )\r
1777         {\r
1778             for( i = 0; i <= split_point; i++ )\r
1779                 dir[sorted[i]] = (char)-1;\r
1780             for( ; i < n1; i++ )\r
1781                 dir[sorted[i]] = (char)1;\r
1782             for( ; i < n; i++ )\r
1783                 dir[sorted[i]] = (char)0;\r
1784 \r
1785             L = split_point-1;\r
1786             R = n1 - split_point + 1;\r
1787         }\r
1788         else\r
1789         {\r
1790             const double* priors = data->priors_mult->data.db;\r
1791             int* responses_buf = data->resp_int_buf;\r
1792             const int* responses = 0;\r
1793             data->get_class_labels(node, responses_buf, &responses);\r
1794             L = R = 0;\r
1795 \r
1796             for( i = 0; i <= split_point; i++ )\r
1797             {\r
1798                 int idx = sorted[i];\r
1799                 double w = priors[responses[idx]];\r
1800                 dir[idx] = (char)-1;\r
1801                 L += w;\r
1802             }\r
1803 \r
1804             for( ; i < n1; i++ )\r
1805             {\r
1806                 int idx = sorted[i];\r
1807                 double w = priors[responses[idx]];\r
1808                 dir[idx] = (char)1;\r
1809                 R += w;\r
1810             }\r
1811 \r
1812             for( ; i < n; i++ )\r
1813                 dir[sorted[i]] = (char)0;\r
1814         }\r
1815     }\r
1816 \r
1817     node->maxlr = MAX( L, R );\r
1818     return node->split->quality/(L + R);\r
1819 }\r
1820 \r
1821 \r
1822 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )\r
1823 {\r
1824     int vi;\r
1825     CvDTreeSplit *best_split = 0, *split = 0, *t;\r
1826 \r
1827     for( vi = 0; vi < data->var_count; vi++ )\r
1828     {\r
1829         int ci = data->get_var_type(vi);\r
1830         if( node->get_num_valid(vi) <= 1 )\r
1831             continue;\r
1832 \r
1833         if( data->is_classifier )\r
1834         {\r
1835             if( ci >= 0 )\r
1836                 split = find_split_cat_class( node, vi );\r
1837             else\r
1838                 split = find_split_ord_class( node, vi );\r
1839         }\r
1840         else\r
1841         {\r
1842             if( ci >= 0 )\r
1843                 split = find_split_cat_reg( node, vi );\r
1844             else\r
1845                 split = find_split_ord_reg( node, vi );\r
1846         }\r
1847 \r
1848         if( split )\r
1849         {\r
1850             if( !best_split || best_split->quality < split->quality )\r
1851                 CV_SWAP( best_split, split, t );\r
1852             if( split )\r
1853                 cvSetRemoveByPtr( data->split_heap, split );\r
1854         }\r
1855     }\r
1856 \r
1857     return best_split;\r
1858 }\r
1859 \r
1860 \r
1861 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )\r
1862 {\r
1863     const float epsilon = FLT_EPSILON*2;\r
1864     int n = node->sample_count;\r
1865     int n1 = node->get_num_valid(vi);\r
1866     int m = data->get_num_classes();\r
1867 \r
1868     float* values_buf = data->pred_float_buf;\r
1869     const float* values = 0;\r
1870     int* indices_buf = data->pred_int_buf;\r
1871     const int* indices = 0;\r
1872     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
1873     int* responses_buf =  data->resp_int_buf;\r
1874     const int* responses = 0;\r
1875     data->get_class_labels( node, responses_buf, &responses );\r
1876 \r
1877     const int* rc0 = data->counts->data.i;\r
1878     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
1879     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
1880     int i, best_i = -1;\r
1881     double lsum2 = 0, rsum2 = 0, best_val = 0;\r
1882     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
1883 \r
1884     // init arrays of class instance counters on both sides of the split\r
1885     for( i = 0; i < m; i++ )\r
1886     {\r
1887         lc[i] = 0;\r
1888         rc[i] = rc0[i];\r
1889     }\r
1890 \r
1891     // compensate for missing values\r
1892     for( i = n1; i < n; i++ )\r
1893     {\r
1894         rc[responses[indices[i]]]--;\r
1895     }\r
1896 \r
1897     if( !priors )\r
1898     {\r
1899         int L = 0, R = n1;\r
1900 \r
1901         for( i = 0; i < m; i++ )\r
1902             rsum2 += (double)rc[i]*rc[i];\r
1903 \r
1904         for( i = 0; i < n1 - 1; i++ )\r
1905         {\r
1906             int idx = responses[indices[i]];\r
1907             int lv, rv;\r
1908             L++; R--;\r
1909             lv = lc[idx]; rv = rc[idx];\r
1910             lsum2 += lv*2 + 1;\r
1911             rsum2 -= rv*2 - 1;\r
1912             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1913 \r
1914             if( values[i] + epsilon < values[i+1] )\r
1915             {\r
1916                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1917                 if( best_val < val )\r
1918                 {\r
1919                     best_val = val;\r
1920                     best_i = i;\r
1921                 }\r
1922             }\r
1923         }\r
1924     }\r
1925     else\r
1926     {\r
1927         double L = 0, R = 0;\r
1928         for( i = 0; i < m; i++ )\r
1929         {\r
1930             double wv = rc[i]*priors[i];\r
1931             R += wv;\r
1932             rsum2 += wv*wv;\r
1933         }\r
1934 \r
1935         for( i = 0; i < n1 - 1; i++ )\r
1936         {\r
1937             int idx = responses[indices[i]];\r
1938             int lv, rv;\r
1939             double p = priors[idx], p2 = p*p;\r
1940             L += p; R -= p;\r
1941             lv = lc[idx]; rv = rc[idx];\r
1942             lsum2 += p2*(lv*2 + 1);\r
1943             rsum2 -= p2*(rv*2 - 1);\r
1944             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1945 \r
1946             if( values[i] + epsilon < values[i+1] )\r
1947             {\r
1948                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1949                 if( best_val < val )\r
1950                 {\r
1951                     best_val = val;\r
1952                     best_i = i;\r
1953                 }\r
1954             }\r
1955         }\r
1956     }\r
1957 \r
1958     return best_i >= 0 ? data->new_split_ord( vi,\r
1959         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
1960         0, (float)best_val ) : 0;\r
1961 }\r
1962 \r
1963 \r
1964 void CvDTree::cluster_categories( const int* vectors, int n, int m,\r
1965                                 int* csums, int k, int* labels )\r
1966 {\r
1967     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm\r
1968     int iters = 0, max_iters = 100;\r
1969     int i, j, idx;\r
1970     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );\r
1971     double *v_weights = buf, *c_weights = buf + k;\r
1972     bool modified = true;\r
1973     CvRNG* r = &data->rng;\r
1974 \r
1975     // assign labels randomly\r
1976     for( i = idx = 0; i < n; i++ )\r
1977     {\r
1978         int sum = 0;\r
1979         const int* v = vectors + i*m;\r
1980         labels[i] = idx++;\r
1981         idx &= idx < k ? -1 : 0;\r
1982 \r
1983         // compute weight of each vector\r
1984         for( j = 0; j < m; j++ )\r
1985             sum += v[j];\r
1986         v_weights[i] = sum ? 1./sum : 0.;\r
1987     }\r
1988 \r
1989     for( i = 0; i < n; i++ )\r
1990     {\r
1991         int i1 = cvRandInt(r) % n;\r
1992         int i2 = cvRandInt(r) % n;\r
1993         CV_SWAP( labels[i1], labels[i2], j );\r
1994     }\r
1995 \r
1996     for( iters = 0; iters <= max_iters; iters++ )\r
1997     {\r
1998         // calculate csums\r
1999         for( i = 0; i < k; i++ )\r
2000         {\r
2001             for( j = 0; j < m; j++ )\r
2002                 csums[i*m + j] = 0;\r
2003         }\r
2004 \r
2005         for( i = 0; i < n; i++ )\r
2006         {\r
2007             const int* v = vectors + i*m;\r
2008             int* s = csums + labels[i]*m;\r
2009             for( j = 0; j < m; j++ )\r
2010                 s[j] += v[j];\r
2011         }\r
2012 \r
2013         // exit the loop here, when we have up-to-date csums\r
2014         if( iters == max_iters || !modified )\r
2015             break;\r
2016 \r
2017         modified = false;\r
2018 \r
2019         // calculate weight of each cluster\r
2020         for( i = 0; i < k; i++ )\r
2021         {\r
2022             const int* s = csums + i*m;\r
2023             int sum = 0;\r
2024             for( j = 0; j < m; j++ )\r
2025                 sum += s[j];\r
2026             c_weights[i] = sum ? 1./sum : 0;\r
2027         }\r
2028 \r
2029         // now for each vector determine the closest cluster\r
2030         for( i = 0; i < n; i++ )\r
2031         {\r
2032             const int* v = vectors + i*m;\r
2033             double alpha = v_weights[i];\r
2034             double min_dist2 = DBL_MAX;\r
2035             int min_idx = -1;\r
2036 \r
2037             for( idx = 0; idx < k; idx++ )\r
2038             {\r
2039                 const int* s = csums + idx*m;\r
2040                 double dist2 = 0., beta = c_weights[idx];\r
2041                 for( j = 0; j < m; j++ )\r
2042                 {\r
2043                     double t = v[j]*alpha - s[j]*beta;\r
2044                     dist2 += t*t;\r
2045                 }\r
2046                 if( min_dist2 > dist2 )\r
2047                 {\r
2048                     min_dist2 = dist2;\r
2049                     min_idx = idx;\r
2050                 }\r
2051             }\r
2052 \r
2053             if( min_idx != labels[i] )\r
2054                 modified = true;\r
2055             labels[i] = min_idx;\r
2056         }\r
2057     }\r
2058 }\r
2059 \r
2060 \r
2061 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )\r
2062 {\r
2063     CvDTreeSplit* split = 0;\r
2064     int ci = data->get_var_type(vi);\r
2065     int n = node->sample_count;\r
2066     int m = data->get_num_classes();\r
2067     int _mi = data->cat_count->data.i[ci], mi = _mi;\r
2068 \r
2069     int* labels_buf = data->pred_int_buf;\r
2070     const int* labels = 0;\r
2071     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2072     int *responses_buf = data->resp_int_buf;\r
2073     const int* responses = 0;\r
2074     data->get_class_labels(node, responses_buf, &responses);\r
2075 \r
2076     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
2077     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
2078     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;\r
2079     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );\r
2080     int* cluster_labels = 0;\r
2081     int** int_ptr = 0;\r
2082     int i, j, k, idx;\r
2083     double L = 0, R = 0;\r
2084     double best_val = 0;\r
2085     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;\r
2086     const double* priors = data->priors_mult->data.db;\r
2087 \r
2088     // init array of counters:\r
2089     // c_{jk} - number of samples that have vi-th input variable = j and response = k.\r
2090     for( j = -1; j < mi; j++ )\r
2091         for( k = 0; k < m; k++ )\r
2092             cjk[j*m + k] = 0;\r
2093 \r
2094     for( i = 0; i < n; i++ )\r
2095     {\r
2096        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];\r
2097        k = responses[i];\r
2098        cjk[j*m + k]++;\r
2099     }\r
2100 \r
2101     if( m > 2 )\r
2102     {\r
2103         if( mi > data->params.max_categories )\r
2104         {\r
2105             mi = MIN(data->params.max_categories, n);\r
2106             cjk += _mi*m;\r
2107             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));\r
2108             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );\r
2109         }\r
2110         subset_i = 1;\r
2111         subset_n = 1 << mi;\r
2112     }\r
2113     else\r
2114     {\r
2115         assert( m == 2 );\r
2116         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );\r
2117         for( j = 0; j < mi; j++ )\r
2118             int_ptr[j] = cjk + j*2 + 1;\r
2119         icvSortIntPtr( int_ptr, mi, 0 );\r
2120         subset_i = 0;\r
2121         subset_n = mi;\r
2122     }\r
2123 \r
2124     for( k = 0; k < m; k++ )\r
2125     {\r
2126         int sum = 0;\r
2127         for( j = 0; j < mi; j++ )\r
2128             sum += cjk[j*m + k];\r
2129         rc[k] = sum;\r
2130         lc[k] = 0;\r
2131     }\r
2132 \r
2133     for( j = 0; j < mi; j++ )\r
2134     {\r
2135         double sum = 0;\r
2136         for( k = 0; k < m; k++ )\r
2137             sum += cjk[j*m + k]*priors[k];\r
2138         c_weights[j] = sum;\r
2139         R += c_weights[j];\r
2140     }\r
2141 \r
2142     for( ; subset_i < subset_n; subset_i++ )\r
2143     {\r
2144         double weight;\r
2145         int* crow;\r
2146         double lsum2 = 0, rsum2 = 0;\r
2147 \r
2148         if( m == 2 )\r
2149             idx = (int)(int_ptr[subset_i] - cjk)/2;\r
2150         else\r
2151         {\r
2152             int graycode = (subset_i>>1)^subset_i;\r
2153             int diff = graycode ^ prevcode;\r
2154 \r
2155             // determine index of the changed bit.\r
2156             Cv32suf u;\r
2157             idx = diff >= (1 << 16) ? 16 : 0;\r
2158             u.f = (float)(((diff >> 16) | diff) & 65535);\r
2159             idx += (u.i >> 23) - 127;\r
2160             subtract = graycode < prevcode;\r
2161             prevcode = graycode;\r
2162         }\r
2163 \r
2164         crow = cjk + idx*m;\r
2165         weight = c_weights[idx];\r
2166         if( weight < FLT_EPSILON )\r
2167             continue;\r
2168 \r
2169         if( !subtract )\r
2170         {\r
2171             for( k = 0; k < m; k++ )\r
2172             {\r
2173                 int t = crow[k];\r
2174                 int lval = lc[k] + t;\r
2175                 int rval = rc[k] - t;\r
2176                 double p = priors[k], p2 = p*p;\r
2177                 lsum2 += p2*lval*lval;\r
2178                 rsum2 += p2*rval*rval;\r
2179                 lc[k] = lval; rc[k] = rval;\r
2180             }\r
2181             L += weight;\r
2182             R -= weight;\r
2183         }\r
2184         else\r
2185         {\r
2186             for( k = 0; k < m; k++ )\r
2187             {\r
2188                 int t = crow[k];\r
2189                 int lval = lc[k] - t;\r
2190                 int rval = rc[k] + t;\r
2191                 double p = priors[k], p2 = p*p;\r
2192                 lsum2 += p2*lval*lval;\r
2193                 rsum2 += p2*rval*rval;\r
2194                 lc[k] = lval; rc[k] = rval;\r
2195             }\r
2196             L -= weight;\r
2197             R += weight;\r
2198         }\r
2199 \r
2200         if( L > FLT_EPSILON && R > FLT_EPSILON )\r
2201         {\r
2202             double val = (lsum2*R + rsum2*L)/((double)L*R);\r
2203             if( best_val < val )\r
2204             {\r
2205                 best_val = val;\r
2206                 best_subset = subset_i;\r
2207             }\r
2208         }\r
2209     }\r
2210 \r
2211     if( best_subset < 0 )\r
2212         return 0;\r
2213 \r
2214     split = data->new_split_cat( vi, (float)best_val );\r
2215 \r
2216     if( m == 2 )\r
2217     {\r
2218         for( i = 0; i <= best_subset; i++ )\r
2219         {\r
2220             idx = (int)(int_ptr[i] - cjk) >> 1;\r
2221             split->subset[idx >> 5] |= 1 << (idx & 31);\r
2222         }\r
2223     }\r
2224     else\r
2225     {\r
2226         for( i = 0; i < _mi; i++ )\r
2227         {\r
2228             idx = cluster_labels ? cluster_labels[i] : i;\r
2229             if( best_subset & (1 << idx) )\r
2230                 split->subset[i >> 5] |= 1 << (i & 31);\r
2231         }\r
2232     }\r
2233 \r
2234     return split;\r
2235 }\r
2236 \r
2237 \r
2238 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )\r
2239 {\r
2240     const float epsilon = FLT_EPSILON*2;\r
2241     int n = node->sample_count;\r
2242     int n1 = node->get_num_valid(vi);\r
2243 \r
2244     float* values_buf = data->pred_float_buf;\r
2245     const float* values = 0;\r
2246     int* indices_buf = data->pred_int_buf;\r
2247     const int* indices = 0;\r
2248     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2249     float* responses_buf =  data->resp_float_buf;\r
2250     const float* responses = 0;\r
2251     data->get_ord_responses( node, responses_buf, &responses );\r
2252 \r
2253     int i, best_i = -1;\r
2254     double best_val = 0, lsum = 0, rsum = node->value*n;\r
2255     int L = 0, R = n1;\r
2256 \r
2257     // compensate for missing values\r
2258     for( i = n1; i < n; i++ )\r
2259         rsum -= responses[indices[i]];\r
2260 \r
2261     // find the optimal split\r
2262     for( i = 0; i < n1 - 1; i++ )\r
2263     {\r
2264         float t = responses[indices[i]];\r
2265         L++; R--;\r
2266         lsum += t;\r
2267         rsum -= t;\r
2268 \r
2269         if( values[i] + epsilon < values[i+1] )\r
2270         {\r
2271             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2272             if( best_val < val )\r
2273             {\r
2274                 best_val = val;\r
2275                 best_i = i;\r
2276             }\r
2277         }\r
2278     }\r
2279 \r
2280     return best_i >= 0 ? data->new_split_ord( vi,\r
2281         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2282         0, (float)best_val ) : 0;\r
2283 }\r
2284 \r
2285 \r
2286 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )\r
2287 {\r
2288     CvDTreeSplit* split;\r
2289     int ci = data->get_var_type(vi);\r
2290     int n = node->sample_count;\r
2291     int mi = data->cat_count->data.i[ci];\r
2292     int* labels_buf = data->pred_int_buf;\r
2293     const int* labels = 0;\r
2294     float* responses_buf = data->resp_float_buf;\r
2295     const float* responses = 0;\r
2296     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2297     data->get_ord_responses(node, responses_buf, &responses);\r
2298 \r
2299     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;\r
2300     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;\r
2301     double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );\r
2302     int i, L = 0, R = 0;\r
2303     double best_val = 0, lsum = 0, rsum = 0;\r
2304     int best_subset = -1, subset_i;\r
2305 \r
2306     for( i = -1; i < mi; i++ )\r
2307         sum[i] = counts[i] = 0;\r
2308 \r
2309     // calculate sum response and weight of each category of the input var\r
2310     for( i = 0; i < n; i++ )\r
2311     {\r
2312         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];\r
2313         double s = sum[idx] + responses[i];\r
2314         int nc = counts[idx] + 1;\r
2315         sum[idx] = s;\r
2316         counts[idx] = nc;\r
2317     }\r
2318 \r
2319     // calculate average response in each category\r
2320     for( i = 0; i < mi; i++ )\r
2321     {\r
2322         R += counts[i];\r
2323         rsum += sum[i];\r
2324         sum[i] /= MAX(counts[i],1);\r
2325         sum_ptr[i] = sum + i;\r
2326     }\r
2327 \r
2328     icvSortDblPtr( sum_ptr, mi, 0 );\r
2329 \r
2330     // revert back to unnormalized sums\r
2331     // (there should be a very little loss of accuracy)\r
2332     for( i = 0; i < mi; i++ )\r
2333         sum[i] *= counts[i];\r
2334 \r
2335     for( subset_i = 0; subset_i < mi-1; subset_i++ )\r
2336     {\r
2337         int idx = (int)(sum_ptr[subset_i] - sum);\r
2338         int ni = counts[idx];\r
2339 \r
2340         if( ni )\r
2341         {\r
2342             double s = sum[idx];\r
2343             lsum += s; L += ni;\r
2344             rsum -= s; R -= ni;\r
2345 \r
2346             if( L && R )\r
2347             {\r
2348                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2349                 if( best_val < val )\r
2350                 {\r
2351                     best_val = val;\r
2352                     best_subset = subset_i;\r
2353                 }\r
2354             }\r
2355         }\r
2356     }\r
2357 \r
2358     if( best_subset < 0 )\r
2359         return 0;\r
2360 \r
2361     split = data->new_split_cat( vi, (float)best_val );\r
2362     for( i = 0; i <= best_subset; i++ )\r
2363     {\r
2364         int idx = (int)(sum_ptr[i] - sum);\r
2365         split->subset[idx >> 5] |= 1 << (idx & 31);\r
2366     }\r
2367 \r
2368     return split;\r
2369 }\r
2370 \r
2371 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )\r
2372 {\r
2373     const float epsilon = FLT_EPSILON*2;\r
2374     const char* dir = (char*)data->direction->data.ptr;\r
2375     int n1 = node->get_num_valid(vi);\r
2376     float* values_buf = data->pred_float_buf;\r
2377     const float* values = 0;\r
2378     int* indices_buf = data->pred_int_buf;\r
2379     const int* indices = 0;\r
2380     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2381     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2382     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2383     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2384     // RR - ... both send to the right\r
2385     int i, best_i = -1, best_inversed = 0;\r
2386     double best_val;\r
2387 \r
2388     if( !data->have_priors )\r
2389     {\r
2390         int LL = 0, RL = 0, LR, RR;\r
2391         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;\r
2392         int sum = 0, sum_abs = 0;\r
2393 \r
2394         for( i = 0; i < n1; i++ )\r
2395         {\r
2396             int d = dir[indices[i]];\r
2397             sum += d; sum_abs += d & 1;\r
2398         }\r
2399 \r
2400         // sum_abs = R + L; sum = R - L\r
2401         RR = (sum_abs + sum) >> 1;\r
2402         LR = (sum_abs - sum) >> 1;\r
2403 \r
2404         // initially all the samples are sent to the right by the surrogate split,\r
2405         // LR of them are sent to the left by primary split, and RR - to the right.\r
2406         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2407         for( i = 0; i < n1 - 1; i++ )\r
2408         {\r
2409             int d = dir[indices[i]];\r
2410 \r
2411             if( d < 0 )\r
2412             {\r
2413                 LL++; LR--;\r
2414                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )\r
2415                 {\r
2416                     best_val = LL + RR;\r
2417                     best_i = i; best_inversed = 0;\r
2418                 }\r
2419             }\r
2420             else if( d > 0 )\r
2421             {\r
2422                 RL++; RR--;\r
2423                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )\r
2424                 {\r
2425                     best_val = RL + LR;\r
2426                     best_i = i; best_inversed = 1;\r
2427                 }\r
2428             }\r
2429         }\r
2430         best_val = _best_val;\r
2431     }\r
2432     else\r
2433     {\r
2434         double LL = 0, RL = 0, LR, RR;\r
2435         double worst_val = node->maxlr;\r
2436         double sum = 0, sum_abs = 0;\r
2437         const double* priors = data->priors_mult->data.db;\r
2438         int* responses_buf = data->resp_int_buf;\r
2439         const int* responses = 0;\r
2440         data->get_class_labels(node, responses_buf, &responses);\r
2441         best_val = worst_val;\r
2442 \r
2443         for( i = 0; i < n1; i++ )\r
2444         {\r
2445             int idx = indices[i];\r
2446             double w = priors[responses[idx]];\r
2447             int d = dir[idx];\r
2448             sum += d*w; sum_abs += (d & 1)*w;\r
2449         }\r
2450 \r
2451         // sum_abs = R + L; sum = R - L\r
2452         RR = (sum_abs + sum)*0.5;\r
2453         LR = (sum_abs - sum)*0.5;\r
2454 \r
2455         // initially all the samples are sent to the right by the surrogate split,\r
2456         // LR of them are sent to the left by primary split, and RR - to the right.\r
2457         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2458         for( i = 0; i < n1 - 1; i++ )\r
2459         {\r
2460             int idx = indices[i];\r
2461             double w = priors[responses[idx]];\r
2462             int d = dir[idx];\r
2463 \r
2464             if( d < 0 )\r
2465             {\r
2466                 LL += w; LR -= w;\r
2467                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )\r
2468                 {\r
2469                     best_val = LL + RR;\r
2470                     best_i = i; best_inversed = 0;\r
2471                 }\r
2472             }\r
2473             else if( d > 0 )\r
2474             {\r
2475                 RL += w; RR -= w;\r
2476                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )\r
2477                 {\r
2478                     best_val = RL + LR;\r
2479                     best_i = i; best_inversed = 1;\r
2480                 }\r
2481             }\r
2482         }\r
2483     }\r
2484 \r
2485     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,\r
2486         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2487         best_inversed, (float)best_val ) : 0;\r
2488 }\r
2489 \r
2490 \r
2491 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )\r
2492 {\r
2493     const char* dir = (char*)data->direction->data.ptr;\r
2494     int n = node->sample_count;\r
2495     int* labels_buf = data->pred_int_buf;\r
2496     const int* labels = 0;\r
2497     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2498     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2499     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2500     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2501     // RR - ... both send to the right\r
2502     CvDTreeSplit* split = data->new_split_cat( vi, 0 );\r
2503     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;\r
2504     double best_val = 0;\r
2505     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;\r
2506     double* rc = lc + mi + 1;\r
2507 \r
2508     for( i = -1; i < mi; i++ )\r
2509         lc[i] = rc[i] = 0;\r
2510 \r
2511     // for each category calculate the weight of samples\r
2512     // sent to the left (lc) and to the right (rc) by the primary split\r
2513     if( !data->have_priors )\r
2514     {\r
2515         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;\r
2516         int* _rc = _lc + mi + 1;\r
2517 \r
2518         for( i = -1; i < mi; i++ )\r
2519             _lc[i] = _rc[i] = 0;\r
2520 \r
2521         for( i = 0; i < n; i++ )\r
2522         {\r
2523             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2524             int d = dir[i];\r
2525             int sum = _lc[idx] + d;\r
2526             int sum_abs = _rc[idx] + (d & 1);\r
2527             _lc[idx] = sum; _rc[idx] = sum_abs;\r
2528         }\r
2529 \r
2530         for( i = 0; i < mi; i++ )\r
2531         {\r
2532             int sum = _lc[i];\r
2533             int sum_abs = _rc[i];\r
2534             lc[i] = (sum_abs - sum) >> 1;\r
2535             rc[i] = (sum_abs + sum) >> 1;\r
2536         }\r
2537     }\r
2538     else\r
2539     {\r
2540         const double* priors = data->priors_mult->data.db;\r
2541         int* responses_buf = data->resp_int_buf;\r
2542         const int* responses = 0;\r
2543         data->get_class_labels(node, responses_buf, &responses);\r
2544 \r
2545         for( i = 0; i < n; i++ )\r
2546         {\r
2547             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2548             double w = priors[responses[i]];\r
2549             int d = dir[i];\r
2550             double sum = lc[idx] + d*w;\r
2551             double sum_abs = rc[idx] + (d & 1)*w;\r
2552             lc[idx] = sum; rc[idx] = sum_abs;\r
2553         }\r
2554 \r
2555         for( i = 0; i < mi; i++ )\r
2556         {\r
2557             double sum = lc[i];\r
2558             double sum_abs = rc[i];\r
2559             lc[i] = (sum_abs - sum) * 0.5;\r
2560             rc[i] = (sum_abs + sum) * 0.5;\r
2561         }\r
2562     }\r
2563 \r
2564     // 2. now form the split.\r
2565     // in each category send all the samples to the same direction as majority\r
2566     for( i = 0; i < mi; i++ )\r
2567     {\r
2568         double lval = lc[i], rval = rc[i];\r
2569         if( lval > rval )\r
2570         {\r
2571             split->subset[i >> 5] |= 1 << (i & 31);\r
2572             best_val += lval;\r
2573             l_win++;\r
2574         }\r
2575         else\r
2576             best_val += rval;\r
2577     }\r
2578 \r
2579     split->quality = (float)best_val;\r
2580     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )\r
2581         cvSetRemoveByPtr( data->split_heap, split ), split = 0;\r
2582 \r
2583     return split;\r
2584 }\r
2585 \r
2586 \r
2587 void CvDTree::calc_node_value( CvDTreeNode* node )\r
2588 {\r
2589     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;\r
2590     int* cv_labels_buf = data->cv_lables_buf;\r
2591     const int* cv_labels = 0;\r
2592     data->get_cv_labels(node, cv_labels_buf, &cv_labels);\r
2593 \r
2594     if( data->is_classifier )\r
2595     {\r
2596         // in case of classification tree:\r
2597         //  * node value is the label of the class that has the largest weight in the node.\r
2598         //  * node risk is the weighted number of misclassified samples,\r
2599         //  * j-th cross-validation fold value and risk are calculated as above,\r
2600         //    but using the samples with cv_labels(*)!=j.\r
2601         //  * j-th cross-validation fold error is calculated as the weighted number of\r
2602         //    misclassified samples with cv_labels(*)==j.\r
2603 \r
2604         // compute the number of instances of each class\r
2605         int* cls_count = data->counts->data.i;\r
2606         int* responses_buf = data->resp_int_buf;\r
2607         const int* responses = 0;\r
2608         data->get_class_labels(node, responses_buf, &responses);\r
2609         int m = data->get_num_classes();\r
2610         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));\r
2611         double max_val = -1, total_weight = 0;\r
2612         int max_k = -1;\r
2613         double* priors = data->priors_mult->data.db;\r
2614 \r
2615         for( k = 0; k < m; k++ )\r
2616             cls_count[k] = 0;\r
2617 \r
2618         if( cv_n == 0 )\r
2619         {\r
2620             for( i = 0; i < n; i++ )\r
2621                 cls_count[responses[i]]++;\r
2622         }\r
2623         else\r
2624         {\r
2625             for( j = 0; j < cv_n; j++ )\r
2626                 for( k = 0; k < m; k++ )\r
2627                     cv_cls_count[j*m + k] = 0;\r
2628 \r
2629             for( i = 0; i < n; i++ )\r
2630             {\r
2631                 j = cv_labels[i]; k = responses[i];\r
2632                 cv_cls_count[j*m + k]++;\r
2633             }\r
2634 \r
2635             for( j = 0; j < cv_n; j++ )\r
2636                 for( k = 0; k < m; k++ )\r
2637                     cls_count[k] += cv_cls_count[j*m + k];\r
2638         }\r
2639 \r
2640         if( data->have_priors && node->parent == 0 )\r
2641         {\r
2642             // compute priors_mult from priors, take the sample ratio into account.\r
2643             double sum = 0;\r
2644             for( k = 0; k < m; k++ )\r
2645             {\r
2646                 int n_k = cls_count[k];\r
2647                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);\r
2648                 sum += priors[k];\r
2649             }\r
2650             sum = 1./sum;\r
2651             for( k = 0; k < m; k++ )\r
2652                 priors[k] *= sum;\r
2653         }\r
2654 \r
2655         for( k = 0; k < m; k++ )\r
2656         {\r
2657             double val = cls_count[k]*priors[k];\r
2658             total_weight += val;\r
2659             if( max_val < val )\r
2660             {\r
2661                 max_val = val;\r
2662                 max_k = k;\r
2663             }\r
2664         }\r
2665 \r
2666         node->class_idx = max_k;\r
2667         node->value = data->cat_map->data.i[\r
2668             data->cat_ofs->data.i[data->cat_var_count] + max_k];\r
2669         node->node_risk = total_weight - max_val;\r
2670 \r
2671         for( j = 0; j < cv_n; j++ )\r
2672         {\r
2673             double sum_k = 0, sum = 0, max_val_k = 0;\r
2674             max_val = -1; max_k = -1;\r
2675 \r
2676             for( k = 0; k < m; k++ )\r
2677             {\r
2678                 double w = priors[k];\r
2679                 double val_k = cv_cls_count[j*m + k]*w;\r
2680                 double val = cls_count[k]*w - val_k;\r
2681                 sum_k += val_k;\r
2682                 sum += val;\r
2683                 if( max_val < val )\r
2684                 {\r
2685                     max_val = val;\r
2686                     max_val_k = val_k;\r
2687                     max_k = k;\r
2688                 }\r
2689             }\r
2690 \r
2691             node->cv_Tn[j] = INT_MAX;\r
2692             node->cv_node_risk[j] = sum - max_val;\r
2693             node->cv_node_error[j] = sum_k - max_val_k;\r
2694         }\r
2695     }\r
2696     else\r
2697     {\r
2698         // in case of regression tree:\r
2699         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,\r
2700         //    n is the number of samples in the node.\r
2701         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)\r
2702         //  * j-th cross-validation fold value and risk are calculated as above,\r
2703         //    but using the samples with cv_labels(*)!=j.\r
2704         //  * j-th cross-validation fold error is calculated\r
2705         //    using samples with cv_labels(*)==j as the test subset:\r
2706         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),\r
2707         //    where node_value_j is the node value calculated\r
2708         //    as described in the previous bullet, and summation is done\r
2709         //    over the samples with cv_labels(*)==j.\r
2710 \r
2711         double sum = 0, sum2 = 0;\r
2712         float* values_buf = data->resp_float_buf;\r
2713         const float* values = 0;\r
2714         data->get_ord_responses(node, values_buf, &values);\r
2715         double *cv_sum = 0, *cv_sum2 = 0;\r
2716         int* cv_count = 0;\r
2717 \r
2718         if( cv_n == 0 )\r
2719         {\r
2720             for( i = 0; i < n; i++ )\r
2721             {\r
2722                 double t = values[i];\r
2723                 sum += t;\r
2724                 sum2 += t*t;\r
2725             }\r
2726         }\r
2727         else\r
2728         {\r
2729             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );\r
2730             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );\r
2731             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );\r
2732 \r
2733             for( j = 0; j < cv_n; j++ )\r
2734             {\r
2735                 cv_sum[j] = cv_sum2[j] = 0.;\r
2736                 cv_count[j] = 0;\r
2737             }\r
2738 \r
2739             for( i = 0; i < n; i++ )\r
2740             {\r
2741                 j = cv_labels[i];\r
2742                 double t = values[i];\r
2743                 double s = cv_sum[j] + t;\r
2744                 double s2 = cv_sum2[j] + t*t;\r
2745                 int nc = cv_count[j] + 1;\r
2746                 cv_sum[j] = s;\r
2747                 cv_sum2[j] = s2;\r
2748                 cv_count[j] = nc;\r
2749             }\r
2750 \r
2751             for( j = 0; j < cv_n; j++ )\r
2752             {\r
2753                 sum += cv_sum[j];\r
2754                 sum2 += cv_sum2[j];\r
2755             }\r
2756         }\r
2757 \r
2758         node->node_risk = sum2 - (sum/n)*sum;\r
2759         node->value = sum/n;\r
2760 \r
2761         for( j = 0; j < cv_n; j++ )\r
2762         {\r
2763             double s = cv_sum[j], si = sum - s;\r
2764             double s2 = cv_sum2[j], s2i = sum2 - s2;\r
2765             int c = cv_count[j], ci = n - c;\r
2766             double r = si/MAX(ci,1);\r
2767             node->cv_node_risk[j] = s2i - r*r*ci;\r
2768             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;\r
2769             node->cv_Tn[j] = INT_MAX;\r
2770         }\r
2771     }\r
2772 }\r
2773 \r
2774 \r
2775 void CvDTree::complete_node_dir( CvDTreeNode* node )\r
2776 {\r
2777     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;\r
2778     int nz = n - node->get_num_valid(node->split->var_idx);\r
2779     char* dir = (char*)data->direction->data.ptr;\r
2780 \r
2781     // try to complete direction using surrogate splits\r
2782     if( nz && data->params.use_surrogates )\r
2783     {\r
2784         CvDTreeSplit* split = node->split->next;\r
2785         for( ; split != 0 && nz; split = split->next )\r
2786         {\r
2787             int inversed_mask = split->inversed ? -1 : 0;\r
2788             vi = split->var_idx;\r
2789 \r
2790             if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
2791             {\r
2792                 int* labels_buf = data->pred_int_buf;\r
2793                 const int* labels = 0;\r
2794                 data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2795                 const int* subset = split->subset;\r
2796 \r
2797                 for( i = 0; i < n; i++ )\r
2798                 {\r
2799                     int idx = labels[i];\r
2800                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))\r
2801                         \r
2802                     {\r
2803                         int d = CV_DTREE_CAT_DIR(idx,subset);\r
2804                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
2805                         if( --nz )\r
2806                             break;\r
2807                     }\r
2808                 }\r
2809             }\r
2810             else // split on ordered var\r
2811             {\r
2812                 float* values_buf = data->pred_float_buf;\r
2813                 const float* values = 0;\r
2814                 int* indices_buf = data->pred_int_buf;\r
2815                 const int* indices = 0;\r
2816                 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2817                 int split_point = split->ord.split_point;\r
2818                 int n1 = node->get_num_valid(vi);\r
2819 \r
2820                 assert( 0 <= split_point && split_point < n-1 );\r
2821 \r
2822                 for( i = 0; i < n1; i++ )\r
2823                 {\r
2824                     int idx = indices[i];\r
2825                     if( !dir[idx] )\r
2826                     {\r
2827                         int d = i <= split_point ? -1 : 1;\r
2828                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);\r
2829                         if( --nz )\r
2830                             break;\r
2831                     }\r
2832                 }\r
2833             }\r
2834         }\r
2835     }\r
2836 \r
2837     // find the default direction for the rest\r
2838     if( nz )\r
2839     {\r
2840         for( i = nr = 0; i < n; i++ )\r
2841             nr += dir[i] > 0;\r
2842         nl = n - nr - nz;\r
2843         d0 = nl > nr ? -1 : nr > nl;\r
2844     }\r
2845 \r
2846     // make sure that every sample is directed either to the left or to the right\r
2847     for( i = 0; i < n; i++ )\r
2848     {\r
2849         int d = dir[i];\r
2850         if( !d )\r
2851         {\r
2852             d = d0;\r
2853             if( !d )\r
2854                 d = d1, d1 = -d1;\r
2855         }\r
2856         d = d > 0;\r
2857         dir[i] = (char)d; // remap (-1,1) to (0,1)\r
2858     }\r
2859 }\r
2860 \r
2861 \r
2862 void CvDTree::split_node_data( CvDTreeNode* node )\r
2863 {\r
2864     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;\r
2865     char* dir = (char*)data->direction->data.ptr;\r
2866     CvDTreeNode *left = 0, *right = 0;\r
2867     int* new_idx = data->split_buf->data.i;\r
2868     int new_buf_idx = data->get_child_buf_idx( node );\r
2869     int work_var_count = data->get_work_var_count();\r
2870     CvMat* buf = data->buf;\r
2871     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));\r
2872 \r
2873     complete_node_dir(node);\r
2874 \r
2875     for( i = nl = nr = 0; i < n; i++ )\r
2876     {\r
2877         int d = dir[i];\r
2878         // initialize new indices for splitting ordered variables\r
2879         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li\r
2880         nr += d;\r
2881         nl += d^1;\r
2882     }\r
2883 \r
2884 \r
2885     bool split_input_data;\r
2886     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );\r
2887     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );\r
2888 \r
2889     split_input_data = node->depth + 1 < data->params.max_depth &&\r
2890         (node->left->sample_count > data->params.min_sample_count ||\r
2891         node->right->sample_count > data->params.min_sample_count);\r
2892 \r
2893     // split ordered variables, keep both halves sorted.\r
2894     for( vi = 0; vi < data->var_count; vi++ )\r
2895     {\r
2896         int ci = data->get_var_type(vi);\r
2897         int n1 = node->get_num_valid(vi);\r
2898         int *src_idx_buf = data->pred_int_buf;\r
2899         const int* src_idx = 0;\r
2900         float *src_val_buf = data->pred_float_buf;\r
2901         const float* src_val = 0;\r
2902         \r
2903         if( ci >= 0 || !split_input_data )\r
2904             continue;\r
2905 \r
2906         data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
2907 \r
2908         for(i = 0; i < n; i++)\r
2909             temp_buf[i] = src_idx[i];\r
2910 \r
2911         if (data->is_buf_16u)\r
2912         {\r
2913             unsigned short *ldst, *rdst, *ldst0, *rdst0;\r
2914             //unsigned short tl, tr;\r
2915             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
2916                 vi*scount + left->offset);\r
2917             rdst0 = rdst = (unsigned short*)(ldst + nl);\r
2918 \r
2919             // split sorted\r
2920             for( i = 0; i < n1; i++ )\r
2921             {\r
2922                 int idx = temp_buf[i];\r
2923                 int d = dir[idx];\r
2924                 idx = new_idx[idx];\r
2925                 if (d)\r
2926                 {\r
2927                     *rdst = (unsigned short)idx;\r
2928                     rdst++;\r
2929                 }\r
2930                 else\r
2931                 {\r
2932                     *ldst = (unsigned short)idx;\r
2933                     ldst++;\r
2934                 }\r
2935             }\r
2936 \r
2937             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2938             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2939 \r
2940             // split missing\r
2941             for( ; i < n; i++ )\r
2942             {\r
2943                 int idx = temp_buf[i];\r
2944                 int d = dir[idx];\r
2945                 idx = new_idx[idx];\r
2946                 if (d)\r
2947                 {\r
2948                     *rdst = (unsigned short)idx;\r
2949                     rdst++;\r
2950                 }\r
2951                 else\r
2952                 {\r
2953                     *ldst = (unsigned short)idx;\r
2954                     ldst++;\r
2955                 }\r
2956             }\r
2957         }\r
2958         else\r
2959         {\r
2960             int *ldst0, *ldst, *rdst0, *rdst;\r
2961             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + \r
2962                 vi*scount + left->offset;\r
2963             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + \r
2964                 vi*scount + right->offset;\r
2965 \r
2966             // split sorted\r
2967             for( i = 0; i < n1; i++ )\r
2968             {\r
2969                 int idx = temp_buf[i];\r
2970                 int d = dir[idx];\r
2971                 idx = new_idx[idx];\r
2972                 if (d)\r
2973                 {\r
2974                     *rdst = idx;\r
2975                     rdst++;\r
2976                 }\r
2977                 else\r
2978                 {\r
2979                     *ldst = idx;\r
2980                     ldst++;\r
2981                 }\r
2982             }\r
2983 \r
2984             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2985             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2986 \r
2987             // split missing\r
2988             for( ; i < n; i++ )\r
2989             {\r
2990                 int idx = temp_buf[i];\r
2991                 int d = dir[idx];\r
2992                 idx = new_idx[idx];\r
2993                 if (d)\r
2994                 {\r
2995                     *rdst = idx;\r
2996                     rdst++;\r
2997                 }\r
2998                 else\r
2999                 {\r
3000                     *ldst = idx;\r
3001                     ldst++;\r
3002                 }\r
3003             }\r
3004         }\r
3005     }\r
3006 \r
3007     // split categorical vars, responses and cv_labels using new_idx relocation table\r
3008     for( vi = 0; vi < work_var_count; vi++ )\r
3009     {\r
3010         int ci = data->get_var_type(vi);\r
3011         int n1 = node->get_num_valid(vi), nr1 = 0;\r
3012         \r
3013         if( ci < 0 || (vi < data->var_count && !split_input_data) )\r
3014             continue;\r
3015 \r
3016         int *src_lbls_buf = data->pred_int_buf;\r
3017         const int* src_lbls = 0;\r
3018         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);\r
3019 \r
3020         for(i = 0; i < n; i++)\r
3021             temp_buf[i] = src_lbls[i];\r
3022 \r
3023         if (data->is_buf_16u)\r
3024         {\r
3025             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + \r
3026                 vi*scount + left->offset);\r
3027             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + \r
3028                 vi*scount + right->offset);\r
3029             \r
3030             for( i = 0; i < n; i++ )\r
3031             {\r
3032                 int d = dir[i];\r
3033                 int idx = temp_buf[i];\r
3034                 if (d)\r
3035                 {\r
3036                     *rdst = (unsigned short)idx;\r
3037                     rdst++;\r
3038                     nr1 += (idx != 65535 )&d;\r
3039                 }\r
3040                 else\r
3041                 {\r
3042                     *ldst = (unsigned short)idx;\r
3043                     ldst++;\r
3044                 }\r
3045             }\r
3046 \r
3047             if( vi < data->var_count )\r
3048             {\r
3049                 left->set_num_valid(vi, n1 - nr1);\r
3050                 right->set_num_valid(vi, nr1);\r
3051             }\r
3052         }\r
3053         else\r
3054         {\r
3055             int *ldst = buf->data.i + left->buf_idx*buf->cols + \r
3056                 vi*scount + left->offset;\r
3057             int *rdst = buf->data.i + right->buf_idx*buf->cols + \r
3058                 vi*scount + right->offset;\r
3059             \r
3060             for( i = 0; i < n; i++ )\r
3061             {\r
3062                 int d = dir[i];\r
3063                 int idx = temp_buf[i];\r
3064                 if (d)\r
3065                 {\r
3066                     *rdst = idx;\r
3067                     rdst++;\r
3068                     nr1 += (idx >= 0)&d;\r
3069                 }\r
3070                 else\r
3071                 {\r
3072                     *ldst = idx;\r
3073                     ldst++;\r
3074                 }\r
3075                 \r
3076             }\r
3077 \r
3078             if( vi < data->var_count )\r
3079             {\r
3080                 left->set_num_valid(vi, n1 - nr1);\r
3081                 right->set_num_valid(vi, nr1);\r
3082             }\r
3083         }        \r
3084     }\r
3085 \r
3086 \r
3087     // split sample indices\r
3088     int *sample_idx_src_buf = data->sample_idx_buf;\r
3089     const int* sample_idx_src = 0;\r
3090     data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);\r
3091 \r
3092     for(i = 0; i < n; i++)\r
3093         temp_buf[i] = sample_idx_src[i];\r
3094 \r
3095     int pos = data->get_work_var_count();\r
3096     if (data->is_buf_16u)\r
3097     {\r
3098         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
3099             pos*scount + left->offset);\r
3100         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + \r
3101             pos*scount + right->offset);\r
3102         for (i = 0; i < n; i++)\r
3103         {\r
3104             int d = dir[i];\r
3105             unsigned short idx = (unsigned short)temp_buf[i];\r
3106             if (d)\r
3107             {\r
3108                 *rdst = idx;\r
3109                 rdst++;\r
3110             }\r
3111             else\r
3112             {\r
3113                 *ldst = idx;\r
3114                 ldst++;\r
3115             }\r
3116         }\r
3117     }\r
3118     else\r
3119     {\r
3120         int* ldst = buf->data.i + left->buf_idx*buf->cols + \r
3121             pos*scount + left->offset;\r
3122         int* rdst = buf->data.i + right->buf_idx*buf->cols + \r
3123             pos*scount + right->offset;\r
3124         for (i = 0; i < n; i++)\r
3125         {\r
3126             int d = dir[i];\r
3127             int idx = temp_buf[i];\r
3128             if (d)\r
3129             {\r
3130                 *rdst = idx;\r
3131                 rdst++;\r
3132             }\r
3133             else\r
3134             {\r
3135                 *ldst = idx;\r
3136                 ldst++;\r
3137             }\r
3138         }\r
3139     }\r
3140     \r
3141     // deallocate the parent node data that is not needed anymore\r
3142     data->free_node_data(node);    \r
3143 }\r
3144 \r
3145 float CvDTree::calc_error( CvMLData* _data, int type )\r
3146 {\r
3147     float err = 0;\r
3148     const CvMat* values = _data->get_values();
3149     const CvMat* response = _data->get_response();
3150     const CvMat* missing = _data->get_missing();
3151     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
3152     const CvMat* var_types = _data->get_var_types();\r
3153     int* sidx = sample_idx ? sample_idx->data.i : 0;\r
3154     int r_step = CV_IS_MAT_CONT(response->type) ?\r
3155                 1 : response->step / CV_ELEM_SIZE(response->type);\r
3156     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;\r
3157     int sample_count = sample_idx ? sample_idx->cols : 0;\r
3158     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
3159     if ( is_classifier )\r
3160     {\r
3161         for( int i = 0; i < sample_count; i++ )\r
3162         {\r
3163             CvMat sample, miss;\r
3164             int si = sidx ? sidx[i] : i;\r
3165             cvGetRow( values, &sample, si ); \r
3166             if( missing ) \r
3167                 cvGetRow( missing, &miss, si );             \r
3168             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3169             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
3170             err += d;\r
3171         }\r
3172         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
3173     }\r
3174     else\r
3175     {\r
3176         for( int i = 0; i < sample_count; i++ )\r
3177         {\r
3178             CvMat sample, miss;\r
3179             int si = sidx ? sidx[i] : i;\r
3180             cvGetRow( values, &sample, si ); \r
3181             if( missing ) \r
3182                 cvGetRow( missing, &miss, si );             \r
3183             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3184             float d = r - response->data.fl[si*r_step];\r
3185             err += d*d;\r
3186         }\r
3187         err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
3188     }\r
3189     return err;\r
3190 }\r
3191 \r
3192 void CvDTree::prune_cv()\r
3193 {\r
3194     CvMat* ab = 0;\r
3195     CvMat* temp = 0;\r
3196     CvMat* err_jk = 0;\r
3197 \r
3198     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.\r
3199     // 2. choose the best tree index (if need, apply 1SE rule).\r
3200     // 3. store the best index and cut the branches.\r
3201 \r
3202     CV_FUNCNAME( "CvDTree::prune_cv" );\r
3203 \r
3204     __BEGIN__;\r
3205 \r
3206     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;\r
3207     // currently, 1SE for regression is not implemented\r
3208     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;\r
3209     double* err;\r
3210     double min_err = 0, min_err_se = 0;\r
3211     int min_idx = -1;\r
3212 \r
3213     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));\r
3214 \r
3215     // build the main tree sequence, calculate alpha's\r
3216     for(;;tree_count++)\r
3217     {\r
3218         double min_alpha = update_tree_rnc(tree_count, -1);\r
3219         if( cut_tree(tree_count, -1, min_alpha) )\r
3220             break;\r
3221 \r
3222         if( ab->cols <= tree_count )\r
3223         {\r
3224             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));\r
3225             for( ti = 0; ti < ab->cols; ti++ )\r
3226                 temp->data.db[ti] = ab->data.db[ti];\r
3227             cvReleaseMat( &ab );\r
3228             ab = temp;\r
3229             temp = 0;\r
3230         }\r
3231 \r
3232         ab->data.db[tree_count] = min_alpha;\r
3233     }\r
3234 \r
3235     ab->data.db[0] = 0.;\r
3236 \r
3237     if( tree_count > 0 )\r
3238     {\r
3239         for( ti = 1; ti < tree_count-1; ti++ )\r
3240             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);\r
3241         ab->data.db[tree_count-1] = DBL_MAX*0.5;\r
3242 \r
3243         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));\r
3244         err = err_jk->data.db;\r
3245 \r
3246         for( j = 0; j < cv_n; j++ )\r
3247         {\r
3248             int tj = 0, tk = 0;\r
3249             for( ; tk < tree_count; tj++ )\r
3250             {\r
3251                 double min_alpha = update_tree_rnc(tj, j);\r
3252                 if( cut_tree(tj, j, min_alpha) )\r
3253                     min_alpha = DBL_MAX;\r
3254 \r
3255                 for( ; tk < tree_count; tk++ )\r
3256                 {\r
3257                     if( ab->data.db[tk] > min_alpha )\r
3258                         break;\r
3259                     err[j*tree_count + tk] = root->tree_error;\r
3260                 }\r
3261             }\r
3262         }\r
3263 \r
3264         for( ti = 0; ti < tree_count; ti++ )\r
3265         {\r
3266             double sum_err = 0;\r
3267             for( j = 0; j < cv_n; j++ )\r
3268                 sum_err += err[j*tree_count + ti];\r
3269             if( ti == 0 || sum_err < min_err )\r
3270             {\r
3271                 min_err = sum_err;\r
3272                 min_idx = ti;\r
3273                 if( use_1se )\r
3274                     min_err_se = sqrt( sum_err*(n - sum_err) );\r
3275             }\r
3276             else if( sum_err < min_err + min_err_se )\r
3277                 min_idx = ti;\r
3278         }\r
3279     }\r
3280 \r
3281     pruned_tree_idx = min_idx;\r
3282     free_prune_data(data->params.truncate_pruned_tree != 0);\r
3283 \r
3284     __END__;\r
3285 \r
3286     cvReleaseMat( &err_jk );\r
3287     cvReleaseMat( &ab );\r
3288     cvReleaseMat( &temp );\r
3289 }\r
3290 \r
3291 \r
3292 double CvDTree::update_tree_rnc( int T, int fold )\r
3293 {\r
3294     CvDTreeNode* node = root;\r
3295     double min_alpha = DBL_MAX;\r
3296 \r
3297     for(;;)\r
3298     {\r
3299         CvDTreeNode* parent;\r
3300         for(;;)\r
3301         {\r
3302             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3303             if( t <= T || !node->left )\r
3304             {\r
3305                 node->complexity = 1;\r
3306                 node->tree_risk = node->node_risk;\r
3307                 node->tree_error = 0.;\r
3308                 if( fold >= 0 )\r
3309                 {\r
3310                     node->tree_risk = node->cv_node_risk[fold];\r
3311                     node->tree_error = node->cv_node_error[fold];\r
3312                 }\r
3313                 break;\r
3314             }\r
3315             node = node->left;\r
3316         }\r
3317 \r
3318         for( parent = node->parent; parent && parent->right == node;\r
3319             node = parent, parent = parent->parent )\r
3320         {\r
3321             parent->complexity += node->complexity;\r
3322             parent->tree_risk += node->tree_risk;\r
3323             parent->tree_error += node->tree_error;\r
3324 \r
3325             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)\r
3326                 - parent->tree_risk)/(parent->complexity - 1);\r
3327             min_alpha = MIN( min_alpha, parent->alpha );\r
3328         }\r
3329 \r
3330         if( !parent )\r
3331             break;\r
3332 \r
3333         parent->complexity = node->complexity;\r
3334         parent->tree_risk = node->tree_risk;\r
3335         parent->tree_error = node->tree_error;\r
3336         node = parent->right;\r
3337     }\r
3338 \r
3339     return min_alpha;\r
3340 }\r
3341 \r
3342 \r
3343 int CvDTree::cut_tree( int T, int fold, double min_alpha )\r
3344 {\r
3345     CvDTreeNode* node = root;\r
3346     if( !node->left )\r
3347         return 1;\r
3348 \r
3349     for(;;)\r
3350     {\r
3351         CvDTreeNode* parent;\r
3352         for(;;)\r
3353         {\r
3354             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3355             if( t <= T || !node->left )\r
3356                 break;\r
3357             if( node->alpha <= min_alpha + FLT_EPSILON )\r
3358             {\r
3359                 if( fold >= 0 )\r
3360                     node->cv_Tn[fold] = T;\r
3361                 else\r
3362                     node->Tn = T;\r
3363                 if( node == root )\r
3364                     return 1;\r
3365                 break;\r
3366             }\r
3367             node = node->left;\r
3368         }\r
3369 \r
3370         for( parent = node->parent; parent && parent->right == node;\r
3371             node = parent, parent = parent->parent )\r
3372             ;\r
3373 \r
3374         if( !parent )\r
3375             break;\r
3376 \r
3377         node = parent->right;\r
3378     }\r
3379 \r
3380     return 0;\r
3381 }\r
3382 \r
3383 \r
3384 void CvDTree::free_prune_data(bool cut_tree)\r
3385 {\r
3386     CvDTreeNode* node = root;\r
3387 \r
3388     for(;;)\r
3389     {\r
3390         CvDTreeNode* parent;\r
3391         for(;;)\r
3392         {\r
3393             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )\r
3394             // as we will clear the whole cross-validation heap at the end\r
3395             node->cv_Tn = 0;\r
3396             node->cv_node_error = node->cv_node_risk = 0;\r
3397             if( !node->left )\r
3398                 break;\r
3399             node = node->left;\r
3400         }\r
3401 \r
3402         for( parent = node->parent; parent && parent->right == node;\r
3403             node = parent, parent = parent->parent )\r
3404         {\r
3405             if( cut_tree && parent->Tn <= pruned_tree_idx )\r
3406             {\r
3407                 data->free_node( parent->left );\r
3408                 data->free_node( parent->right );\r
3409                 parent->left = parent->right = 0;\r
3410             }\r
3411         }\r
3412 \r
3413         if( !parent )\r
3414             break;\r
3415 \r
3416         node = parent->right;\r
3417     }\r
3418 \r
3419     if( data->cv_heap )\r
3420         cvClearSet( data->cv_heap );\r
3421 }\r
3422 \r
3423 \r
3424 void CvDTree::free_tree()\r
3425 {\r
3426     if( root && data && data->shared )\r
3427     {\r
3428         pruned_tree_idx = INT_MIN;\r
3429         free_prune_data(true);\r
3430         data->free_node(root);\r
3431         root = 0;\r
3432     }\r
3433 }\r
3434 \r
3435 CvDTreeNode* CvDTree::predict( const CvMat* _sample,\r
3436     const CvMat* _missing, bool preprocessed_input ) const\r
3437 {\r
3438     CvDTreeNode* result = 0;\r
3439     int* catbuf = 0;\r
3440 \r
3441     CV_FUNCNAME( "CvDTree::predict" );\r
3442 \r
3443     __BEGIN__;\r
3444 \r
3445     int i, step, mstep = 0;\r
3446     const float* sample;\r
3447     const uchar* m = 0;\r
3448     CvDTreeNode* node = root;\r
3449     const int* vtype;\r
3450     const int* vidx;\r
3451     const int* cmap;\r
3452     const int* cofs;\r
3453 \r
3454     if( !node )\r
3455         CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
3456 \r
3457     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
3458         (_sample->cols != 1 && _sample->rows != 1) ||\r
3459         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
3460         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
3461             CV_ERROR( CV_StsBadArg,\r
3462         "the input sample must be 1d floating-point vector with the same "\r
3463         "number of elements as the total number of variables used for training" );\r
3464 \r
3465     sample = _sample->data.fl;\r
3466     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
3467 \r
3468     if( data->cat_count && !preprocessed_input ) // cache for categorical variables\r
3469     {\r
3470         int n = data->cat_count->cols;\r
3471         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));\r
3472         for( i = 0; i < n; i++ )\r
3473             catbuf[i] = -1;\r
3474     }\r
3475 \r
3476     if( _missing )\r
3477     {\r
3478         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
3479         !CV_ARE_SIZES_EQ(_missing, _sample) )\r
3480             CV_ERROR( CV_StsBadArg,\r
3481         "the missing data mask must be 8-bit vector of the same size as input sample" );\r
3482         m = _missing->data.ptr;\r
3483         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
3484     }\r
3485 \r
3486     vtype = data->var_type->data.i;\r
3487     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
3488     cmap = data->cat_map ? data->cat_map->data.i : 0;\r
3489     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;\r
3490 \r
3491     while( node->Tn > pruned_tree_idx && node->left )\r
3492     {\r
3493         CvDTreeSplit* split = node->split;\r
3494         int dir = 0;\r
3495         for( ; !dir && split != 0; split = split->next )\r
3496         {\r
3497             int vi = split->var_idx;\r
3498             int ci = vtype[vi];\r
3499             i = vidx ? vidx[vi] : vi;\r
3500             float val = sample[i*step];\r
3501             if( m && m[i*mstep] )\r
3502                 continue;\r
3503             if( ci < 0 ) // ordered\r
3504                 dir = val <= split->ord.c ? -1 : 1;\r
3505             else // categorical\r
3506             {\r
3507                 int c;\r
3508                 if( preprocessed_input )\r
3509                     c = cvRound(val);\r
3510                 else\r
3511                 {\r
3512                     c = catbuf[ci];\r
3513                     if( c < 0 )\r
3514                     {\r
3515                         int a = c = cofs[ci];\r
3516                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];\r
3517                         \r
3518                         int ival = cvRound(val);\r
3519                         if( ival != val )\r
3520                             CV_ERROR( CV_StsBadArg,\r
3521                             "one of input categorical variable is not an integer" );\r
3522                         \r
3523                         int sh = 0;\r
3524                         while( a < b )\r
3525                         {\r
3526                             sh++;\r
3527                             c = (a + b) >> 1;\r
3528                             if( ival < cmap[c] )\r
3529                                 b = c;\r
3530                             else if( ival > cmap[c] )\r
3531                                 a = c+1;\r
3532                             else\r
3533                                 break;\r
3534                         }\r
3535 \r
3536                         if( c < 0 || ival != cmap[c] )\r
3537                             continue;\r
3538 \r
3539                         catbuf[ci] = c -= cofs[ci];\r
3540                     }\r
3541                 }\r
3542                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;\r
3543                 dir = CV_DTREE_CAT_DIR(c, split->subset);\r
3544             }\r
3545 \r
3546             if( split->inversed )\r
3547                 dir = -dir;\r
3548         }\r
3549 \r
3550         if( !dir )\r
3551         {\r
3552             double diff = node->right->sample_count - node->left->sample_count;\r
3553             dir = diff < 0 ? -1 : 1;\r
3554         }\r
3555         node = dir < 0 ? node->left : node->right;\r
3556     }\r
3557 \r
3558     result = node;\r
3559 \r
3560     __END__;\r
3561 \r
3562     return result;\r
3563 }\r
3564 \r
3565 \r
3566 const CvMat* CvDTree::get_var_importance()\r
3567 {\r
3568     if( !var_importance )\r
3569     {\r
3570         CvDTreeNode* node = root;\r
3571         double* importance;\r
3572         if( !node )\r
3573             return 0;\r
3574         var_importance = cvCreateMat( 1, data->var_count, CV_64F );\r
3575         cvZero( var_importance );\r
3576         importance = var_importance->data.db;\r
3577 \r
3578         for(;;)\r
3579         {\r
3580             CvDTreeNode* parent;\r
3581             for( ;; node = node->left )\r
3582             {\r
3583                 CvDTreeSplit* split = node->split;\r
3584 \r
3585                 if( !node->left || node->Tn <= pruned_tree_idx )\r
3586                     break;\r
3587 \r
3588                 for( ; split != 0; split = split->next )\r
3589                     importance[split->var_idx] += split->quality;\r
3590             }\r
3591 \r
3592             for( parent = node->parent; parent && parent->right == node;\r
3593                 node = parent, parent = parent->parent )\r
3594                 ;\r
3595 \r
3596             if( !parent )\r
3597                 break;\r
3598 \r
3599             node = parent->right;\r
3600         }\r
3601 \r
3602         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );\r
3603     }\r
3604 \r
3605     return var_importance;\r
3606 }\r
3607 \r
3608 \r
3609 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )\r
3610 {\r
3611     int ci;\r
3612 \r
3613     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );\r
3614     cvWriteInt( fs, "var", split->var_idx );\r
3615     cvWriteReal( fs, "quality", split->quality );\r
3616 \r
3617     ci = data->get_var_type(split->var_idx);\r
3618     if( ci >= 0 ) // split on a categorical var\r
3619     {\r
3620         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;\r
3621         for( i = 0; i < n; i++ )\r
3622             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;\r
3623 \r
3624         // ad-hoc rule when to use inverse categorical split notation\r
3625         // to achieve more compact and clear representation\r
3626         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;\r
3627 \r
3628         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?\r
3629                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );\r
3630 \r
3631         for( i = 0; i < n; i++ )\r
3632         {\r
3633             int dir = CV_DTREE_CAT_DIR(i,split->subset);\r
3634             if( dir*default_dir < 0 )\r
3635                 cvWriteInt( fs, 0, i );\r
3636         }\r
3637         cvEndWriteStruct( fs );\r
3638     }\r
3639     else\r
3640         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );\r
3641 \r
3642     cvEndWriteStruct( fs );\r
3643 }\r
3644 \r
3645 \r
3646 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )\r
3647 {\r
3648     CvDTreeSplit* split;\r
3649 \r
3650     cvStartWriteStruct( fs, 0, CV_NODE_MAP );\r
3651 \r
3652     cvWriteInt( fs, "depth", node->depth );\r
3653     cvWriteInt( fs, "sample_count", node->sample_count );\r
3654     cvWriteReal( fs, "value", node->value );\r
3655 \r
3656     if( data->is_classifier )\r
3657         cvWriteInt( fs, "norm_class_idx", node->class_idx );\r
3658 \r
3659     cvWriteInt( fs, "Tn", node->Tn );\r
3660     cvWriteInt( fs, "complexity", node->complexity );\r
3661     cvWriteReal( fs, "alpha", node->alpha );\r
3662     cvWriteReal( fs, "node_risk", node->node_risk );\r
3663     cvWriteReal( fs, "tree_risk", node->tree_risk );\r
3664     cvWriteReal( fs, "tree_error", node->tree_error );\r
3665 \r
3666     if( node->left )\r
3667     {\r
3668         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );\r
3669 \r
3670         for( split = node->split; split != 0; split = split->next )\r
3671             write_split( fs, split );\r
3672 \r
3673         cvEndWriteStruct( fs );\r
3674     }\r
3675 \r
3676     cvEndWriteStruct( fs );\r
3677 }\r
3678 \r
3679 \r
3680 void CvDTree::write_tree_nodes( CvFileStorage* fs )\r
3681 {\r
3682     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );\r
3683 \r
3684     __BEGIN__;\r
3685 \r
3686     CvDTreeNode* node = root;\r
3687 \r
3688     // traverse the tree and save all the nodes in depth-first order\r
3689     for(;;)\r
3690     {\r
3691         CvDTreeNode* parent;\r
3692         for(;;)\r
3693         {\r
3694             write_node( fs, node );\r
3695             if( !node->left )\r
3696                 break;\r
3697             node = node->left;\r
3698         }\r
3699 \r
3700         for( parent = node->parent; parent && parent->right == node;\r
3701             node = parent, parent = parent->parent )\r
3702             ;\r
3703 \r
3704         if( !parent )\r
3705             break;\r
3706 \r
3707         node = parent->right;\r
3708     }\r
3709 \r
3710     __END__;\r
3711 }\r
3712 \r
3713 \r
3714 void CvDTree::write( CvFileStorage* fs, const char* name )\r
3715 {\r
3716     //CV_FUNCNAME( "CvDTree::write" );\r
3717 \r
3718     __BEGIN__;\r
3719 \r
3720     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );\r
3721 \r
3722     get_var_importance();\r
3723     data->write_params( fs );\r
3724     if( var_importance )\r
3725         cvWrite( fs, "var_importance", var_importance );\r
3726     write( fs );\r
3727 \r
3728     cvEndWriteStruct( fs );\r
3729 \r
3730     __END__;\r
3731 }\r
3732 \r
3733 \r
3734 void CvDTree::write( CvFileStorage* fs )\r
3735 {\r
3736     //CV_FUNCNAME( "CvDTree::write" );\r
3737 \r
3738     __BEGIN__;\r
3739 \r
3740     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );\r
3741 \r
3742     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );\r
3743     write_tree_nodes( fs );\r
3744     cvEndWriteStruct( fs );\r
3745 \r
3746     __END__;\r
3747 }\r
3748 \r
3749 \r
3750 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )\r
3751 {\r
3752     CvDTreeSplit* split = 0;\r
3753 \r
3754     CV_FUNCNAME( "CvDTree::read_split" );\r
3755 \r
3756     __BEGIN__;\r
3757 \r
3758     int vi, ci;\r
3759 \r
3760     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3761         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );\r
3762 \r
3763     vi = cvReadIntByName( fs, fnode, "var", -1 );\r
3764     if( (unsigned)vi >= (unsigned)data->var_count )\r
3765         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );\r
3766 \r
3767     ci = data->get_var_type(vi);\r
3768     if( ci >= 0 ) // split on categorical var\r
3769     {\r
3770         int i, n = data->cat_count->data.i[ci], inversed = 0, val;\r
3771         CvSeqReader reader;\r
3772         CvFileNode* inseq;\r
3773         split = data->new_split_cat( vi, 0 );\r
3774         inseq = cvGetFileNodeByName( fs, fnode, "in" );\r
3775         if( !inseq )\r
3776         {\r
3777             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );\r
3778             inversed = 1;\r
3779         }\r
3780         if( !inseq ||\r
3781             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))\r
3782             CV_ERROR( CV_StsParseError,\r
3783             "Either 'in' or 'not_in' tags should be inside a categorical split data" );\r
3784 \r
3785         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )\r
3786         {\r
3787             val = inseq->data.i;\r
3788             if( (unsigned)val >= (unsigned)n )\r
3789                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3790 \r
3791             split->subset[val >> 5] |= 1 << (val & 31);\r
3792         }\r
3793         else\r
3794         {\r
3795             cvStartReadSeq( inseq->data.seq, &reader );\r
3796 \r
3797             for( i = 0; i < reader.seq->total; i++ )\r
3798             {\r
3799                 CvFileNode* inode = (CvFileNode*)reader.ptr;\r
3800                 val = inode->data.i;\r
3801                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )\r
3802                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3803 \r
3804                 split->subset[val >> 5] |= 1 << (val & 31);\r
3805                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3806             }\r
3807         }\r
3808 \r
3809         // for categorical splits we do not use inversed splits,\r
3810         // instead we inverse the variable set in the split\r
3811         if( inversed )\r
3812             for( i = 0; i < (n + 31) >> 5; i++ )\r
3813                 split->subset[i] ^= -1;\r
3814     }\r
3815     else\r
3816     {\r
3817         CvFileNode* cmp_node;\r
3818         split = data->new_split_ord( vi, 0, 0, 0, 0 );\r
3819 \r
3820         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );\r
3821         if( !cmp_node )\r
3822         {\r
3823             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );\r
3824             split->inversed = 1;\r
3825         }\r
3826 \r
3827         split->ord.c = (float)cvReadReal( cmp_node );\r
3828     }\r
3829 \r
3830     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );\r
3831 \r
3832     __END__;\r
3833 \r
3834     return split;\r
3835 }\r
3836 \r
3837 \r
3838 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )\r
3839 {\r
3840     CvDTreeNode* node = 0;\r
3841 \r
3842     CV_FUNCNAME( "CvDTree::read_node" );\r
3843 \r
3844     __BEGIN__;\r
3845 \r
3846     CvFileNode* splits;\r
3847     int i, depth;\r
3848 \r
3849     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3850         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );\r
3851 \r
3852     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));\r
3853     depth = cvReadIntByName( fs, fnode, "depth", -1 );\r
3854     if( depth != node->depth )\r
3855         CV_ERROR( CV_StsParseError, "incorrect node depth" );\r
3856 \r
3857     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );\r
3858     node->value = cvReadRealByName( fs, fnode, "value" );\r
3859     if( data->is_classifier )\r
3860         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );\r
3861 \r
3862     node->Tn = cvReadIntByName( fs, fnode, "Tn" );\r
3863     node->complexity = cvReadIntByName( fs, fnode, "complexity" );\r
3864     node->alpha = cvReadRealByName( fs, fnode, "alpha" );\r
3865     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );\r
3866     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );\r
3867     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );\r
3868 \r
3869     splits = cvGetFileNodeByName( fs, fnode, "splits" );\r
3870     if( splits )\r
3871     {\r
3872         CvSeqReader reader;\r
3873         CvDTreeSplit* last_split = 0;\r
3874 \r
3875         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )\r
3876             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );\r
3877 \r
3878         cvStartReadSeq( splits->data.seq, &reader );\r
3879         for( i = 0; i < reader.seq->total; i++ )\r
3880         {\r
3881             CvDTreeSplit* split;\r
3882             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));\r
3883             if( !last_split )\r
3884                 node->split = last_split = split;\r
3885             else\r
3886                 last_split = last_split->next = split;\r
3887 \r
3888             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3889         }\r
3890     }\r
3891 \r
3892     __END__;\r
3893 \r
3894     return node;\r
3895 }\r
3896 \r
3897 \r
3898 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )\r
3899 {\r
3900     CV_FUNCNAME( "CvDTree::read_tree_nodes" );\r
3901 \r
3902     __BEGIN__;\r
3903 \r
3904     CvSeqReader reader;\r
3905     CvDTreeNode _root;\r
3906     CvDTreeNode* parent = &_root;\r
3907     int i;\r
3908     parent->left = parent->right = parent->parent = 0;\r
3909 \r
3910     cvStartReadSeq( fnode->data.seq, &reader );\r
3911 \r
3912     for( i = 0; i < reader.seq->total; i++ )\r
3913     {\r
3914         CvDTreeNode* node;\r
3915 \r
3916         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));\r
3917         if( !parent->left )\r
3918             parent->left = node;\r
3919         else\r
3920             parent->right = node;\r
3921         if( node->split )\r
3922             parent = node;\r
3923         else\r
3924         {\r
3925             while( parent && parent->right )\r
3926                 parent = parent->parent;\r
3927         }\r
3928 \r
3929         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3930     }\r
3931 \r
3932     root = _root.left;\r
3933 \r
3934     __END__;\r
3935 }\r
3936 \r
3937 \r
3938 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )\r
3939 {\r
3940     CvDTreeTrainData* _data = new CvDTreeTrainData();\r
3941     _data->read_params( fs, fnode );\r
3942 \r
3943     read( fs, fnode, _data );\r
3944     get_var_importance();\r
3945 }\r
3946 \r
3947 \r
3948 // a special entry point for reading weak decision trees from the tree ensembles\r
3949 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )\r
3950 {\r
3951     CV_FUNCNAME( "CvDTree::read" );\r
3952 \r
3953     __BEGIN__;\r
3954 \r
3955     CvFileNode* tree_nodes;\r
3956 \r
3957     clear();\r
3958     data = _data;\r
3959 \r
3960     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );\r
3961     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )\r
3962         CV_ERROR( CV_StsParseError, "nodes tag is missing" );\r
3963 \r
3964     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );\r
3965     read_tree_nodes( fs, tree_nodes );\r
3966 \r
3967     __END__;\r
3968 }\r
3969 \r
3970 /* End of file. */\r