]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
- fixed some bugs in decision tree engine
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 //\r
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 //\r
5 //  By downloading, copying, installing or using the software you agree to this license.\r
6 //  If you do not agree to this license, do not download, install,\r
7 //  copy or use the software.\r
8 //\r
9 //\r
10 //                        Intel License Agreement\r
11 //\r
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.\r
13 // Third party copyrights are property of their respective owners.\r
14 //\r
15 // Redistribution and use in source and binary forms, with or without modification,\r
16 // are permitted provided that the following conditions are met:\r
17 //\r
18 //   * Redistribution's of source code must retain the above copyright notice,\r
19 //     this list of conditions and the following disclaimer.\r
20 //\r
21 //   * Redistribution's in binary form must reproduce the above copyright notice,\r
22 //     this list of conditions and the following disclaimer in the documentation\r
23 //     and/or other materials provided with the distribution.\r
24 //\r
25 //   * The name of Intel Corporation may not be used to endorse or promote products\r
26 //     derived from this software without specific prior written permission.\r
27 //\r
28 // This software is provided by the copyright holders and contributors "as is" and\r
29 // any express or implied warranties, including, but not limited to, the implied\r
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.\r
31 // In no event shall the Intel Corporation or contributors be liable for any direct,\r
32 // indirect, incidental, special, exemplary, or consequential damages\r
33 // (including, but not limited to, procurement of substitute goods or services;\r
34 // loss of use, data, or profits; or business interruption) however caused\r
35 // and on any theory of liability, whether in contract, strict liability,\r
36 // or tort (including negligence or otherwise) arising in any way out of\r
37 // the use of this software, even if advised of the possibility of such damage.\r
38 //\r
39 //M*/\r
40 \r
41 #include "_ml.h"\r
42 #include <ctype.h>\r
43 \r
44 static const float ord_nan = FLT_MAX*0.5f;\r
45 static const int min_block_size = 1 << 16;\r
46 static const int block_size_delta = 1 << 10;\r
47 \r
48 CvDTreeTrainData::CvDTreeTrainData()\r
49 {\r
50     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
51         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
52     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
53     pred_float_buf = resp_float_buf = 0;\r
54     tree_storage = temp_storage = 0;\r
55 \r
56     clear();\r
57 }\r
58 \r
59 \r
60 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,\r
61                       const CvMat* _responses, const CvMat* _var_idx,\r
62                       const CvMat* _sample_idx, const CvMat* _var_type,\r
63                       const CvMat* _missing_mask, const CvDTreeParams& _params,\r
64                       bool _shared, bool _add_labels )\r
65 {\r
66     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
67         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
68 \r
69     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
70     pred_float_buf = resp_float_buf = 0;\r
71 \r
72     tree_storage = temp_storage = 0;\r
73 \r
74     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,\r
75               _var_type, _missing_mask, _params, _shared, _add_labels );\r
76 }\r
77 \r
78 \r
79 CvDTreeTrainData::~CvDTreeTrainData()\r
80 {\r
81     clear();\r
82 }\r
83 \r
84 \r
85 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )\r
86 {\r
87     bool ok = false;\r
88 \r
89     CV_FUNCNAME( "CvDTreeTrainData::set_params" );\r
90 \r
91     __BEGIN__;\r
92 \r
93     // set parameters\r
94     params = _params;\r
95 \r
96     if( params.max_categories < 2 )\r
97         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );\r
98     params.max_categories = MIN( params.max_categories, 15 );\r
99 \r
100     if( params.max_depth < 0 )\r
101         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );\r
102     params.max_depth = MIN( params.max_depth, 25 );\r
103 \r
104     params.min_sample_count = MAX(params.min_sample_count,1);\r
105 \r
106     if( params.cv_folds < 0 )\r
107         CV_ERROR( CV_StsOutOfRange,\r
108         "params.cv_folds should be =0 (the tree is not pruned) "\r
109         "or n>0 (tree is pruned using n-fold cross-validation)" );\r
110 \r
111     if( params.cv_folds == 1 )\r
112         params.cv_folds = 0;\r
113 \r
114     if( params.regression_accuracy < 0 )\r
115         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );\r
116 \r
117     ok = true;\r
118 \r
119     __END__;\r
120 \r
121     return ok;\r
122 }\r
123 \r
124 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))\r
125 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )\r
126 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )\r
127 \r
128 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])\r
129 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )\r
130 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )\r
131 \r
132 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))\r
133 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )\r
134 \r
135 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
136     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
137     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
138     bool _shared, bool _add_labels, bool _update_data )\r
139 {\r
140     CvMat* sample_indices = 0;\r
141     CvMat* var_type0 = 0;\r
142     CvMat* tmp_map = 0;\r
143     int** int_ptr = 0;\r
144     CvPair16u32s* pair16u32s_ptr = 0;\r
145     CvDTreeTrainData* data = 0;\r
146     float *_fdst = 0;\r
147     int *_idst = 0;\r
148     unsigned short* udst = 0;\r
149     int* idst = 0;\r
150 \r
151     CV_FUNCNAME( "CvDTreeTrainData::set_data" );\r
152 \r
153     __BEGIN__;\r
154 \r
155     int sample_all = 0, r_type = 0, cv_n;\r
156     int total_c_count = 0;\r
157     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
158     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step\r
159     int vi, i, size;\r
160     char err[100];\r
161     const int *sidx = 0, *vidx = 0;\r
162     \r
163     \r
164     if( _update_data && data_root )\r
165     {\r
166         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,\r
167             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );\r
168 \r
169         // compare new and old train data\r
170         if( !(data->var_count == var_count &&\r
171             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
172             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&\r
173             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )\r
174             CV_ERROR( CV_StsBadArg,\r
175             "The new training data must have the same types and the input and output variables "\r
176             "and the same categories for categorical variables" );\r
177 \r
178         cvReleaseMat( &priors );\r
179         cvReleaseMat( &priors_mult );\r
180         cvReleaseMat( &buf );\r
181         cvReleaseMat( &direction );\r
182         cvReleaseMat( &split_buf );\r
183         cvReleaseMemStorage( &temp_storage );\r
184 \r
185         priors = data->priors; data->priors = 0;\r
186         priors_mult = data->priors_mult; data->priors_mult = 0;\r
187         buf = data->buf; data->buf = 0;\r
188         buf_count = data->buf_count; buf_size = data->buf_size;\r
189         sample_count = data->sample_count;\r
190 \r
191         direction = data->direction; data->direction = 0;\r
192         split_buf = data->split_buf; data->split_buf = 0;\r
193         temp_storage = data->temp_storage; data->temp_storage = 0;\r
194         nv_heap = data->nv_heap; cv_heap = data->cv_heap;\r
195 \r
196         data_root = new_node( 0, sample_count, 0, 0 );\r
197         EXIT;\r
198     }\r
199 \r
200     clear();\r
201 \r
202     var_all = 0;\r
203     rng = cvRNG(-1);\r
204 \r
205     CV_CALL( set_params( _params ));\r
206 \r
207     // check parameter types and sizes\r
208     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));\r
209 \r
210     train_data = _train_data;\r
211     responses = _responses;\r
212 \r
213     if( _tflag == CV_ROW_SAMPLE )\r
214     {\r
215         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
216         dv_step = 1;\r
217         if( _missing_mask )\r
218             ms_step = _missing_mask->step, mv_step = 1;\r
219     }\r
220     else\r
221     {\r
222         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
223         ds_step = 1;\r
224         if( _missing_mask )\r
225             mv_step = _missing_mask->step, ms_step = 1;\r
226     }\r
227     tflag = _tflag;\r
228 \r
229     sample_count = sample_all;\r
230     var_count = var_all;\r
231     \r
232     if (_train_data->rows + _train_data->cols -1 < 65536) \r
233         is_buf_16u = true;                                \r
234     \r
235     if( _sample_idx )\r
236     {\r
237         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));\r
238         sidx = sample_indices->data.i;\r
239         sample_count = sample_indices->rows + sample_indices->cols - 1;\r
240     }\r
241 \r
242     if( _var_idx )\r
243     {\r
244         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));\r
245         vidx = var_idx->data.i;\r
246         var_count = var_idx->rows + var_idx->cols - 1;\r
247     }\r
248 \r
249     if( !CV_IS_MAT(_responses) ||\r
250         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
251          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
252         (_responses->rows != 1 && _responses->cols != 1) ||\r
253         _responses->rows + _responses->cols - 1 != sample_all )\r
254         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
255                   "floating-point vector containing as many elements as "\r
256                   "the total number of samples in the training data matrix" );\r
257    \r
258   \r
259     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));\r
260 \r
261     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));\r
262    \r
263     \r
264     cat_var_count = 0;\r
265     ord_var_count = -1;\r
266 \r
267     is_classifier = r_type == CV_VAR_CATEGORICAL;\r
268 \r
269     // step 0. calc the number of categorical vars\r
270     for( vi = 0; vi < var_count; vi++ )\r
271     {\r
272         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?\r
273             cat_var_count++ : ord_var_count--;\r
274     }\r
275 \r
276     ord_var_count = ~ord_var_count;\r
277     cv_n = params.cv_folds;\r
278     // set the two last elements of var_type array to be able\r
279     // to locate responses and cross-validation labels using\r
280     // the corresponding get_* functions.\r
281     var_type->data.i[var_count] = cat_var_count;\r
282     var_type->data.i[var_count+1] = cat_var_count+1;\r
283 \r
284     // in case of single ordered predictor we need dummy cv_labels\r
285     // for safe split_node_data() operation\r
286     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;\r
287 \r
288     work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);\r
289     buf_size = (work_var_count + 1)*sample_count;\r
290     shared = _shared;\r
291     buf_count = shared ? 2 : 1;\r
292     \r
293     if ( is_buf_16u )\r
294     {\r
295         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));\r
296         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));\r
297     }\r
298     else\r
299     {\r
300         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));\r
301         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));\r
302     }    \r
303 \r
304     size = is_classifier ? (cat_var_count+1) : cat_var_count;\r
305     size = !size ? 1 : size;\r
306     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));\r
307     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));\r
308         \r
309     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;\r
310     size = !size ? 1 : size;\r
311     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));\r
312 \r
313     // now calculate the maximum size of split,\r
314     // create memory storage that will keep nodes and splits of the decision tree\r
315     // allocate root node and the buffer for the whole training data\r
316     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
317         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));\r
318     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
319     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
320     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
321     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));\r
322 \r
323     nv_size = var_count*sizeof(int);\r
324     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));\r
325 \r
326     temp_block_size = nv_size;\r
327 \r
328     if( cv_n )\r
329     {\r
330         if( sample_count < cv_n*MAX(params.min_sample_count,10) )\r
331             CV_ERROR( CV_StsOutOfRange,\r
332                 "The many folds in cross-validation for such a small dataset" );\r
333 \r
334         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );\r
335         temp_block_size = MAX(temp_block_size, cv_size);\r
336     }\r
337 \r
338     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );\r
339     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));\r
340     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));\r
341     if( cv_size )\r
342         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));\r
343 \r
344     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));\r
345 \r
346     max_c_count = 1;\r
347 \r
348     _fdst = 0;\r
349     _idst = 0;\r
350     if (ord_var_count)\r
351         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));\r
352     if (is_buf_16u && (cat_var_count || is_classifier))\r
353         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));\r
354 \r
355     // transform the training data to convenient representation\r
356     for( vi = 0; vi <= var_count; vi++ )\r
357     {\r
358         int ci;\r
359         const uchar* mask = 0;\r
360         int m_step = 0, step;\r
361         const int* idata = 0;\r
362         const float* fdata = 0;\r
363         int num_valid = 0;\r
364 \r
365         if( vi < var_count ) // analyze i-th input variable\r
366         {\r
367             int vi0 = vidx ? vidx[vi] : vi;\r
368             ci = get_var_type(vi);\r
369             step = ds_step; m_step = ms_step;\r
370             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
371                 idata = _train_data->data.i + vi0*dv_step;\r
372             else\r
373                 fdata = _train_data->data.fl + vi0*dv_step;\r
374             if( _missing_mask )\r
375                 mask = _missing_mask->data.ptr + vi0*mv_step;\r
376         }\r
377         else // analyze _responses\r
378         {\r
379             ci = cat_var_count;\r
380             step = CV_IS_MAT_CONT(_responses->type) ?\r
381                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);\r
382             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )\r
383                 idata = _responses->data.i;\r
384             else\r
385                 fdata = _responses->data.fl;\r
386         }\r
387 \r
388         if( (vi < var_count && ci>=0) ||\r
389             (vi == var_count && is_classifier) ) // process categorical variable or response\r
390         {\r
391             int c_count, prev_label;\r
392             int* c_map;\r
393             \r
394             if (is_buf_16u)\r
395                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
396             else\r
397                 idst = buf->data.i + vi*sample_count;\r
398             \r
399             // copy data\r
400             for( i = 0; i < sample_count; i++ )\r
401             {\r
402                 int val = INT_MAX, si = sidx ? sidx[i] : i;\r
403                 if( !mask || !mask[si*m_step] )\r
404                 {\r
405                     if( idata )\r
406                         val = idata[si*step];\r
407                     else\r
408                     {\r
409                         float t = fdata[si*step];\r
410                         val = cvRound(t);\r
411                         if( val != t )\r
412                         {\r
413                             sprintf( err, "%d-th value of %d-th (categorical) "\r
414                                 "variable is not an integer", i, vi );\r
415                             CV_ERROR( CV_StsBadArg, err );\r
416                         }\r
417                     }\r
418 \r
419                     if( val == INT_MAX )\r
420                     {\r
421                         sprintf( err, "%d-th value of %d-th (categorical) "\r
422                             "variable is too large", i, vi );\r
423                         CV_ERROR( CV_StsBadArg, err );\r
424                     }\r
425                     num_valid++;\r
426                 }\r
427                 if (is_buf_16u)\r
428                 {\r
429                     _idst[i] = val;\r
430                     pair16u32s_ptr[i].u = udst + i;\r
431                     pair16u32s_ptr[i].i = _idst + i;\r
432                 }   \r
433                 else\r
434                 {\r
435                     idst[i] = val;\r
436                     int_ptr[i] = idst + i;\r
437                 }\r
438             }\r
439 \r
440             c_count = num_valid > 0;\r
441 \r
442             if (is_buf_16u)\r
443             {\r
444                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );\r
445                 // count the categories\r
446                 for( i = 1; i < num_valid; i++ )\r
447                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)\r
448                         c_count ++ ;\r
449             }\r
450             else\r
451             {\r
452                 icvSortIntPtr( int_ptr, sample_count, 0 );\r
453                 // count the categories\r
454                 for( i = 1; i < num_valid; i++ )\r
455                     c_count += *int_ptr[i] != *int_ptr[i-1];\r
456             }\r
457 \r
458             if( vi > 0 )\r
459                 max_c_count = MAX( max_c_count, c_count );\r
460             cat_count->data.i[ci] = c_count;\r
461             cat_ofs->data.i[ci] = total_c_count;\r
462 \r
463             // resize cat_map, if need\r
464             if( cat_map->cols < total_c_count + c_count )\r
465             {\r
466                 tmp_map = cat_map;\r
467                 CV_CALL( cat_map = cvCreateMat( 1,\r
468                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));\r
469                 for( i = 0; i < total_c_count; i++ )\r
470                     cat_map->data.i[i] = tmp_map->data.i[i];\r
471                 cvReleaseMat( &tmp_map );\r
472             }\r
473 \r
474             c_map = cat_map->data.i + total_c_count;\r
475             total_c_count += c_count;\r
476 \r
477             c_count = -1;\r
478             if (is_buf_16u)\r
479             {\r
480                 // compact the class indices and build the map\r
481                 prev_label = ~*pair16u32s_ptr[0].i;\r
482                 for( i = 0; i < num_valid; i++ )\r
483                 {\r
484                     int cur_label = *pair16u32s_ptr[i].i;\r
485                     if( cur_label != prev_label )\r
486                         c_map[++c_count] = prev_label = cur_label;\r
487                     *pair16u32s_ptr[i].u = (unsigned short)c_count;\r
488                 }\r
489                 // replace labels for missing values with -1\r
490                 for( ; i < sample_count; i++ )\r
491                     *pair16u32s_ptr[i].u = 65535;\r
492             }\r
493             else\r
494             {\r
495                 // compact the class indices and build the map\r
496                 prev_label = ~*int_ptr[0];\r
497                 for( i = 0; i < num_valid; i++ )\r
498                 {\r
499                     int cur_label = *int_ptr[i];\r
500                     if( cur_label != prev_label )\r
501                         c_map[++c_count] = prev_label = cur_label;\r
502                     *int_ptr[i] = c_count;\r
503                 }\r
504                 // replace labels for missing values with -1\r
505                 for( ; i < sample_count; i++ )\r
506                     *int_ptr[i] = -1;\r
507             }           \r
508         }\r
509         else if( ci < 0 ) // process ordered variable\r
510         {\r
511             if (is_buf_16u)\r
512                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
513             else\r
514                 idst = buf->data.i + vi*sample_count;\r
515 \r
516             for( i = 0; i < sample_count; i++ )\r
517             {\r
518                 float val = ord_nan;\r
519                 int si = sidx ? sidx[i] : i;\r
520                 if( !mask || !mask[si*m_step] )\r
521                 {\r
522                     if( idata )\r
523                         val = (float)idata[si*step];\r
524                     else\r
525                         val = fdata[si*step];\r
526 \r
527                     if( fabs(val) >= ord_nan )\r
528                     {\r
529                         sprintf( err, "%d-th value of %d-th (ordered) "\r
530                             "variable (=%g) is too large", i, vi, val );\r
531                         CV_ERROR( CV_StsBadArg, err );\r
532                     }\r
533                 }\r
534                 num_valid++;\r
535                 if (is_buf_16u)\r
536                     udst[i] = (unsigned short)i;\r
537                 else\r
538                     idst[i] = i; // Ã¯Ã¥Ã°Ã¥Ã­Ã¥Ã±Ã²Ã¨ Ã¢Ã»Ã¸Ã¥ Ã¢ if( idata )\r
539                 _fdst[i] = val;\r
540                 \r
541             }\r
542             if (is_buf_16u)\r
543                 icvSortUShAux( udst, num_valid, _fdst);\r
544             else\r
545                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );\r
546         }\r
547        \r
548         if( vi < var_count )\r
549             data_root->set_num_valid(vi, num_valid);\r
550     }\r
551 \r
552     // set sample labels\r
553     if (is_buf_16u)\r
554         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);\r
555     else\r
556         idst = buf->data.i + work_var_count*sample_count;\r
557 \r
558     for (i = 0; i < sample_count; i++)\r
559     {\r
560         if (udst)\r
561             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;\r
562         else\r
563             idst[i] = sidx ? sidx[i] : i;\r
564     }\r
565 \r
566     if( cv_n )\r
567     {\r
568         unsigned short* udst = 0;\r
569         int* idst = 0;\r
570         CvRNG* r = &rng;\r
571 \r
572         if (is_buf_16u)\r
573         {\r
574             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);\r
575             for( i = vi = 0; i < sample_count; i++ )\r
576             {\r
577                 udst[i] = (unsigned short)vi++;\r
578                 vi &= vi < cv_n ? -1 : 0;\r
579             }\r
580 \r
581             for( i = 0; i < sample_count; i++ )\r
582             {\r
583                 int a = cvRandInt(r) % sample_count;\r
584                 int b = cvRandInt(r) % sample_count;\r
585                 unsigned short unsh = (unsigned short)vi;\r
586                 CV_SWAP( udst[a], udst[b], unsh );\r
587             }\r
588         }\r
589         else\r
590         {\r
591             idst = buf->data.i + (get_work_var_count()-1)*sample_count;\r
592             for( i = vi = 0; i < sample_count; i++ )\r
593             {\r
594                 idst[i] = vi++;\r
595                 vi &= vi < cv_n ? -1 : 0;\r
596             }\r
597 \r
598             for( i = 0; i < sample_count; i++ )\r
599             {\r
600                 int a = cvRandInt(r) % sample_count;\r
601                 int b = cvRandInt(r) % sample_count;\r
602                 CV_SWAP( idst[a], idst[b], vi );\r
603             }\r
604         }\r
605     }\r
606 \r
607     if ( cat_map ) \r
608         cat_map->cols = MAX( total_c_count, 1 );\r
609 \r
610     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
611         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
612     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));\r
613 \r
614     have_priors = is_classifier && params.priors;\r
615     if( is_classifier )\r
616     {\r
617         int m = get_num_classes();\r
618         double sum = 0;\r
619         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));\r
620         for( i = 0; i < m; i++ )\r
621         {\r
622             double val = have_priors ? params.priors[i] : 1.;\r
623             if( val <= 0 )\r
624                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );\r
625             priors->data.db[i] = val;\r
626             sum += val;\r
627         }\r
628 \r
629         // normalize weights\r
630         if( have_priors )\r
631             cvScale( priors, priors, 1./sum );\r
632 \r
633         CV_CALL( priors_mult = cvCloneMat( priors ));\r
634         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));\r
635     }\r
636 \r
637 \r
638     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));\r
639     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));\r
640 \r
641     CV_CALL( pred_float_buf = (float*)cvAlloc(sample_count*sizeof(pred_float_buf[0])) );\r
642     CV_CALL( pred_int_buf = (int*)cvAlloc(sample_count*sizeof(pred_int_buf[0])) );\r
643     CV_CALL( resp_float_buf = (float*)cvAlloc(sample_count*sizeof(resp_float_buf[0])) );\r
644     CV_CALL( resp_int_buf = (int*)cvAlloc(sample_count*sizeof(resp_int_buf[0])) );\r
645     CV_CALL( cv_lables_buf = (int*)cvAlloc(sample_count*sizeof(cv_lables_buf[0])) );\r
646     CV_CALL( sample_idx_buf = (int*)cvAlloc(sample_count*sizeof(sample_idx_buf[0])) );\r
647 \r
648     __END__;\r
649 \r
650     if( data )\r
651         delete data;\r
652 \r
653     if (_fdst)\r
654         cvFree( &_fdst );\r
655     if (_idst)\r
656         cvFree( &_idst );\r
657     cvFree( &int_ptr );\r
658     cvReleaseMat( &var_type0 );\r
659     cvReleaseMat( &sample_indices );\r
660     cvReleaseMat( &tmp_map );\r
661 }\r
662 \r
663 \r
664 \r
665 void CvDTreeTrainData::do_responses_copy()\r
666 {\r
667     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );\r
668     cvCopy( responses, responses_copy);\r
669     responses = responses_copy;\r
670 }\r
671 \r
672 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )\r
673 {\r
674     CvDTreeNode* root = 0;\r
675     CvMat* isubsample_idx = 0;\r
676     CvMat* subsample_co = 0;\r
677 \r
678     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );\r
679 \r
680     __BEGIN__;\r
681 \r
682     if( !data_root )\r
683         CV_ERROR( CV_StsError, "No training data has been set" );\r
684 \r
685     if( _subsample_idx )\r
686         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
687 \r
688     if( !isubsample_idx )\r
689     {\r
690         // make a copy of the root node\r
691         CvDTreeNode temp;\r
692         int i;\r
693         root = new_node( 0, 1, 0, 0 );\r
694         temp = *root;\r
695         *root = *data_root;\r
696         root->num_valid = temp.num_valid;\r
697         if( root->num_valid )\r
698         {\r
699             for( i = 0; i < var_count; i++ )\r
700                 root->num_valid[i] = data_root->num_valid[i];\r
701         }\r
702         root->cv_Tn = temp.cv_Tn;\r
703         root->cv_node_risk = temp.cv_node_risk;\r
704         root->cv_node_error = temp.cv_node_error;\r
705     }\r
706     else\r
707     {\r
708         int* sidx = isubsample_idx->data.i;\r
709         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)\r
710         int* co, cur_ofs = 0;\r
711         int vi, i;\r
712         int work_var_count = get_work_var_count();\r
713         int count = isubsample_idx->rows + isubsample_idx->cols - 1;\r
714 \r
715         root = new_node( 0, count, 1, 0 );\r
716 \r
717         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
718         cvZero( subsample_co );\r
719         co = subsample_co->data.i;\r
720         for( i = 0; i < count; i++ )\r
721             co[sidx[i]*2]++;\r
722         for( i = 0; i < sample_count; i++ )\r
723         {\r
724             if( co[i*2] )\r
725             {\r
726                 co[i*2+1] = cur_ofs;\r
727                 cur_ofs += co[i*2];\r
728             }\r
729             else\r
730                 co[i*2+1] = -1;\r
731         }\r
732 \r
733         for( vi = 0; vi < work_var_count; vi++ )\r
734         {\r
735             int ci = get_var_type(vi);\r
736 \r
737             if( ci >= 0 || vi >= var_count )\r
738             {\r
739                 int* src_buf = pred_int_buf;\r
740                 const int* src = 0;\r
741                 int num_valid = 0;\r
742                 \r
743                 get_cat_var_data( data_root, vi, src_buf, &src );\r
744 \r
745                 if (is_buf_16u)\r
746                 {\r
747                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
748                         vi*sample_count + root->offset);\r
749                     for( i = 0; i < count; i++ )\r
750                     {\r
751                         int val = src[sidx[i]];\r
752                         udst[i] = (unsigned short)val;\r
753                         num_valid += val >= 0;\r
754                     }\r
755                 }\r
756                 else\r
757                 {\r
758                     int* idst = buf->data.i + root->buf_idx*buf->cols + \r
759                         vi*sample_count + root->offset;\r
760                     for( i = 0; i < count; i++ )\r
761                     {\r
762                         int val = src[sidx[i]];\r
763                         idst[i] = val;\r
764                         num_valid += val >= 0;\r
765                     }\r
766                 }\r
767 \r
768                 if( vi < var_count )\r
769                     root->set_num_valid(vi, num_valid);\r
770             }\r
771             else\r
772             {\r
773                 int *src_idx_buf = pred_int_buf;\r
774                 const int* src_idx = 0;\r
775                 float *src_val_buf = pred_float_buf;\r
776                 const float* src_val = 0;\r
777                 int j = 0, idx, count_i;\r
778                 int num_valid = data_root->get_num_valid(vi);\r
779 \r
780                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );\r
781                 if (is_buf_16u)\r
782                 {\r
783                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
784                         vi*sample_count + data_root->offset);\r
785                     for( i = 0; i < num_valid; i++ )\r
786                     {\r
787                         idx = src_idx[i];\r
788                         count_i = co[idx*2];\r
789                         if( count_i )\r
790                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
791                                 udst_idx[j] = (unsigned short)cur_ofs;\r
792                     }\r
793 \r
794                     root->set_num_valid(vi, j);\r
795 \r
796                     for( ; i < sample_count; i++ )\r
797                     {\r
798                         idx = src_idx[i];\r
799                         count_i = co[idx*2];\r
800                         if( count_i )\r
801                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
802                                 udst_idx[j] = (unsigned short)cur_ofs;\r
803                     }\r
804                 }\r
805                 else\r
806                 {\r
807                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + \r
808                         vi*sample_count + root->offset;\r
809                     for( i = 0; i < num_valid; i++ )\r
810                     {\r
811                         idx = src_idx[i];\r
812                         count_i = co[idx*2];\r
813                         if( count_i )\r
814                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
815                                 idst_idx[j] = cur_ofs;\r
816                     }\r
817 \r
818                     root->set_num_valid(vi, j);\r
819 \r
820                     for( ; i < sample_count; i++ )\r
821                     {\r
822                         idx = src_idx[i];\r
823                         count_i = co[idx*2];\r
824                         if( count_i )\r
825                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
826                                 idst_idx[j] = cur_ofs;\r
827                     }\r
828                 }\r
829             }\r
830         }\r
831         // sample indices subsampling\r
832         int* sample_idx_src_buf = sample_idx_buf;\r
833         const int* sample_idx_src = 0;\r
834         get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);\r
835         if (is_buf_16u)\r
836         {\r
837             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
838                 get_work_var_count()*sample_count + root->offset);            \r
839             for (i = 0; i < count; i++)\r
840                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];\r
841         }\r
842         else\r
843         {\r
844             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + \r
845                 get_work_var_count()*sample_count + root->offset;            \r
846             for (i = 0; i < count; i++)\r
847                 sample_idx_dst[i] = sample_idx_src[sidx[i]];\r
848         }\r
849     }\r
850 \r
851     __END__;\r
852 \r
853     cvReleaseMat( &isubsample_idx );\r
854     cvReleaseMat( &subsample_co );\r
855 \r
856     return root;\r
857 }\r
858 \r
859 \r
860 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,\r
861                                     float* values, uchar* missing,\r
862                                     float* responses, bool get_class_idx )\r
863 {\r
864     CvMat* subsample_idx = 0;\r
865     CvMat* subsample_co = 0;\r
866 \r
867     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );\r
868 \r
869     __BEGIN__;\r
870 \r
871     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
872     int* sidx = 0;\r
873     int* co = 0;\r
874 \r
875     if( _subsample_idx )\r
876     {\r
877         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
878         sidx = subsample_idx->data.i;\r
879         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
880         co = subsample_co->data.i;\r
881         cvZero( subsample_co );\r
882         count = subsample_idx->cols + subsample_idx->rows - 1;\r
883         for( i = 0; i < count; i++ )\r
884             co[sidx[i]*2]++;\r
885         for( i = 0; i < total; i++ )\r
886         {\r
887             int count_i = co[i*2];\r
888             if( count_i )\r
889             {\r
890                 co[i*2+1] = cur_ofs*var_count;\r
891                 cur_ofs += count_i;\r
892             }\r
893         }\r
894     }\r
895 \r
896     if( missing )\r
897         memset( missing, 1, count*var_count );\r
898 \r
899     for( vi = 0; vi < var_count; vi++ )\r
900     {\r
901         int ci = get_var_type(vi);\r
902         if( ci >= 0 ) // categorical\r
903         {\r
904             float* dst = values + vi;\r
905             uchar* m = missing ? missing + vi : 0;\r
906             int* src_buf = pred_int_buf;\r
907             const int* src = 0; \r
908             get_cat_var_data(data_root, vi, src_buf, &src);\r
909 \r
910             for( i = 0; i < count; i++, dst += var_count )\r
911             {\r
912                 int idx = sidx ? sidx[i] : i;\r
913                 int val = src[idx];\r
914                 *dst = (float)val;\r
915                 if( m )\r
916                 {\r
917                     *m = val < 0;\r
918                     m += var_count;\r
919                 }\r
920             }\r
921         }\r
922         else // ordered\r
923         {\r
924             float* dst = values + vi;\r
925             uchar* m = missing ? missing + vi : 0;\r
926             int count1 = data_root->get_num_valid(vi);\r
927             float *src_val_buf = pred_float_buf;\r
928             const float *src_val = 0;\r
929             int* src_idx_buf = pred_int_buf;\r
930             const int* src_idx = 0;\r
931             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
932 \r
933             for( i = 0; i < count1; i++ )\r
934             {\r
935                 int idx = src_idx[i];\r
936                 int count_i = 1;\r
937                 if( co )\r
938                 {\r
939                     count_i = co[idx*2];\r
940                     cur_ofs = co[idx*2+1];\r
941                 }\r
942                 else\r
943                     cur_ofs = idx*var_count;\r
944                 if( count_i )\r
945                 {\r
946                     float val = src_val[i];\r
947                     for( ; count_i > 0; count_i--, cur_ofs += var_count )\r
948                     {\r
949                         dst[cur_ofs] = val;\r
950                         if( m )\r
951                             m[cur_ofs] = 0;\r
952                     }\r
953                 }\r
954             }\r
955         }\r
956     }\r
957 \r
958     // copy responses\r
959     if( responses )\r
960     {\r
961         if( is_classifier )\r
962         {\r
963             int* src_buf = resp_int_buf;\r
964             const int* src = 0;\r
965             get_class_labels(data_root, src_buf, &src);\r
966             for( i = 0; i < count; i++ )\r
967             {\r
968                 int idx = sidx ? sidx[i] : i;\r
969                 int val = get_class_idx ? src[idx] :\r
970                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];\r
971                 responses[i] = (float)val;\r
972             }\r
973         }\r
974         else\r
975         {\r
976             float *_values_buf = resp_float_buf;\r
977             const float* _values = 0;\r
978             get_ord_responses(data_root, _values_buf, &_values);\r
979             for( i = 0; i < count; i++ )\r
980             {\r
981                 int idx = sidx ? sidx[i] : i;\r
982                 responses[i] = _values[idx];\r
983             }\r
984         }\r
985     }\r
986 \r
987     __END__;\r
988 \r
989     cvReleaseMat( &subsample_idx );\r
990     cvReleaseMat( &subsample_co );\r
991 }\r
992 \r
993 \r
994 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,\r
995                                          int storage_idx, int offset )\r
996 {\r
997     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );\r
998 \r
999     node->sample_count = count;\r
1000     node->depth = parent ? parent->depth + 1 : 0;\r
1001     node->parent = parent;\r
1002     node->left = node->right = 0;\r
1003     node->split = 0;\r
1004     node->value = 0;\r
1005     node->class_idx = 0;\r
1006     node->maxlr = 0.;\r
1007 \r
1008     node->buf_idx = storage_idx;\r
1009     node->offset = offset;\r
1010     if( nv_heap )\r
1011         node->num_valid = (int*)cvSetNew( nv_heap );\r
1012     else\r
1013         node->num_valid = 0;\r
1014     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;\r
1015     node->complexity = 0;\r
1016 \r
1017     if( params.cv_folds > 0 && cv_heap )\r
1018     {\r
1019         int cv_n = params.cv_folds;\r
1020         node->Tn = INT_MAX;\r
1021         node->cv_Tn = (int*)cvSetNew( cv_heap );\r
1022         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));\r
1023         node->cv_node_error = node->cv_node_risk + cv_n;\r
1024     }\r
1025     else\r
1026     {\r
1027         node->Tn = 0;\r
1028         node->cv_Tn = 0;\r
1029         node->cv_node_risk = 0;\r
1030         node->cv_node_error = 0;\r
1031     }\r
1032 \r
1033     return node;\r
1034 }\r
1035 \r
1036 \r
1037 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,\r
1038                 int split_point, int inversed, float quality )\r
1039 {\r
1040     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1041     split->var_idx = vi;\r
1042     split->condensed_idx = INT_MIN;\r
1043     split->ord.c = cmp_val;\r
1044     split->ord.split_point = split_point;\r
1045     split->inversed = inversed;\r
1046     split->quality = quality;\r
1047     split->next = 0;\r
1048 \r
1049     return split;\r
1050 }\r
1051 \r
1052 \r
1053 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )\r
1054 {\r
1055     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1056     int i, n = (max_c_count + 31)/32;\r
1057 \r
1058     split->var_idx = vi;\r
1059     split->condensed_idx = INT_MIN;\r
1060     split->inversed = 0;\r
1061     split->quality = quality;\r
1062     for( i = 0; i < n; i++ )\r
1063         split->subset[i] = 0;\r
1064     split->next = 0;\r
1065 \r
1066     return split;\r
1067 }\r
1068 \r
1069 \r
1070 void CvDTreeTrainData::free_node( CvDTreeNode* node )\r
1071 {\r
1072     CvDTreeSplit* split = node->split;\r
1073     free_node_data( node );\r
1074     while( split )\r
1075     {\r
1076         CvDTreeSplit* next = split->next;\r
1077         cvSetRemoveByPtr( split_heap, split );\r
1078         split = next;\r
1079     }\r
1080     node->split = 0;\r
1081     cvSetRemoveByPtr( node_heap, node );\r
1082 }\r
1083 \r
1084 \r
1085 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )\r
1086 {\r
1087     if( node->num_valid )\r
1088     {\r
1089         cvSetRemoveByPtr( nv_heap, node->num_valid );\r
1090         node->num_valid = 0;\r
1091     }\r
1092     // do not free cv_* fields, as all the cross-validation related data is released at once.\r
1093 }\r
1094 \r
1095 \r
1096 void CvDTreeTrainData::free_train_data()\r
1097 {\r
1098     cvReleaseMat( &counts );\r
1099     cvReleaseMat( &buf );\r
1100     cvReleaseMat( &direction );\r
1101     cvReleaseMat( &split_buf );\r
1102     cvReleaseMemStorage( &temp_storage );\r
1103     cvReleaseMat( &responses_copy );\r
1104     cvFree( &pred_float_buf );\r
1105     cvFree( &pred_int_buf );\r
1106     cvFree( &resp_float_buf );\r
1107     cvFree( &resp_int_buf );\r
1108     cvFree( &cv_lables_buf );\r
1109     cvFree( &sample_idx_buf );\r
1110 \r
1111     cv_heap = nv_heap = 0;\r
1112 }\r
1113 \r
1114 \r
1115 void CvDTreeTrainData::clear()\r
1116 {\r
1117     free_train_data();\r
1118 \r
1119     cvReleaseMemStorage( &tree_storage );\r
1120 \r
1121     cvReleaseMat( &var_idx );\r
1122     cvReleaseMat( &var_type );\r
1123     cvReleaseMat( &cat_count );\r
1124     cvReleaseMat( &cat_ofs );\r
1125     cvReleaseMat( &cat_map );\r
1126     cvReleaseMat( &priors );\r
1127     cvReleaseMat( &priors_mult );\r
1128     \r
1129     node_heap = split_heap = 0;\r
1130 \r
1131     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;\r
1132     have_labels = have_priors = is_classifier = false;\r
1133 \r
1134     buf_count = buf_size = 0;\r
1135     shared = false;\r
1136     \r
1137     data_root = 0;\r
1138 \r
1139     rng = cvRNG(-1);\r
1140 }\r
1141 \r
1142 \r
1143 int CvDTreeTrainData::get_num_classes() const\r
1144 {\r
1145     return is_classifier ? cat_count->data.i[cat_var_count] : 0;\r
1146 }\r
1147 \r
1148 \r
1149 int CvDTreeTrainData::get_var_type(int vi) const\r
1150 {\r
1151     return var_type->data.i[vi];\r
1152 }\r
1153 \r
1154 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )\r
1155 {\r
1156     int vidx = var_idx->data.i[vi];\r
1157     int node_sample_count = n->sample_count; \r
1158     int* sample_indices_buf = sample_idx_buf;\r
1159     const int* sample_indices = 0;\r
1160     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);\r
1161 \r
1162     get_sample_indices(n, sample_indices_buf, &sample_indices);\r
1163 \r
1164     if( !is_buf_16u )\r
1165         *indices = buf->data.i + n->buf_idx*buf->cols + \r
1166         vi*sample_count + n->offset;\r
1167     else {\r
1168         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1169             vi*sample_count + n->offset );\r
1170         for( int i = 0; i < node_sample_count; i++ )\r
1171             indices_buf[i] = short_indices[i];\r
1172         *indices = indices_buf;\r
1173     }\r
1174     \r
1175     if( tflag == CV_ROW_SAMPLE )\r
1176     {\r
1177         for( int i = 0; i < node_sample_count && \r
1178             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1179         {\r
1180             int idx = (*indices)[i];\r
1181             idx = sample_indices[idx];\r
1182             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);\r
1183         }\r
1184     }\r
1185     else\r
1186         for( int i = 0; i < node_sample_count && \r
1187             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1188         {\r
1189             int idx = (*indices)[i];\r
1190             idx = sample_indices[idx];\r
1191             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);\r
1192         }\r
1193     \r
1194     *ord_values = ord_values_buf;\r
1195     return 0; //TODO: return the number of non-missing values\r
1196 }\r
1197 \r
1198 \r
1199 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1200 {\r
1201     if (is_classifier)\r
1202         get_cat_var_data( n, var_count, labels_buf, labels );\r
1203 }\r
1204 \r
1205 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )\r
1206 {\r
1207     get_cat_var_data( n, get_work_var_count(), indices_buf, indices );\r
1208 }\r
1209 \r
1210 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)\r
1211 {\r
1212     int sample_count = n->sample_count;\r
1213     int* indices_buf = sample_idx_buf;\r
1214     const int* indices = 0;\r
1215 \r
1216     int r_step = responses->step/CV_ELEM_SIZE(responses->type);\r
1217 \r
1218     get_sample_indices(n, indices_buf, &indices);\r
1219 \r
1220     \r
1221     for( int i = 0; i < sample_count && \r
1222         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )\r
1223     {\r
1224         int idx = indices[i];\r
1225         values_buf[i] = *(responses->data.fl + idx * r_step);\r
1226     }\r
1227     \r
1228     *values = values_buf;    \r
1229 }\r
1230 \r
1231 \r
1232 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1233 {\r
1234     if (have_labels)\r
1235         get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );\r
1236 }\r
1237 \r
1238 \r
1239 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )\r
1240 {\r
1241     if( !is_buf_16u )\r
1242         *cat_values = buf->data.i + n->buf_idx*buf->cols + \r
1243         vi*sample_count + n->offset;\r
1244     else {\r
1245         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1246             vi*sample_count + n->offset);\r
1247         for( int i = 0; i < n->sample_count; i++ )\r
1248             cat_values_buf[i] = short_values[i];\r
1249         *cat_values = cat_values_buf;\r
1250     }\r
1251 \r
1252     return 0; //TODO: return the number of non-missing values\r
1253 }\r
1254 \r
1255 \r
1256 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )\r
1257 {\r
1258     int idx = n->buf_idx + 1;\r
1259     if( idx >= buf_count )\r
1260         idx = shared ? 1 : 0;\r
1261     return idx;\r
1262 }\r
1263 \r
1264 \r
1265 void CvDTreeTrainData::write_params( CvFileStorage* fs )\r
1266 {\r
1267     CV_FUNCNAME( "CvDTreeTrainData::write_params" );\r
1268 \r
1269     __BEGIN__;\r
1270 \r
1271     int vi, vcount = var_count;\r
1272 \r
1273     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );\r
1274     cvWriteInt( fs, "var_all", var_all );\r
1275     cvWriteInt( fs, "var_count", var_count );\r
1276     cvWriteInt( fs, "ord_var_count", ord_var_count );\r
1277     cvWriteInt( fs, "cat_var_count", cat_var_count );\r
1278 \r
1279     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );\r
1280     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );\r
1281 \r
1282     if( is_classifier )\r
1283     {\r
1284         cvWriteInt( fs, "max_categories", params.max_categories );\r
1285     }\r
1286     else\r
1287     {\r
1288         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );\r
1289     }\r
1290 \r
1291     cvWriteInt( fs, "max_depth", params.max_depth );\r
1292     cvWriteInt( fs, "min_sample_count", params.min_sample_count );\r
1293     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );\r
1294 \r
1295     if( params.cv_folds > 1 )\r
1296     {\r
1297         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );\r
1298         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );\r
1299     }\r
1300 \r
1301     if( priors )\r
1302         cvWrite( fs, "priors", priors );\r
1303 \r
1304     cvEndWriteStruct( fs );\r
1305 \r
1306     if( var_idx )\r
1307         cvWrite( fs, "var_idx", var_idx );\r
1308 \r
1309     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );\r
1310 \r
1311     for( vi = 0; vi < vcount; vi++ )\r
1312         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );\r
1313 \r
1314     cvEndWriteStruct( fs );\r
1315 \r
1316     if( cat_count && (cat_var_count > 0 || is_classifier) )\r
1317     {\r
1318         CV_ASSERT( cat_count != 0 );\r
1319         cvWrite( fs, "cat_count", cat_count );\r
1320         cvWrite( fs, "cat_map", cat_map );\r
1321     }\r
1322 \r
1323     __END__;\r
1324 }\r
1325 \r
1326 \r
1327 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )\r
1328 {\r
1329     CV_FUNCNAME( "CvDTreeTrainData::read_params" );\r
1330 \r
1331     __BEGIN__;\r
1332 \r
1333     CvFileNode *tparams_node, *vartype_node;\r
1334     CvSeqReader reader;\r
1335     int vi, max_split_size, tree_block_size;\r
1336 \r
1337     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);\r
1338     var_all = cvReadIntByName( fs, node, "var_all" );\r
1339     var_count = cvReadIntByName( fs, node, "var_count", var_all );\r
1340     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );\r
1341     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );\r
1342 \r
1343     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );\r
1344 \r
1345     if( tparams_node ) // training parameters are not necessary\r
1346     {\r
1347         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;\r
1348 \r
1349         if( is_classifier )\r
1350         {\r
1351             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );\r
1352         }\r
1353         else\r
1354         {\r
1355             params.regression_accuracy =\r
1356                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );\r
1357         }\r
1358 \r
1359         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );\r
1360         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );\r
1361         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );\r
1362 \r
1363         if( params.cv_folds > 1 )\r
1364         {\r
1365             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;\r
1366             params.truncate_pruned_tree =\r
1367                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;\r
1368         }\r
1369 \r
1370         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );\r
1371         if( priors )\r
1372         {\r
1373             if( !CV_IS_MAT(priors) )\r
1374                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );\r
1375             priors_mult = cvCloneMat( priors );\r
1376         }\r
1377     }\r
1378 \r
1379     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));\r
1380     if( var_idx )\r
1381     {\r
1382         if( !CV_IS_MAT(var_idx) ||\r
1383             (var_idx->cols != 1 && var_idx->rows != 1) ||\r
1384             var_idx->cols + var_idx->rows - 1 != var_count ||\r
1385             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )\r
1386             CV_ERROR( CV_StsParseError,\r
1387                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );\r
1388 \r
1389         for( vi = 0; vi < var_count; vi++ )\r
1390             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )\r
1391                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );\r
1392     }\r
1393 \r
1394     ////// read var type\r
1395     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));\r
1396 \r
1397     cat_var_count = 0;\r
1398     ord_var_count = -1;\r
1399     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );\r
1400 \r
1401     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )\r
1402         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;\r
1403     else\r
1404     {\r
1405         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||\r
1406             vartype_node->data.seq->total != var_count )\r
1407             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1408 \r
1409         cvStartReadSeq( vartype_node->data.seq, &reader );\r
1410 \r
1411         for( vi = 0; vi < var_count; vi++ )\r
1412         {\r
1413             CvFileNode* n = (CvFileNode*)reader.ptr;\r
1414             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )\r
1415                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1416             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;\r
1417             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
1418         }\r
1419     }\r
1420     var_type->data.i[var_count] = cat_var_count;\r
1421 \r
1422     ord_var_count = ~ord_var_count;\r
1423     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )\r
1424         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );\r
1425     //////\r
1426 \r
1427     if( cat_var_count > 0 || is_classifier )\r
1428     {\r
1429         int ccount, total_c_count = 0;\r
1430         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));\r
1431         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));\r
1432 \r
1433         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||\r
1434             (cat_count->cols != 1 && cat_count->rows != 1) ||\r
1435             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||\r
1436             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||\r
1437             (cat_map->cols != 1 && cat_map->rows != 1) ||\r
1438             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )\r
1439             CV_ERROR( CV_StsParseError,\r
1440             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );\r
1441 \r
1442         ccount = cat_var_count + is_classifier;\r
1443 \r
1444         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));\r
1445         cat_ofs->data.i[0] = 0;\r
1446         max_c_count = 1;\r
1447 \r
1448         for( vi = 0; vi < ccount; vi++ )\r
1449         {\r
1450             int val = cat_count->data.i[vi];\r
1451             if( val <= 0 )\r
1452                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );\r
1453             max_c_count = MAX( max_c_count, val );\r
1454             cat_ofs->data.i[vi+1] = total_c_count += val;\r
1455         }\r
1456 \r
1457         if( cat_map->cols + cat_map->rows - 1 != total_c_count )\r
1458             CV_ERROR( CV_StsBadSize,\r
1459             "cat_map vector length is not equal to the total number of categories in all categorical vars" );\r
1460     }\r
1461 \r
1462     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
1463         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
1464 \r
1465     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
1466     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
1467     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
1468     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),\r
1469             sizeof(CvDTreeNode), tree_storage ));\r
1470     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),\r
1471             max_split_size, tree_storage ));\r
1472 \r
1473     __END__;\r
1474 }\r
1475 /////////////////////// Decision Tree /////////////////////////\r
1476 \r
1477 CvDTree::CvDTree()\r
1478 {\r
1479     data = 0;\r
1480     var_importance = 0;\r
1481     default_model_name = "my_tree";\r
1482 \r
1483     clear();\r
1484 }\r
1485 \r
1486 \r
1487 void CvDTree::clear()\r
1488 {\r
1489     cvReleaseMat( &var_importance );\r
1490     if( data )\r
1491     {\r
1492         if( !data->shared )\r
1493             delete data;\r
1494         else\r
1495             free_tree();\r
1496         data = 0;\r
1497     }\r
1498     root = 0;\r
1499     pruned_tree_idx = -1;\r
1500 }\r
1501 \r
1502 \r
1503 CvDTree::~CvDTree()\r
1504 {\r
1505     clear();\r
1506 }\r
1507 \r
1508 \r
1509 const CvDTreeNode* CvDTree::get_root() const\r
1510 {\r
1511     return root;\r
1512 }\r
1513 \r
1514 \r
1515 int CvDTree::get_pruned_tree_idx() const\r
1516 {\r
1517     return pruned_tree_idx;\r
1518 }\r
1519 \r
1520 \r
1521 CvDTreeTrainData* CvDTree::get_data()\r
1522 {\r
1523     return data;\r
1524 }\r
1525 \r
1526 \r
1527 bool CvDTree::train( const CvMat* _train_data, int _tflag,\r
1528                      const CvMat* _responses, const CvMat* _var_idx,\r
1529                      const CvMat* _sample_idx, const CvMat* _var_type,\r
1530                      const CvMat* _missing_mask, CvDTreeParams _params )\r
1531 {\r
1532     bool result = false;\r
1533 \r
1534     CV_FUNCNAME( "CvDTree::train" );\r
1535 \r
1536     __BEGIN__;\r
1537 \r
1538     clear();\r
1539     data = new CvDTreeTrainData( _train_data, _tflag, _responses,\r
1540                                  _var_idx, _sample_idx, _var_type,\r
1541                                  _missing_mask, _params, false );\r
1542     CV_CALL( result = do_train(0) );\r
1543 \r
1544     __END__;\r
1545 \r
1546     return result;\r
1547 }\r
1548 \r
1549 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )\r
1550 {\r
1551    bool result = false;\r
1552 \r
1553     CV_FUNCNAME( "CvDTree::train" );\r
1554 \r
1555     __BEGIN__;\r
1556 \r
1557     const CvMat* values = _data->get_values();
1558     const CvMat* response = _data->get_response();
1559     const CvMat* missing = _data->get_missing();
1560     const CvMat* var_types = _data->get_var_types();
1561     const CvMat* train_sidx = _data->get_train_sample_idx();
1562     const CvMat* var_idx = _data->get_var_idx();\r
1563 \r
1564     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,\r
1565         train_sidx, var_types, missing, _params ) );\r
1566 \r
1567     __END__;\r
1568 \r
1569     return result;\r
1570 }\r
1571 \r
1572 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )\r
1573 {\r
1574     bool result = false;\r
1575 \r
1576     CV_FUNCNAME( "CvDTree::train" );\r
1577 \r
1578     __BEGIN__;\r
1579 \r
1580     clear();\r
1581     data = _data;\r
1582     data->shared = true;\r
1583     CV_CALL( result = do_train(_subsample_idx));\r
1584 \r
1585     __END__;\r
1586 \r
1587     return result;\r
1588 }\r
1589 \r
1590 \r
1591 bool CvDTree::do_train( const CvMat* _subsample_idx )\r
1592 {\r
1593     bool result = false;\r
1594 \r
1595     CV_FUNCNAME( "CvDTree::do_train" );\r
1596 \r
1597     __BEGIN__;\r
1598 \r
1599     root = data->subsample_data( _subsample_idx );\r
1600 \r
1601     CV_CALL( try_split_node(root));\r
1602 \r
1603     if( data->params.cv_folds > 0 )\r
1604         CV_CALL( prune_cv());\r
1605 \r
1606     if( !data->shared )\r
1607         data->free_train_data();\r
1608 \r
1609     result = true;\r
1610 \r
1611     __END__;\r
1612 \r
1613     return result;\r
1614 }\r
1615 \r
1616 \r
1617 void CvDTree::try_split_node( CvDTreeNode* node )\r
1618 {\r
1619     CvDTreeSplit* best_split = 0;\r
1620     int i, n = node->sample_count, vi;\r
1621     bool can_split = true;\r
1622     double quality_scale;\r
1623 \r
1624     calc_node_value( node );\r
1625 \r
1626     if( node->sample_count <= data->params.min_sample_count ||\r
1627         node->depth >= data->params.max_depth )\r
1628         can_split = false;\r
1629 \r
1630     if( can_split && data->is_classifier )\r
1631     {\r
1632         // check if we have a "pure" node,\r
1633         // we assume that cls_count is filled by calc_node_value()\r
1634         int* cls_count = data->counts->data.i;\r
1635         int nz = 0, m = data->get_num_classes();\r
1636         for( i = 0; i < m; i++ )\r
1637             nz += cls_count[i] != 0;\r
1638         if( nz == 1 ) // there is only one class\r
1639             can_split = false;\r
1640     }\r
1641     else if( can_split )\r
1642     {\r
1643         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )\r
1644             can_split = false;\r
1645     }\r
1646 \r
1647     if( can_split )\r
1648     {\r
1649         best_split = find_best_split(node);\r
1650         // TODO: check the split quality ...\r
1651         node->split = best_split;\r
1652     }\r
1653 \r
1654     if( !can_split || !best_split )\r
1655     {\r
1656         data->free_node_data(node);\r
1657         return;\r
1658     }\r
1659 \r
1660     quality_scale = calc_node_dir( node );\r
1661 \r
1662     if( data->params.use_surrogates )\r
1663     {\r
1664         // find all the surrogate splits\r
1665         // and sort them by their similarity to the primary one\r
1666         for( vi = 0; vi < data->var_count; vi++ )\r
1667         {\r
1668             CvDTreeSplit* split;\r
1669             int ci = data->get_var_type(vi);\r
1670 \r
1671             if( vi == best_split->var_idx )\r
1672                 continue;\r
1673 \r
1674             if( ci >= 0 )\r
1675                 split = find_surrogate_split_cat( node, vi );\r
1676             else\r
1677                 split = find_surrogate_split_ord( node, vi );\r
1678 \r
1679             if( split )\r
1680             {\r
1681                 // insert the split\r
1682                 CvDTreeSplit* prev_split = node->split;\r
1683                 split->quality = (float)(split->quality*quality_scale);\r
1684 \r
1685                 while( prev_split->next &&\r
1686                        prev_split->next->quality > split->quality )\r
1687                     prev_split = prev_split->next;\r
1688                 split->next = prev_split->next;\r
1689                 prev_split->next = split;\r
1690             }\r
1691         }\r
1692     }\r
1693 \r
1694     split_node_data( node );\r
1695     try_split_node( node->left );\r
1696     try_split_node( node->right );\r
1697 }\r
1698 \r
1699 \r
1700 // calculate direction (left(-1),right(1),missing(0))\r
1701 // for each sample using the best split\r
1702 // the function returns scale coefficients for surrogate split quality factors.\r
1703 // the scale is applied to normalize surrogate split quality relatively to the\r
1704 // best (primary) split quality. That is, if a surrogate split is absolutely\r
1705 // identical to the primary split, its quality will be set to the maximum value =\r
1706 // quality of the primary split; otherwise, it will be lower.\r
1707 // besides, the function compute node->maxlr,\r
1708 // minimum possible quality (w/o considering the above mentioned scale)\r
1709 // for a surrogate split. Surrogate splits with quality less than node->maxlr\r
1710 // are not discarded.\r
1711 double CvDTree::calc_node_dir( CvDTreeNode* node )\r
1712 {\r
1713     char* dir = (char*)data->direction->data.ptr;\r
1714     int i, n = node->sample_count, vi = node->split->var_idx;\r
1715     double L, R;\r
1716 \r
1717     assert( !node->split->inversed );\r
1718 \r
1719     if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
1720     {\r
1721         int* labels_buf = data->pred_int_buf;\r
1722         const int* labels = 0;\r
1723         const int* subset = node->split->subset;\r
1724         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1725         if( !data->have_priors )\r
1726         {\r
1727             int sum = 0, sum_abs = 0;\r
1728 \r
1729             for( i = 0; i < n; i++ )\r
1730             {\r
1731                 int idx = labels[i];\r
1732                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?\r
1733                     CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1734                 sum += d; sum_abs += d & 1;\r
1735                 dir[i] = (char)d;\r
1736             }\r
1737 \r
1738             R = (sum_abs + sum) >> 1;\r
1739             L = (sum_abs - sum) >> 1;\r
1740         }\r
1741         else\r
1742         {\r
1743             const double* priors = data->priors_mult->data.db;\r
1744             double sum = 0, sum_abs = 0;\r
1745             int *responses_buf = data->resp_int_buf;\r
1746             const int* responses;\r
1747             data->get_class_labels(node, responses_buf, &responses);\r
1748 \r
1749             for( i = 0; i < n; i++ )\r
1750             {\r
1751                 int idx = labels[i];\r
1752                 double w = priors[responses[i]];\r
1753                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1754                 sum += d*w; sum_abs += (d & 1)*w;\r
1755                 dir[i] = (char)d;\r
1756             }\r
1757 \r
1758             R = (sum_abs + sum) * 0.5;\r
1759             L = (sum_abs - sum) * 0.5;\r
1760         }\r
1761     }\r
1762     else // split on ordered var\r
1763     {\r
1764         int split_point = node->split->ord.split_point;\r
1765         int n1 = node->get_num_valid(vi);\r
1766         \r
1767         float* val_buf = data->pred_float_buf;\r
1768         const float* val = 0;\r
1769         int* sorted_buf = data->pred_int_buf;\r
1770         const int* sorted = 0;\r
1771         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);\r
1772 \r
1773         assert( 0 <= split_point && split_point < n1-1 );\r
1774 \r
1775         if( !data->have_priors )\r
1776         {\r
1777             for( i = 0; i <= split_point; i++ )\r
1778                 dir[sorted[i]] = (char)-1;\r
1779             for( ; i < n1; i++ )\r
1780                 dir[sorted[i]] = (char)1;\r
1781             for( ; i < n; i++ )\r
1782                 dir[sorted[i]] = (char)0;\r
1783 \r
1784             L = split_point-1;\r
1785             R = n1 - split_point + 1;\r
1786         }\r
1787         else\r
1788         {\r
1789             const double* priors = data->priors_mult->data.db;\r
1790             int* responses_buf = data->resp_int_buf;\r
1791             const int* responses = 0;\r
1792             data->get_class_labels(node, responses_buf, &responses);\r
1793             L = R = 0;\r
1794 \r
1795             for( i = 0; i <= split_point; i++ )\r
1796             {\r
1797                 int idx = sorted[i];\r
1798                 double w = priors[responses[idx]];\r
1799                 dir[idx] = (char)-1;\r
1800                 L += w;\r
1801             }\r
1802 \r
1803             for( ; i < n1; i++ )\r
1804             {\r
1805                 int idx = sorted[i];\r
1806                 double w = priors[responses[idx]];\r
1807                 dir[idx] = (char)1;\r
1808                 R += w;\r
1809             }\r
1810 \r
1811             for( ; i < n; i++ )\r
1812                 dir[sorted[i]] = (char)0;\r
1813         }\r
1814     }\r
1815 \r
1816     node->maxlr = MAX( L, R );\r
1817     return node->split->quality/(L + R);\r
1818 }\r
1819 \r
1820 \r
1821 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )\r
1822 {\r
1823     int vi;\r
1824     CvDTreeSplit *best_split = 0, *split = 0, *t;\r
1825 \r
1826     for( vi = 0; vi < data->var_count; vi++ )\r
1827     {\r
1828         int ci = data->get_var_type(vi);\r
1829         if( node->get_num_valid(vi) <= 1 )\r
1830             continue;\r
1831 \r
1832         if( data->is_classifier )\r
1833         {\r
1834             if( ci >= 0 )\r
1835                 split = find_split_cat_class( node, vi );\r
1836             else\r
1837                 split = find_split_ord_class( node, vi );\r
1838         }\r
1839         else\r
1840         {\r
1841             if( ci >= 0 )\r
1842                 split = find_split_cat_reg( node, vi );\r
1843             else\r
1844                 split = find_split_ord_reg( node, vi );\r
1845         }\r
1846 \r
1847         if( split )\r
1848         {\r
1849             if( !best_split || best_split->quality < split->quality )\r
1850                 CV_SWAP( best_split, split, t );\r
1851             if( split )\r
1852                 cvSetRemoveByPtr( data->split_heap, split );\r
1853         }\r
1854     }\r
1855 \r
1856     return best_split;\r
1857 }\r
1858 \r
1859 \r
1860 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )\r
1861 {\r
1862     const float epsilon = FLT_EPSILON*2;\r
1863     int n = node->sample_count;\r
1864     int n1 = node->get_num_valid(vi);\r
1865     int m = data->get_num_classes();\r
1866 \r
1867     float* values_buf = data->pred_float_buf;\r
1868     const float* values = 0;\r
1869     int* indices_buf = data->pred_int_buf;\r
1870     const int* indices = 0;\r
1871     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
1872     int* responses_buf =  data->resp_int_buf;\r
1873     const int* responses = 0;\r
1874     data->get_class_labels( node, responses_buf, &responses );\r
1875 \r
1876     const int* rc0 = data->counts->data.i;\r
1877     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
1878     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
1879     int i, best_i = -1;\r
1880     double lsum2 = 0, rsum2 = 0, best_val = 0;\r
1881     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
1882 \r
1883     // init arrays of class instance counters on both sides of the split\r
1884     for( i = 0; i < m; i++ )\r
1885     {\r
1886         lc[i] = 0;\r
1887         rc[i] = rc0[i];\r
1888     }\r
1889 \r
1890     // compensate for missing values\r
1891     for( i = n1; i < n; i++ )\r
1892     {\r
1893         rc[responses[indices[i]]]--;\r
1894     }\r
1895 \r
1896     if( !priors )\r
1897     {\r
1898         int L = 0, R = n1;\r
1899 \r
1900         for( i = 0; i < m; i++ )\r
1901             rsum2 += (double)rc[i]*rc[i];\r
1902 \r
1903         for( i = 0; i < n1 - 1; i++ )\r
1904         {\r
1905             int idx = responses[indices[i]];\r
1906             int lv, rv;\r
1907             L++; R--;\r
1908             lv = lc[idx]; rv = rc[idx];\r
1909             lsum2 += lv*2 + 1;\r
1910             rsum2 -= rv*2 - 1;\r
1911             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1912 \r
1913             if( values[i] + epsilon < values[i+1] )\r
1914             {\r
1915                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1916                 if( best_val < val )\r
1917                 {\r
1918                     best_val = val;\r
1919                     best_i = i;\r
1920                 }\r
1921             }\r
1922         }\r
1923     }\r
1924     else\r
1925     {\r
1926         double L = 0, R = 0;\r
1927         for( i = 0; i < m; i++ )\r
1928         {\r
1929             double wv = rc[i]*priors[i];\r
1930             R += wv;\r
1931             rsum2 += wv*wv;\r
1932         }\r
1933 \r
1934         for( i = 0; i < n1 - 1; i++ )\r
1935         {\r
1936             int idx = responses[indices[i]];\r
1937             int lv, rv;\r
1938             double p = priors[idx], p2 = p*p;\r
1939             L += p; R -= p;\r
1940             lv = lc[idx]; rv = rc[idx];\r
1941             lsum2 += p2*(lv*2 + 1);\r
1942             rsum2 -= p2*(rv*2 - 1);\r
1943             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1944 \r
1945             if( values[i] + epsilon < values[i+1] )\r
1946             {\r
1947                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1948                 if( best_val < val )\r
1949                 {\r
1950                     best_val = val;\r
1951                     best_i = i;\r
1952                 }\r
1953             }\r
1954         }\r
1955     }\r
1956 \r
1957     return best_i >= 0 ? data->new_split_ord( vi,\r
1958         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
1959         0, (float)best_val ) : 0;\r
1960 }\r
1961 \r
1962 \r
1963 void CvDTree::cluster_categories( const int* vectors, int n, int m,\r
1964                                 int* csums, int k, int* labels )\r
1965 {\r
1966     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm\r
1967     int iters = 0, max_iters = 100;\r
1968     int i, j, idx;\r
1969     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );\r
1970     double *v_weights = buf, *c_weights = buf + k;\r
1971     bool modified = true;\r
1972     CvRNG* r = &data->rng;\r
1973 \r
1974     // assign labels randomly\r
1975     for( i = idx = 0; i < n; i++ )\r
1976     {\r
1977         int sum = 0;\r
1978         const int* v = vectors + i*m;\r
1979         labels[i] = idx++;\r
1980         idx &= idx < k ? -1 : 0;\r
1981 \r
1982         // compute weight of each vector\r
1983         for( j = 0; j < m; j++ )\r
1984             sum += v[j];\r
1985         v_weights[i] = sum ? 1./sum : 0.;\r
1986     }\r
1987 \r
1988     for( i = 0; i < n; i++ )\r
1989     {\r
1990         int i1 = cvRandInt(r) % n;\r
1991         int i2 = cvRandInt(r) % n;\r
1992         CV_SWAP( labels[i1], labels[i2], j );\r
1993     }\r
1994 \r
1995     for( iters = 0; iters <= max_iters; iters++ )\r
1996     {\r
1997         // calculate csums\r
1998         for( i = 0; i < k; i++ )\r
1999         {\r
2000             for( j = 0; j < m; j++ )\r
2001                 csums[i*m + j] = 0;\r
2002         }\r
2003 \r
2004         for( i = 0; i < n; i++ )\r
2005         {\r
2006             const int* v = vectors + i*m;\r
2007             int* s = csums + labels[i]*m;\r
2008             for( j = 0; j < m; j++ )\r
2009                 s[j] += v[j];\r
2010         }\r
2011 \r
2012         // exit the loop here, when we have up-to-date csums\r
2013         if( iters == max_iters || !modified )\r
2014             break;\r
2015 \r
2016         modified = false;\r
2017 \r
2018         // calculate weight of each cluster\r
2019         for( i = 0; i < k; i++ )\r
2020         {\r
2021             const int* s = csums + i*m;\r
2022             int sum = 0;\r
2023             for( j = 0; j < m; j++ )\r
2024                 sum += s[j];\r
2025             c_weights[i] = sum ? 1./sum : 0;\r
2026         }\r
2027 \r
2028         // now for each vector determine the closest cluster\r
2029         for( i = 0; i < n; i++ )\r
2030         {\r
2031             const int* v = vectors + i*m;\r
2032             double alpha = v_weights[i];\r
2033             double min_dist2 = DBL_MAX;\r
2034             int min_idx = -1;\r
2035 \r
2036             for( idx = 0; idx < k; idx++ )\r
2037             {\r
2038                 const int* s = csums + idx*m;\r
2039                 double dist2 = 0., beta = c_weights[idx];\r
2040                 for( j = 0; j < m; j++ )\r
2041                 {\r
2042                     double t = v[j]*alpha - s[j]*beta;\r
2043                     dist2 += t*t;\r
2044                 }\r
2045                 if( min_dist2 > dist2 )\r
2046                 {\r
2047                     min_dist2 = dist2;\r
2048                     min_idx = idx;\r
2049                 }\r
2050             }\r
2051 \r
2052             if( min_idx != labels[i] )\r
2053                 modified = true;\r
2054             labels[i] = min_idx;\r
2055         }\r
2056     }\r
2057 }\r
2058 \r
2059 \r
2060 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )\r
2061 {\r
2062     CvDTreeSplit* split = 0;\r
2063     int ci = data->get_var_type(vi);\r
2064     int n = node->sample_count;\r
2065     int m = data->get_num_classes();\r
2066     int _mi = data->cat_count->data.i[ci], mi = _mi;\r
2067 \r
2068     int* labels_buf = data->pred_int_buf;\r
2069     const int* labels = 0;\r
2070     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2071     int *responses_buf = data->resp_int_buf;\r
2072     const int* responses = 0;\r
2073     data->get_class_labels(node, responses_buf, &responses);\r
2074 \r
2075     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
2076     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
2077     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;\r
2078     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );\r
2079     int* cluster_labels = 0;\r
2080     int** int_ptr = 0;\r
2081     int i, j, k, idx;\r
2082     double L = 0, R = 0;\r
2083     double best_val = 0;\r
2084     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;\r
2085     const double* priors = data->priors_mult->data.db;\r
2086 \r
2087     // init array of counters:\r
2088     // c_{jk} - number of samples that have vi-th input variable = j and response = k.\r
2089     for( j = -1; j < mi; j++ )\r
2090         for( k = 0; k < m; k++ )\r
2091             cjk[j*m + k] = 0;\r
2092 \r
2093     for( i = 0; i < n; i++ )\r
2094     {\r
2095        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];\r
2096        k = responses[i];\r
2097        cjk[j*m + k]++;\r
2098     }\r
2099 \r
2100     if( m > 2 )\r
2101     {\r
2102         if( mi > data->params.max_categories )\r
2103         {\r
2104             mi = MIN(data->params.max_categories, n);\r
2105             cjk += _mi*m;\r
2106             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));\r
2107             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );\r
2108         }\r
2109         subset_i = 1;\r
2110         subset_n = 1 << mi;\r
2111     }\r
2112     else\r
2113     {\r
2114         assert( m == 2 );\r
2115         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );\r
2116         for( j = 0; j < mi; j++ )\r
2117             int_ptr[j] = cjk + j*2 + 1;\r
2118         icvSortIntPtr( int_ptr, mi, 0 );\r
2119         subset_i = 0;\r
2120         subset_n = mi;\r
2121     }\r
2122 \r
2123     for( k = 0; k < m; k++ )\r
2124     {\r
2125         int sum = 0;\r
2126         for( j = 0; j < mi; j++ )\r
2127             sum += cjk[j*m + k];\r
2128         rc[k] = sum;\r
2129         lc[k] = 0;\r
2130     }\r
2131 \r
2132     for( j = 0; j < mi; j++ )\r
2133     {\r
2134         double sum = 0;\r
2135         for( k = 0; k < m; k++ )\r
2136             sum += cjk[j*m + k]*priors[k];\r
2137         c_weights[j] = sum;\r
2138         R += c_weights[j];\r
2139     }\r
2140 \r
2141     for( ; subset_i < subset_n; subset_i++ )\r
2142     {\r
2143         double weight;\r
2144         int* crow;\r
2145         double lsum2 = 0, rsum2 = 0;\r
2146 \r
2147         if( m == 2 )\r
2148             idx = (int)(int_ptr[subset_i] - cjk)/2;\r
2149         else\r
2150         {\r
2151             int graycode = (subset_i>>1)^subset_i;\r
2152             int diff = graycode ^ prevcode;\r
2153 \r
2154             // determine index of the changed bit.\r
2155             Cv32suf u;\r
2156             idx = diff >= (1 << 16) ? 16 : 0;\r
2157             u.f = (float)(((diff >> 16) | diff) & 65535);\r
2158             idx += (u.i >> 23) - 127;\r
2159             subtract = graycode < prevcode;\r
2160             prevcode = graycode;\r
2161         }\r
2162 \r
2163         crow = cjk + idx*m;\r
2164         weight = c_weights[idx];\r
2165         if( weight < FLT_EPSILON )\r
2166             continue;\r
2167 \r
2168         if( !subtract )\r
2169         {\r
2170             for( k = 0; k < m; k++ )\r
2171             {\r
2172                 int t = crow[k];\r
2173                 int lval = lc[k] + t;\r
2174                 int rval = rc[k] - t;\r
2175                 double p = priors[k], p2 = p*p;\r
2176                 lsum2 += p2*lval*lval;\r
2177                 rsum2 += p2*rval*rval;\r
2178                 lc[k] = lval; rc[k] = rval;\r
2179             }\r
2180             L += weight;\r
2181             R -= weight;\r
2182         }\r
2183         else\r
2184         {\r
2185             for( k = 0; k < m; k++ )\r
2186             {\r
2187                 int t = crow[k];\r
2188                 int lval = lc[k] - t;\r
2189                 int rval = rc[k] + t;\r
2190                 double p = priors[k], p2 = p*p;\r
2191                 lsum2 += p2*lval*lval;\r
2192                 rsum2 += p2*rval*rval;\r
2193                 lc[k] = lval; rc[k] = rval;\r
2194             }\r
2195             L -= weight;\r
2196             R += weight;\r
2197         }\r
2198 \r
2199         if( L > FLT_EPSILON && R > FLT_EPSILON )\r
2200         {\r
2201             double val = (lsum2*R + rsum2*L)/((double)L*R);\r
2202             if( best_val < val )\r
2203             {\r
2204                 best_val = val;\r
2205                 best_subset = subset_i;\r
2206             }\r
2207         }\r
2208     }\r
2209 \r
2210     if( best_subset < 0 )\r
2211         return 0;\r
2212 \r
2213     split = data->new_split_cat( vi, (float)best_val );\r
2214 \r
2215     if( m == 2 )\r
2216     {\r
2217         for( i = 0; i <= best_subset; i++ )\r
2218         {\r
2219             idx = (int)(int_ptr[i] - cjk) >> 1;\r
2220             split->subset[idx >> 5] |= 1 << (idx & 31);\r
2221         }\r
2222     }\r
2223     else\r
2224     {\r
2225         for( i = 0; i < _mi; i++ )\r
2226         {\r
2227             idx = cluster_labels ? cluster_labels[i] : i;\r
2228             if( best_subset & (1 << idx) )\r
2229                 split->subset[i >> 5] |= 1 << (i & 31);\r
2230         }\r
2231     }\r
2232 \r
2233     return split;\r
2234 }\r
2235 \r
2236 \r
2237 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )\r
2238 {\r
2239     const float epsilon = FLT_EPSILON*2;\r
2240     int n = node->sample_count;\r
2241     int n1 = node->get_num_valid(vi);\r
2242 \r
2243     float* values_buf = data->pred_float_buf;\r
2244     const float* values = 0;\r
2245     int* indices_buf = data->pred_int_buf;\r
2246     const int* indices = 0;\r
2247     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2248     float* responses_buf =  data->resp_float_buf;\r
2249     const float* responses = 0;\r
2250     data->get_ord_responses( node, responses_buf, &responses );\r
2251 \r
2252     int i, best_i = -1;\r
2253     double best_val = 0, lsum = 0, rsum = node->value*n;\r
2254     int L = 0, R = n1;\r
2255 \r
2256     // compensate for missing values\r
2257     for( i = n1; i < n; i++ )\r
2258         rsum -= responses[indices[i]];\r
2259 \r
2260     // find the optimal split\r
2261     for( i = 0; i < n1 - 1; i++ )\r
2262     {\r
2263         float t = responses[indices[i]];\r
2264         L++; R--;\r
2265         lsum += t;\r
2266         rsum -= t;\r
2267 \r
2268         if( values[i] + epsilon < values[i+1] )\r
2269         {\r
2270             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2271             if( best_val < val )\r
2272             {\r
2273                 best_val = val;\r
2274                 best_i = i;\r
2275             }\r
2276         }\r
2277     }\r
2278 \r
2279     return best_i >= 0 ? data->new_split_ord( vi,\r
2280         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2281         0, (float)best_val ) : 0;\r
2282 }\r
2283 \r
2284 \r
2285 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )\r
2286 {\r
2287     CvDTreeSplit* split;\r
2288     int ci = data->get_var_type(vi);\r
2289     int n = node->sample_count;\r
2290     int mi = data->cat_count->data.i[ci];\r
2291     int* labels_buf = data->pred_int_buf;\r
2292     const int* labels = 0;\r
2293     float* responses_buf = data->resp_float_buf;\r
2294     const float* responses = 0;\r
2295     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2296     data->get_ord_responses(node, responses_buf, &responses);\r
2297 \r
2298     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;\r
2299     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;\r
2300     double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );\r
2301     int i, L = 0, R = 0;\r
2302     double best_val = 0, lsum = 0, rsum = 0;\r
2303     int best_subset = -1, subset_i;\r
2304 \r
2305     for( i = -1; i < mi; i++ )\r
2306         sum[i] = counts[i] = 0;\r
2307 \r
2308     // calculate sum response and weight of each category of the input var\r
2309     for( i = 0; i < n; i++ )\r
2310     {\r
2311         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];\r
2312         double s = sum[idx] + responses[i];\r
2313         int nc = counts[idx] + 1;\r
2314         sum[idx] = s;\r
2315         counts[idx] = nc;\r
2316     }\r
2317 \r
2318     // calculate average response in each category\r
2319     for( i = 0; i < mi; i++ )\r
2320     {\r
2321         R += counts[i];\r
2322         rsum += sum[i];\r
2323         sum[i] /= MAX(counts[i],1);\r
2324         sum_ptr[i] = sum + i;\r
2325     }\r
2326 \r
2327     icvSortDblPtr( sum_ptr, mi, 0 );\r
2328 \r
2329     // revert back to unnormalized sums\r
2330     // (there should be a very little loss of accuracy)\r
2331     for( i = 0; i < mi; i++ )\r
2332         sum[i] *= counts[i];\r
2333 \r
2334     for( subset_i = 0; subset_i < mi-1; subset_i++ )\r
2335     {\r
2336         int idx = (int)(sum_ptr[subset_i] - sum);\r
2337         int ni = counts[idx];\r
2338 \r
2339         if( ni )\r
2340         {\r
2341             double s = sum[idx];\r
2342             lsum += s; L += ni;\r
2343             rsum -= s; R -= ni;\r
2344 \r
2345             if( L && R )\r
2346             {\r
2347                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2348                 if( best_val < val )\r
2349                 {\r
2350                     best_val = val;\r
2351                     best_subset = subset_i;\r
2352                 }\r
2353             }\r
2354         }\r
2355     }\r
2356 \r
2357     if( best_subset < 0 )\r
2358         return 0;\r
2359 \r
2360     split = data->new_split_cat( vi, (float)best_val );\r
2361     for( i = 0; i <= best_subset; i++ )\r
2362     {\r
2363         int idx = (int)(sum_ptr[i] - sum);\r
2364         split->subset[idx >> 5] |= 1 << (idx & 31);\r
2365     }\r
2366 \r
2367     return split;\r
2368 }\r
2369 \r
2370 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )\r
2371 {\r
2372     const float epsilon = FLT_EPSILON*2;\r
2373     const char* dir = (char*)data->direction->data.ptr;\r
2374     int n1 = node->get_num_valid(vi);\r
2375     float* values_buf = data->pred_float_buf;\r
2376     const float* values = 0;\r
2377     int* indices_buf = data->pred_int_buf;\r
2378     const int* indices = 0;\r
2379     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2380     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2381     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2382     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2383     // RR - ... both send to the right\r
2384     int i, best_i = -1, best_inversed = 0;\r
2385     double best_val;\r
2386 \r
2387     if( !data->have_priors )\r
2388     {\r
2389         int LL = 0, RL = 0, LR, RR;\r
2390         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;\r
2391         int sum = 0, sum_abs = 0;\r
2392 \r
2393         for( i = 0; i < n1; i++ )\r
2394         {\r
2395             int d = dir[indices[i]];\r
2396             sum += d; sum_abs += d & 1;\r
2397         }\r
2398 \r
2399         // sum_abs = R + L; sum = R - L\r
2400         RR = (sum_abs + sum) >> 1;\r
2401         LR = (sum_abs - sum) >> 1;\r
2402 \r
2403         // initially all the samples are sent to the right by the surrogate split,\r
2404         // LR of them are sent to the left by primary split, and RR - to the right.\r
2405         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2406         for( i = 0; i < n1 - 1; i++ )\r
2407         {\r
2408             int d = dir[indices[i]];\r
2409 \r
2410             if( d < 0 )\r
2411             {\r
2412                 LL++; LR--;\r
2413                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )\r
2414                 {\r
2415                     best_val = LL + RR;\r
2416                     best_i = i; best_inversed = 0;\r
2417                 }\r
2418             }\r
2419             else if( d > 0 )\r
2420             {\r
2421                 RL++; RR--;\r
2422                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )\r
2423                 {\r
2424                     best_val = RL + LR;\r
2425                     best_i = i; best_inversed = 1;\r
2426                 }\r
2427             }\r
2428         }\r
2429         best_val = _best_val;\r
2430     }\r
2431     else\r
2432     {\r
2433         double LL = 0, RL = 0, LR, RR;\r
2434         double worst_val = node->maxlr;\r
2435         double sum = 0, sum_abs = 0;\r
2436         const double* priors = data->priors_mult->data.db;\r
2437         int* responses_buf = data->resp_int_buf;\r
2438         const int* responses = 0;\r
2439         data->get_class_labels(node, responses_buf, &responses);\r
2440         best_val = worst_val;\r
2441 \r
2442         for( i = 0; i < n1; i++ )\r
2443         {\r
2444             int idx = indices[i];\r
2445             double w = priors[responses[idx]];\r
2446             int d = dir[idx];\r
2447             sum += d*w; sum_abs += (d & 1)*w;\r
2448         }\r
2449 \r
2450         // sum_abs = R + L; sum = R - L\r
2451         RR = (sum_abs + sum)*0.5;\r
2452         LR = (sum_abs - sum)*0.5;\r
2453 \r
2454         // initially all the samples are sent to the right by the surrogate split,\r
2455         // LR of them are sent to the left by primary split, and RR - to the right.\r
2456         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2457         for( i = 0; i < n1 - 1; i++ )\r
2458         {\r
2459             int idx = indices[i];\r
2460             double w = priors[responses[idx]];\r
2461             int d = dir[idx];\r
2462 \r
2463             if( d < 0 )\r
2464             {\r
2465                 LL += w; LR -= w;\r
2466                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )\r
2467                 {\r
2468                     best_val = LL + RR;\r
2469                     best_i = i; best_inversed = 0;\r
2470                 }\r
2471             }\r
2472             else if( d > 0 )\r
2473             {\r
2474                 RL += w; RR -= w;\r
2475                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )\r
2476                 {\r
2477                     best_val = RL + LR;\r
2478                     best_i = i; best_inversed = 1;\r
2479                 }\r
2480             }\r
2481         }\r
2482     }\r
2483 \r
2484     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,\r
2485         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2486         best_inversed, (float)best_val ) : 0;\r
2487 }\r
2488 \r
2489 \r
2490 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )\r
2491 {\r
2492     const char* dir = (char*)data->direction->data.ptr;\r
2493     int n = node->sample_count;\r
2494     int* labels_buf = data->pred_int_buf;\r
2495     const int* labels = 0;\r
2496     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2497     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2498     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2499     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2500     // RR - ... both send to the right\r
2501     CvDTreeSplit* split = data->new_split_cat( vi, 0 );\r
2502     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;\r
2503     double best_val = 0;\r
2504     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;\r
2505     double* rc = lc + mi + 1;\r
2506 \r
2507     for( i = -1; i < mi; i++ )\r
2508         lc[i] = rc[i] = 0;\r
2509 \r
2510     // for each category calculate the weight of samples\r
2511     // sent to the left (lc) and to the right (rc) by the primary split\r
2512     if( !data->have_priors )\r
2513     {\r
2514         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;\r
2515         int* _rc = _lc + mi + 1;\r
2516 \r
2517         for( i = -1; i < mi; i++ )\r
2518             _lc[i] = _rc[i] = 0;\r
2519 \r
2520         for( i = 0; i < n; i++ )\r
2521         {\r
2522             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2523             int d = dir[i];\r
2524             int sum = _lc[idx] + d;\r
2525             int sum_abs = _rc[idx] + (d & 1);\r
2526             _lc[idx] = sum; _rc[idx] = sum_abs;\r
2527         }\r
2528 \r
2529         for( i = 0; i < mi; i++ )\r
2530         {\r
2531             int sum = _lc[i];\r
2532             int sum_abs = _rc[i];\r
2533             lc[i] = (sum_abs - sum) >> 1;\r
2534             rc[i] = (sum_abs + sum) >> 1;\r
2535         }\r
2536     }\r
2537     else\r
2538     {\r
2539         const double* priors = data->priors_mult->data.db;\r
2540         int* responses_buf = data->resp_int_buf;\r
2541         const int* responses = 0;\r
2542         data->get_class_labels(node, responses_buf, &responses);\r
2543 \r
2544         for( i = 0; i < n; i++ )\r
2545         {\r
2546             int idx = labels[i];\r
2547             double w = priors[responses[i]];\r
2548             int d = dir[i];\r
2549             double sum = lc[idx] + d*w;\r
2550             double sum_abs = rc[idx] + (d & 1)*w;\r
2551             lc[idx] = sum; rc[idx] = sum_abs;\r
2552         }\r
2553 \r
2554         for( i = 0; i < mi; i++ )\r
2555         {\r
2556             double sum = lc[i];\r
2557             double sum_abs = rc[i];\r
2558             lc[i] = (sum_abs - sum) * 0.5;\r
2559             rc[i] = (sum_abs + sum) * 0.5;\r
2560         }\r
2561     }\r
2562 \r
2563     // 2. now form the split.\r
2564     // in each category send all the samples to the same direction as majority\r
2565     for( i = 0; i < mi; i++ )\r
2566     {\r
2567         double lval = lc[i], rval = rc[i];\r
2568         if( lval > rval )\r
2569         {\r
2570             split->subset[i >> 5] |= 1 << (i & 31);\r
2571             best_val += lval;\r
2572             l_win++;\r
2573         }\r
2574         else\r
2575             best_val += rval;\r
2576     }\r
2577 \r
2578     split->quality = (float)best_val;\r
2579     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )\r
2580         cvSetRemoveByPtr( data->split_heap, split ), split = 0;\r
2581 \r
2582     return split;\r
2583 }\r
2584 \r
2585 \r
2586 void CvDTree::calc_node_value( CvDTreeNode* node )\r
2587 {\r
2588     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;\r
2589     int* cv_labels_buf = data->cv_lables_buf;\r
2590     const int* cv_labels = 0;\r
2591     data->get_cv_labels(node, cv_labels_buf, &cv_labels);\r
2592 \r
2593     if( data->is_classifier )\r
2594     {\r
2595         // in case of classification tree:\r
2596         //  * node value is the label of the class that has the largest weight in the node.\r
2597         //  * node risk is the weighted number of misclassified samples,\r
2598         //  * j-th cross-validation fold value and risk are calculated as above,\r
2599         //    but using the samples with cv_labels(*)!=j.\r
2600         //  * j-th cross-validation fold error is calculated as the weighted number of\r
2601         //    misclassified samples with cv_labels(*)==j.\r
2602 \r
2603         // compute the number of instances of each class\r
2604         int* cls_count = data->counts->data.i;\r
2605         int* responses_buf = data->resp_int_buf;\r
2606         const int* responses = 0;\r
2607         data->get_class_labels(node, responses_buf, &responses);\r
2608         int m = data->get_num_classes();\r
2609         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));\r
2610         double max_val = -1, total_weight = 0;\r
2611         int max_k = -1;\r
2612         double* priors = data->priors_mult->data.db;\r
2613 \r
2614         for( k = 0; k < m; k++ )\r
2615             cls_count[k] = 0;\r
2616 \r
2617         if( cv_n == 0 )\r
2618         {\r
2619             for( i = 0; i < n; i++ )\r
2620                 cls_count[responses[i]]++;\r
2621         }\r
2622         else\r
2623         {\r
2624             for( j = 0; j < cv_n; j++ )\r
2625                 for( k = 0; k < m; k++ )\r
2626                     cv_cls_count[j*m + k] = 0;\r
2627 \r
2628             for( i = 0; i < n; i++ )\r
2629             {\r
2630                 j = cv_labels[i]; k = responses[i];\r
2631                 cv_cls_count[j*m + k]++;\r
2632             }\r
2633 \r
2634             for( j = 0; j < cv_n; j++ )\r
2635                 for( k = 0; k < m; k++ )\r
2636                     cls_count[k] += cv_cls_count[j*m + k];\r
2637         }\r
2638 \r
2639         if( data->have_priors && node->parent == 0 )\r
2640         {\r
2641             // compute priors_mult from priors, take the sample ratio into account.\r
2642             double sum = 0;\r
2643             for( k = 0; k < m; k++ )\r
2644             {\r
2645                 int n_k = cls_count[k];\r
2646                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);\r
2647                 sum += priors[k];\r
2648             }\r
2649             sum = 1./sum;\r
2650             for( k = 0; k < m; k++ )\r
2651                 priors[k] *= sum;\r
2652         }\r
2653 \r
2654         for( k = 0; k < m; k++ )\r
2655         {\r
2656             double val = cls_count[k]*priors[k];\r
2657             total_weight += val;\r
2658             if( max_val < val )\r
2659             {\r
2660                 max_val = val;\r
2661                 max_k = k;\r
2662             }\r
2663         }\r
2664 \r
2665         node->class_idx = max_k;\r
2666         node->value = data->cat_map->data.i[\r
2667             data->cat_ofs->data.i[data->cat_var_count] + max_k];\r
2668         node->node_risk = total_weight - max_val;\r
2669 \r
2670         for( j = 0; j < cv_n; j++ )\r
2671         {\r
2672             double sum_k = 0, sum = 0, max_val_k = 0;\r
2673             max_val = -1; max_k = -1;\r
2674 \r
2675             for( k = 0; k < m; k++ )\r
2676             {\r
2677                 double w = priors[k];\r
2678                 double val_k = cv_cls_count[j*m + k]*w;\r
2679                 double val = cls_count[k]*w - val_k;\r
2680                 sum_k += val_k;\r
2681                 sum += val;\r
2682                 if( max_val < val )\r
2683                 {\r
2684                     max_val = val;\r
2685                     max_val_k = val_k;\r
2686                     max_k = k;\r
2687                 }\r
2688             }\r
2689 \r
2690             node->cv_Tn[j] = INT_MAX;\r
2691             node->cv_node_risk[j] = sum - max_val;\r
2692             node->cv_node_error[j] = sum_k - max_val_k;\r
2693         }\r
2694     }\r
2695     else\r
2696     {\r
2697         // in case of regression tree:\r
2698         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,\r
2699         //    n is the number of samples in the node.\r
2700         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)\r
2701         //  * j-th cross-validation fold value and risk are calculated as above,\r
2702         //    but using the samples with cv_labels(*)!=j.\r
2703         //  * j-th cross-validation fold error is calculated\r
2704         //    using samples with cv_labels(*)==j as the test subset:\r
2705         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),\r
2706         //    where node_value_j is the node value calculated\r
2707         //    as described in the previous bullet, and summation is done\r
2708         //    over the samples with cv_labels(*)==j.\r
2709 \r
2710         double sum = 0, sum2 = 0;\r
2711         float* values_buf = data->resp_float_buf;\r
2712         const float* values = 0;\r
2713         data->get_ord_responses(node, values_buf, &values);\r
2714         double *cv_sum = 0, *cv_sum2 = 0;\r
2715         int* cv_count = 0;\r
2716 \r
2717         if( cv_n == 0 )\r
2718         {\r
2719             for( i = 0; i < n; i++ )\r
2720             {\r
2721                 double t = values[i];\r
2722                 sum += t;\r
2723                 sum2 += t*t;\r
2724             }\r
2725         }\r
2726         else\r
2727         {\r
2728             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );\r
2729             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );\r
2730             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );\r
2731 \r
2732             for( j = 0; j < cv_n; j++ )\r
2733             {\r
2734                 cv_sum[j] = cv_sum2[j] = 0.;\r
2735                 cv_count[j] = 0;\r
2736             }\r
2737 \r
2738             for( i = 0; i < n; i++ )\r
2739             {\r
2740                 j = cv_labels[i];\r
2741                 double t = values[i];\r
2742                 double s = cv_sum[j] + t;\r
2743                 double s2 = cv_sum2[j] + t*t;\r
2744                 int nc = cv_count[j] + 1;\r
2745                 cv_sum[j] = s;\r
2746                 cv_sum2[j] = s2;\r
2747                 cv_count[j] = nc;\r
2748             }\r
2749 \r
2750             for( j = 0; j < cv_n; j++ )\r
2751             {\r
2752                 sum += cv_sum[j];\r
2753                 sum2 += cv_sum2[j];\r
2754             }\r
2755         }\r
2756 \r
2757         node->node_risk = sum2 - (sum/n)*sum;\r
2758         node->value = sum/n;\r
2759 \r
2760         for( j = 0; j < cv_n; j++ )\r
2761         {\r
2762             double s = cv_sum[j], si = sum - s;\r
2763             double s2 = cv_sum2[j], s2i = sum2 - s2;\r
2764             int c = cv_count[j], ci = n - c;\r
2765             double r = si/MAX(ci,1);\r
2766             node->cv_node_risk[j] = s2i - r*r*ci;\r
2767             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;\r
2768             node->cv_Tn[j] = INT_MAX;\r
2769         }\r
2770     }\r
2771 }\r
2772 \r
2773 \r
2774 void CvDTree::complete_node_dir( CvDTreeNode* node )\r
2775 {\r
2776     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;\r
2777     int nz = n - node->get_num_valid(node->split->var_idx);\r
2778     char* dir = (char*)data->direction->data.ptr;\r
2779 \r
2780     // try to complete direction using surrogate splits\r
2781     if( nz && data->params.use_surrogates )\r
2782     {\r
2783         CvDTreeSplit* split = node->split->next;\r
2784         for( ; split != 0 && nz; split = split->next )\r
2785         {\r
2786             int inversed_mask = split->inversed ? -1 : 0;\r
2787             vi = split->var_idx;\r
2788 \r
2789             if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
2790             {\r
2791                 int* labels_buf = data->pred_int_buf;\r
2792                 const int* labels = 0;\r
2793                 data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2794                 const int* subset = split->subset;\r
2795 \r
2796                 for( i = 0; i < n; i++ )\r
2797                 {\r
2798                     int idx = labels[i];\r
2799                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))\r
2800                         \r
2801                     {\r
2802                         int d = CV_DTREE_CAT_DIR(idx,subset);\r
2803                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
2804                         if( --nz )\r
2805                             break;\r
2806                     }\r
2807                 }\r
2808             }\r
2809             else // split on ordered var\r
2810             {\r
2811                 float* values_buf = data->pred_float_buf;\r
2812                 const float* values = 0;\r
2813                 int* indices_buf = data->pred_int_buf;\r
2814                 const int* indices = 0;\r
2815                 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2816                 int split_point = split->ord.split_point;\r
2817                 int n1 = node->get_num_valid(vi);\r
2818 \r
2819                 assert( 0 <= split_point && split_point < n-1 );\r
2820 \r
2821                 for( i = 0; i < n1; i++ )\r
2822                 {\r
2823                     int idx = indices[i];\r
2824                     if( !dir[idx] )\r
2825                     {\r
2826                         int d = i <= split_point ? -1 : 1;\r
2827                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);\r
2828                         if( --nz )\r
2829                             break;\r
2830                     }\r
2831                 }\r
2832             }\r
2833         }\r
2834     }\r
2835 \r
2836     // find the default direction for the rest\r
2837     if( nz )\r
2838     {\r
2839         for( i = nr = 0; i < n; i++ )\r
2840             nr += dir[i] > 0;\r
2841         nl = n - nr - nz;\r
2842         d0 = nl > nr ? -1 : nr > nl;\r
2843     }\r
2844 \r
2845     // make sure that every sample is directed either to the left or to the right\r
2846     for( i = 0; i < n; i++ )\r
2847     {\r
2848         int d = dir[i];\r
2849         if( !d )\r
2850         {\r
2851             d = d0;\r
2852             if( !d )\r
2853                 d = d1, d1 = -d1;\r
2854         }\r
2855         d = d > 0;\r
2856         dir[i] = (char)d; // remap (-1,1) to (0,1)\r
2857     }\r
2858 }\r
2859 \r
2860 \r
2861 void CvDTree::split_node_data( CvDTreeNode* node )\r
2862 {\r
2863     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;\r
2864     char* dir = (char*)data->direction->data.ptr;\r
2865     CvDTreeNode *left = 0, *right = 0;\r
2866     int* new_idx = data->split_buf->data.i;\r
2867     int new_buf_idx = data->get_child_buf_idx( node );\r
2868     int work_var_count = data->get_work_var_count();\r
2869     CvMat* buf = data->buf;\r
2870     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));\r
2871 \r
2872     complete_node_dir(node);\r
2873 \r
2874     for( i = nl = nr = 0; i < n; i++ )\r
2875     {\r
2876         int d = dir[i];\r
2877         // initialize new indices for splitting ordered variables\r
2878         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li\r
2879         nr += d;\r
2880         nl += d^1;\r
2881     }\r
2882 \r
2883 \r
2884     bool split_input_data;\r
2885     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );\r
2886     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );\r
2887 \r
2888     split_input_data = node->depth + 1 < data->params.max_depth &&\r
2889         (node->left->sample_count > data->params.min_sample_count ||\r
2890         node->right->sample_count > data->params.min_sample_count);\r
2891 \r
2892     // split ordered variables, keep both halves sorted.\r
2893     for( vi = 0; vi < data->var_count; vi++ )\r
2894     {\r
2895         int ci = data->get_var_type(vi);\r
2896         int n1 = node->get_num_valid(vi);\r
2897         int *src_idx_buf = data->pred_int_buf;\r
2898         const int* src_idx = 0;\r
2899         float *src_val_buf = data->pred_float_buf;\r
2900         const float* src_val = 0;\r
2901         \r
2902         if( ci >= 0 || !split_input_data )\r
2903             continue;\r
2904 \r
2905         data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
2906 \r
2907         for(i = 0; i < n; i++)\r
2908             temp_buf[i] = src_idx[i];\r
2909 \r
2910         if (data->is_buf_16u)\r
2911         {\r
2912             unsigned short *ldst, *rdst, *ldst0, *rdst0;\r
2913             //unsigned short tl, tr;\r
2914             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
2915                 vi*scount + left->offset);\r
2916             rdst0 = rdst = (unsigned short*)(ldst + nl);\r
2917 \r
2918             // split sorted\r
2919             for( i = 0; i < n1; i++ )\r
2920             {\r
2921                 int idx = temp_buf[i];\r
2922                 int d = dir[idx];\r
2923                 idx = new_idx[idx];\r
2924                 if (d)\r
2925                 {\r
2926                     *rdst = (unsigned short)idx;\r
2927                     rdst++;\r
2928                 }\r
2929                 else\r
2930                 {\r
2931                     *ldst = (unsigned short)idx;\r
2932                     ldst++;\r
2933                 }\r
2934             }\r
2935 \r
2936             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2937             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2938 \r
2939             // split missing\r
2940             for( ; i < n; i++ )\r
2941             {\r
2942                 int idx = temp_buf[i];\r
2943                 int d = dir[idx];\r
2944                 idx = new_idx[idx];\r
2945                 if (d)\r
2946                 {\r
2947                     *rdst = (unsigned short)idx;\r
2948                     rdst++;\r
2949                 }\r
2950                 else\r
2951                 {\r
2952                     *ldst = (unsigned short)idx;\r
2953                     ldst++;\r
2954                 }\r
2955             }\r
2956         }\r
2957         else\r
2958         {\r
2959             int *ldst0, *ldst, *rdst0, *rdst;\r
2960             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + \r
2961                 vi*scount + left->offset;\r
2962             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + \r
2963                 vi*scount + right->offset;\r
2964 \r
2965             // split sorted\r
2966             for( i = 0; i < n1; i++ )\r
2967             {\r
2968                 int idx = temp_buf[i];\r
2969                 int d = dir[idx];\r
2970                 idx = new_idx[idx];\r
2971                 if (d)\r
2972                 {\r
2973                     *rdst = idx;\r
2974                     rdst++;\r
2975                 }\r
2976                 else\r
2977                 {\r
2978                     *ldst = idx;\r
2979                     ldst++;\r
2980                 }\r
2981             }\r
2982 \r
2983             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2984             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2985 \r
2986             // split missing\r
2987             for( ; i < n; i++ )\r
2988             {\r
2989                 int idx = temp_buf[i];\r
2990                 int d = dir[idx];\r
2991                 idx = new_idx[idx];\r
2992                 if (d)\r
2993                 {\r
2994                     *rdst = idx;\r
2995                     rdst++;\r
2996                 }\r
2997                 else\r
2998                 {\r
2999                     *ldst = idx;\r
3000                     ldst++;\r
3001                 }\r
3002             }\r
3003         }\r
3004     }\r
3005 \r
3006     // split categorical vars, responses and cv_labels using new_idx relocation table\r
3007     for( vi = 0; vi < work_var_count; vi++ )\r
3008     {\r
3009         int ci = data->get_var_type(vi);\r
3010         int n1 = node->get_num_valid(vi), nr1 = 0;\r
3011         \r
3012         if( ci < 0 || (vi < data->var_count && !split_input_data) )\r
3013             continue;\r
3014 \r
3015         int *src_lbls_buf = data->pred_int_buf;\r
3016         const int* src_lbls = 0;\r
3017         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);\r
3018 \r
3019         for(i = 0; i < n; i++)\r
3020             temp_buf[i] = src_lbls[i];\r
3021 \r
3022         if (data->is_buf_16u)\r
3023         {\r
3024             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + \r
3025                 vi*scount + left->offset);\r
3026             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + \r
3027                 vi*scount + right->offset);\r
3028             \r
3029             for( i = 0; i < n; i++ )\r
3030             {\r
3031                 int d = dir[i];\r
3032                 int idx = temp_buf[i];\r
3033                 if (d)\r
3034                 {\r
3035                     *rdst = (unsigned short)idx;\r
3036                     rdst++;\r
3037                     nr1 += (idx != 65535 )&d;\r
3038                 }\r
3039                 else\r
3040                 {\r
3041                     *ldst = (unsigned short)idx;\r
3042                     ldst++;\r
3043                 }\r
3044             }\r
3045 \r
3046             if( vi < data->var_count )\r
3047             {\r
3048                 left->set_num_valid(vi, n1 - nr1);\r
3049                 right->set_num_valid(vi, nr1);\r
3050             }\r
3051         }\r
3052         else\r
3053         {\r
3054             int *ldst = buf->data.i + left->buf_idx*buf->cols + \r
3055                 vi*scount + left->offset;\r
3056             int *rdst = buf->data.i + right->buf_idx*buf->cols + \r
3057                 vi*scount + right->offset;\r
3058             \r
3059             for( i = 0; i < n; i++ )\r
3060             {\r
3061                 int d = dir[i];\r
3062                 int idx = temp_buf[i];\r
3063                 if (d)\r
3064                 {\r
3065                     *rdst = idx;\r
3066                     rdst++;\r
3067                     nr1 += (idx >= 0)&d;\r
3068                 }\r
3069                 else\r
3070                 {\r
3071                     *ldst = idx;\r
3072                     ldst++;\r
3073                 }\r
3074                 \r
3075             }\r
3076 \r
3077             if( vi < data->var_count )\r
3078             {\r
3079                 left->set_num_valid(vi, n1 - nr1);\r
3080                 right->set_num_valid(vi, nr1);\r
3081             }\r
3082         }        \r
3083     }\r
3084 \r
3085 \r
3086     // split sample indices\r
3087     int *sample_idx_src_buf = data->sample_idx_buf;\r
3088     const int* sample_idx_src = 0;\r
3089     data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);\r
3090 \r
3091     for(i = 0; i < n; i++)\r
3092         temp_buf[i] = sample_idx_src[i];\r
3093 \r
3094     int pos = data->get_work_var_count();\r
3095     if (data->is_buf_16u)\r
3096     {\r
3097         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
3098             pos*scount + left->offset);\r
3099         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + \r
3100             pos*scount + right->offset);\r
3101         for (i = 0; i < n; i++)\r
3102         {\r
3103             int d = dir[i];\r
3104             unsigned short idx = (unsigned short)temp_buf[i];\r
3105             if (d)\r
3106             {\r
3107                 *rdst = idx;\r
3108                 rdst++;\r
3109             }\r
3110             else\r
3111             {\r
3112                 *ldst = idx;\r
3113                 ldst++;\r
3114             }\r
3115         }\r
3116     }\r
3117     else\r
3118     {\r
3119         int* ldst = buf->data.i + left->buf_idx*buf->cols + \r
3120             pos*scount + left->offset;\r
3121         int* rdst = buf->data.i + right->buf_idx*buf->cols + \r
3122             pos*scount + right->offset;\r
3123         for (i = 0; i < n; i++)\r
3124         {\r
3125             int d = dir[i];\r
3126             int idx = temp_buf[i];\r
3127             if (d)\r
3128             {\r
3129                 *rdst = idx;\r
3130                 rdst++;\r
3131             }\r
3132             else\r
3133             {\r
3134                 *ldst = idx;\r
3135                 ldst++;\r
3136             }\r
3137         }\r
3138     }\r
3139     \r
3140     // deallocate the parent node data that is not needed anymore\r
3141     data->free_node_data(node);    \r
3142 }\r
3143 \r
3144 float CvDTree::calc_error( CvMLData* _data, int type )\r
3145 {\r
3146     CV_FUNCNAME( "CvDTree::calc_error" );\r
3147 \r
3148     __BEGIN__;\r
3149     float err = 0;\r
3150     const CvMat* values = _data->get_values();
3151     const CvMat* response = _data->get_response();
3152     const CvMat* missing = _data->get_missing();
3153     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
3154     const CvMat* var_types = _data->get_var_types();\r
3155     int* sidx = sample_idx->data.i;\r
3156     int r_step = CV_IS_MAT_CONT(response->type) ?\r
3157                 1 : response->step / CV_ELEM_SIZE(response->type);\r
3158     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;\r
3159     if ( is_classifier )\r
3160     {\r
3161         for( int i = 0; i < sample_idx->cols; i++ )\r
3162         {\r
3163             CvMat sample, miss;\r
3164             int si = sidx[i];\r
3165             cvGetRow( values, &sample, si ); \r
3166             if( missing ) \r
3167                 cvGetRow( missing, &miss, si );             \r
3168             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3169             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
3170             err += d;\r
3171         }\r
3172         err = err / (float)sample_idx->cols * 100;\r
3173     }\r
3174     else\r
3175     {\r
3176         for( int i = 0; i < sample_idx->cols; i++ )\r
3177         {\r
3178             CvMat sample, miss;\r
3179             int si = sidx[i];\r
3180             cvGetRow( data, &sample, sidx[i] ); \r
3181             if( missing ) \r
3182                 cvGetRow( missing, &miss, si );             \r
3183             float r = (float)predict( &sample, missing ? &miss : 0 )->value;\r
3184             float d = r - response->data.fl[si];\r
3185             err += d*d;\r
3186         }\r
3187         err = err / (float)sample_idx->cols;    \r
3188     }\r
3189     return err;\r
3190 \r
3191     __END__;\r
3192 }\r
3193 \r
3194 void CvDTree::prune_cv()\r
3195 {\r
3196     CvMat* ab = 0;\r
3197     CvMat* temp = 0;\r
3198     CvMat* err_jk = 0;\r
3199 \r
3200     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.\r
3201     // 2. choose the best tree index (if need, apply 1SE rule).\r
3202     // 3. store the best index and cut the branches.\r
3203 \r
3204     CV_FUNCNAME( "CvDTree::prune_cv" );\r
3205 \r
3206     __BEGIN__;\r
3207 \r
3208     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;\r
3209     // currently, 1SE for regression is not implemented\r
3210     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;\r
3211     double* err;\r
3212     double min_err = 0, min_err_se = 0;\r
3213     int min_idx = -1;\r
3214 \r
3215     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));\r
3216 \r
3217     // build the main tree sequence, calculate alpha's\r
3218     for(;;tree_count++)\r
3219     {\r
3220         double min_alpha = update_tree_rnc(tree_count, -1);\r
3221         if( cut_tree(tree_count, -1, min_alpha) )\r
3222             break;\r
3223 \r
3224         if( ab->cols <= tree_count )\r
3225         {\r
3226             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));\r
3227             for( ti = 0; ti < ab->cols; ti++ )\r
3228                 temp->data.db[ti] = ab->data.db[ti];\r
3229             cvReleaseMat( &ab );\r
3230             ab = temp;\r
3231             temp = 0;\r
3232         }\r
3233 \r
3234         ab->data.db[tree_count] = min_alpha;\r
3235     }\r
3236 \r
3237     ab->data.db[0] = 0.;\r
3238 \r
3239     if( tree_count > 0 )\r
3240     {\r
3241         for( ti = 1; ti < tree_count-1; ti++ )\r
3242             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);\r
3243         ab->data.db[tree_count-1] = DBL_MAX*0.5;\r
3244 \r
3245         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));\r
3246         err = err_jk->data.db;\r
3247 \r
3248         for( j = 0; j < cv_n; j++ )\r
3249         {\r
3250             int tj = 0, tk = 0;\r
3251             for( ; tk < tree_count; tj++ )\r
3252             {\r
3253                 double min_alpha = update_tree_rnc(tj, j);\r
3254                 if( cut_tree(tj, j, min_alpha) )\r
3255                     min_alpha = DBL_MAX;\r
3256 \r
3257                 for( ; tk < tree_count; tk++ )\r
3258                 {\r
3259                     if( ab->data.db[tk] > min_alpha )\r
3260                         break;\r
3261                     err[j*tree_count + tk] = root->tree_error;\r
3262                 }\r
3263             }\r
3264         }\r
3265 \r
3266         for( ti = 0; ti < tree_count; ti++ )\r
3267         {\r
3268             double sum_err = 0;\r
3269             for( j = 0; j < cv_n; j++ )\r
3270                 sum_err += err[j*tree_count + ti];\r
3271             if( ti == 0 || sum_err < min_err )\r
3272             {\r
3273                 min_err = sum_err;\r
3274                 min_idx = ti;\r
3275                 if( use_1se )\r
3276                     min_err_se = sqrt( sum_err*(n - sum_err) );\r
3277             }\r
3278             else if( sum_err < min_err + min_err_se )\r
3279                 min_idx = ti;\r
3280         }\r
3281     }\r
3282 \r
3283     pruned_tree_idx = min_idx;\r
3284     free_prune_data(data->params.truncate_pruned_tree != 0);\r
3285 \r
3286     __END__;\r
3287 \r
3288     cvReleaseMat( &err_jk );\r
3289     cvReleaseMat( &ab );\r
3290     cvReleaseMat( &temp );\r
3291 }\r
3292 \r
3293 \r
3294 double CvDTree::update_tree_rnc( int T, int fold )\r
3295 {\r
3296     CvDTreeNode* node = root;\r
3297     double min_alpha = DBL_MAX;\r
3298 \r
3299     for(;;)\r
3300     {\r
3301         CvDTreeNode* parent;\r
3302         for(;;)\r
3303         {\r
3304             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3305             if( t <= T || !node->left )\r
3306             {\r
3307                 node->complexity = 1;\r
3308                 node->tree_risk = node->node_risk;\r
3309                 node->tree_error = 0.;\r
3310                 if( fold >= 0 )\r
3311                 {\r
3312                     node->tree_risk = node->cv_node_risk[fold];\r
3313                     node->tree_error = node->cv_node_error[fold];\r
3314                 }\r
3315                 break;\r
3316             }\r
3317             node = node->left;\r
3318         }\r
3319 \r
3320         for( parent = node->parent; parent && parent->right == node;\r
3321             node = parent, parent = parent->parent )\r
3322         {\r
3323             parent->complexity += node->complexity;\r
3324             parent->tree_risk += node->tree_risk;\r
3325             parent->tree_error += node->tree_error;\r
3326 \r
3327             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)\r
3328                 - parent->tree_risk)/(parent->complexity - 1);\r
3329             min_alpha = MIN( min_alpha, parent->alpha );\r
3330         }\r
3331 \r
3332         if( !parent )\r
3333             break;\r
3334 \r
3335         parent->complexity = node->complexity;\r
3336         parent->tree_risk = node->tree_risk;\r
3337         parent->tree_error = node->tree_error;\r
3338         node = parent->right;\r
3339     }\r
3340 \r
3341     return min_alpha;\r
3342 }\r
3343 \r
3344 \r
3345 int CvDTree::cut_tree( int T, int fold, double min_alpha )\r
3346 {\r
3347     CvDTreeNode* node = root;\r
3348     if( !node->left )\r
3349         return 1;\r
3350 \r
3351     for(;;)\r
3352     {\r
3353         CvDTreeNode* parent;\r
3354         for(;;)\r
3355         {\r
3356             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3357             if( t <= T || !node->left )\r
3358                 break;\r
3359             if( node->alpha <= min_alpha + FLT_EPSILON )\r
3360             {\r
3361                 if( fold >= 0 )\r
3362                     node->cv_Tn[fold] = T;\r
3363                 else\r
3364                     node->Tn = T;\r
3365                 if( node == root )\r
3366                     return 1;\r
3367                 break;\r
3368             }\r
3369             node = node->left;\r
3370         }\r
3371 \r
3372         for( parent = node->parent; parent && parent->right == node;\r
3373             node = parent, parent = parent->parent )\r
3374             ;\r
3375 \r
3376         if( !parent )\r
3377             break;\r
3378 \r
3379         node = parent->right;\r
3380     }\r
3381 \r
3382     return 0;\r
3383 }\r
3384 \r
3385 \r
3386 void CvDTree::free_prune_data(bool cut_tree)\r
3387 {\r
3388     CvDTreeNode* node = root;\r
3389 \r
3390     for(;;)\r
3391     {\r
3392         CvDTreeNode* parent;\r
3393         for(;;)\r
3394         {\r
3395             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )\r
3396             // as we will clear the whole cross-validation heap at the end\r
3397             node->cv_Tn = 0;\r
3398             node->cv_node_error = node->cv_node_risk = 0;\r
3399             if( !node->left )\r
3400                 break;\r
3401             node = node->left;\r
3402         }\r
3403 \r
3404         for( parent = node->parent; parent && parent->right == node;\r
3405             node = parent, parent = parent->parent )\r
3406         {\r
3407             if( cut_tree && parent->Tn <= pruned_tree_idx )\r
3408             {\r
3409                 data->free_node( parent->left );\r
3410                 data->free_node( parent->right );\r
3411                 parent->left = parent->right = 0;\r
3412             }\r
3413         }\r
3414 \r
3415         if( !parent )\r
3416             break;\r
3417 \r
3418         node = parent->right;\r
3419     }\r
3420 \r
3421     if( data->cv_heap )\r
3422         cvClearSet( data->cv_heap );\r
3423 }\r
3424 \r
3425 \r
3426 void CvDTree::free_tree()\r
3427 {\r
3428     if( root && data && data->shared )\r
3429     {\r
3430         pruned_tree_idx = INT_MIN;\r
3431         free_prune_data(true);\r
3432         data->free_node(root);\r
3433         root = 0;\r
3434     }\r
3435 }\r
3436 \r
3437 CvDTreeNode* CvDTree::predict( const CvMat* _sample,\r
3438     const CvMat* _missing, bool preprocessed_input ) const\r
3439 {\r
3440     CvDTreeNode* result = 0;\r
3441     int* catbuf = 0;\r
3442 \r
3443     CV_FUNCNAME( "CvDTree::predict" );\r
3444 \r
3445     __BEGIN__;\r
3446 \r
3447     int i, step, mstep = 0;\r
3448     const float* sample;\r
3449     const uchar* m = 0;\r
3450     CvDTreeNode* node = root;\r
3451     const int* vtype;\r
3452     const int* vidx;\r
3453     const int* cmap;\r
3454     const int* cofs;\r
3455 \r
3456     if( !node )\r
3457         CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
3458 \r
3459     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
3460         (_sample->cols != 1 && _sample->rows != 1) ||\r
3461         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
3462         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
3463             CV_ERROR( CV_StsBadArg,\r
3464         "the input sample must be 1d floating-point vector with the same "\r
3465         "number of elements as the total number of variables used for training" );\r
3466 \r
3467     sample = _sample->data.fl;\r
3468     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
3469 \r
3470     if( data->cat_count && !preprocessed_input ) // cache for categorical variables\r
3471     {\r
3472         int n = data->cat_count->cols;\r
3473         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));\r
3474         for( i = 0; i < n; i++ )\r
3475             catbuf[i] = -1;\r
3476     }\r
3477 \r
3478     if( _missing )\r
3479     {\r
3480         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
3481         !CV_ARE_SIZES_EQ(_missing, _sample) )\r
3482             CV_ERROR( CV_StsBadArg,\r
3483         "the missing data mask must be 8-bit vector of the same size as input sample" );\r
3484         m = _missing->data.ptr;\r
3485         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
3486     }\r
3487 \r
3488     vtype = data->var_type->data.i;\r
3489     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
3490     cmap = data->cat_map ? data->cat_map->data.i : 0;\r
3491     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;\r
3492 \r
3493     while( node->Tn > pruned_tree_idx && node->left )\r
3494     {\r
3495         CvDTreeSplit* split = node->split;\r
3496         int dir = 0;\r
3497         for( ; !dir && split != 0; split = split->next )\r
3498         {\r
3499             int vi = split->var_idx;\r
3500             int ci = vtype[vi];\r
3501             i = vidx ? vidx[vi] : vi;\r
3502             float val = sample[i*step];\r
3503             if( m && m[i*mstep] )\r
3504                 continue;\r
3505             if( ci < 0 ) // ordered\r
3506                 dir = val <= split->ord.c ? -1 : 1;\r
3507             else // categorical\r
3508             {\r
3509                 int c;\r
3510                 if( preprocessed_input )\r
3511                     c = cvRound(val);\r
3512                 else\r
3513                 {\r
3514                     c = catbuf[ci];\r
3515                     if( c < 0 )\r
3516                     {\r
3517                         int a = c = cofs[ci];\r
3518                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];\r
3519                         \r
3520                         int ival = cvRound(val);\r
3521                         if( ival != val )\r
3522                             CV_ERROR( CV_StsBadArg,\r
3523                             "one of input categorical variable is not an integer" );\r
3524                         \r
3525                         int sh = 0;\r
3526                         while( a < b )\r
3527                         {\r
3528                             sh++;\r
3529                             c = (a + b) >> 1;\r
3530                             if( ival < cmap[c] )\r
3531                                 b = c;\r
3532                             else if( ival > cmap[c] )\r
3533                                 a = c+1;\r
3534                             else\r
3535                                 break;\r
3536                         }\r
3537 \r
3538                         if( c < 0 || ival != cmap[c] )\r
3539                             continue;\r
3540 \r
3541                         catbuf[ci] = c -= cofs[ci];\r
3542                     }\r
3543                 }\r
3544                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;\r
3545                 dir = CV_DTREE_CAT_DIR(c, split->subset);\r
3546             }\r
3547 \r
3548             if( split->inversed )\r
3549                 dir = -dir;\r
3550         }\r
3551 \r
3552         if( !dir )\r
3553         {\r
3554             double diff = node->right->sample_count - node->left->sample_count;\r
3555             dir = diff < 0 ? -1 : 1;\r
3556         }\r
3557         node = dir < 0 ? node->left : node->right;\r
3558     }\r
3559 \r
3560     result = node;\r
3561 \r
3562     __END__;\r
3563 \r
3564     return result;\r
3565 }\r
3566 \r
3567 \r
3568 const CvMat* CvDTree::get_var_importance()\r
3569 {\r
3570     if( !var_importance )\r
3571     {\r
3572         CvDTreeNode* node = root;\r
3573         double* importance;\r
3574         if( !node )\r
3575             return 0;\r
3576         var_importance = cvCreateMat( 1, data->var_count, CV_64F );\r
3577         cvZero( var_importance );\r
3578         importance = var_importance->data.db;\r
3579 \r
3580         for(;;)\r
3581         {\r
3582             CvDTreeNode* parent;\r
3583             for( ;; node = node->left )\r
3584             {\r
3585                 CvDTreeSplit* split = node->split;\r
3586 \r
3587                 if( !node->left || node->Tn <= pruned_tree_idx )\r
3588                     break;\r
3589 \r
3590                 for( ; split != 0; split = split->next )\r
3591                     importance[split->var_idx] += split->quality;\r
3592             }\r
3593 \r
3594             for( parent = node->parent; parent && parent->right == node;\r
3595                 node = parent, parent = parent->parent )\r
3596                 ;\r
3597 \r
3598             if( !parent )\r
3599                 break;\r
3600 \r
3601             node = parent->right;\r
3602         }\r
3603 \r
3604         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );\r
3605     }\r
3606 \r
3607     return var_importance;\r
3608 }\r
3609 \r
3610 \r
3611 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )\r
3612 {\r
3613     int ci;\r
3614 \r
3615     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );\r
3616     cvWriteInt( fs, "var", split->var_idx );\r
3617     cvWriteReal( fs, "quality", split->quality );\r
3618 \r
3619     ci = data->get_var_type(split->var_idx);\r
3620     if( ci >= 0 ) // split on a categorical var\r
3621     {\r
3622         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;\r
3623         for( i = 0; i < n; i++ )\r
3624             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;\r
3625 \r
3626         // ad-hoc rule when to use inverse categorical split notation\r
3627         // to achieve more compact and clear representation\r
3628         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;\r
3629 \r
3630         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?\r
3631                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );\r
3632 \r
3633         for( i = 0; i < n; i++ )\r
3634         {\r
3635             int dir = CV_DTREE_CAT_DIR(i,split->subset);\r
3636             if( dir*default_dir < 0 )\r
3637                 cvWriteInt( fs, 0, i );\r
3638         }\r
3639         cvEndWriteStruct( fs );\r
3640     }\r
3641     else\r
3642         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );\r
3643 \r
3644     cvEndWriteStruct( fs );\r
3645 }\r
3646 \r
3647 \r
3648 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )\r
3649 {\r
3650     CvDTreeSplit* split;\r
3651 \r
3652     cvStartWriteStruct( fs, 0, CV_NODE_MAP );\r
3653 \r
3654     cvWriteInt( fs, "depth", node->depth );\r
3655     cvWriteInt( fs, "sample_count", node->sample_count );\r
3656     cvWriteReal( fs, "value", node->value );\r
3657 \r
3658     if( data->is_classifier )\r
3659         cvWriteInt( fs, "norm_class_idx", node->class_idx );\r
3660 \r
3661     cvWriteInt( fs, "Tn", node->Tn );\r
3662     cvWriteInt( fs, "complexity", node->complexity );\r
3663     cvWriteReal( fs, "alpha", node->alpha );\r
3664     cvWriteReal( fs, "node_risk", node->node_risk );\r
3665     cvWriteReal( fs, "tree_risk", node->tree_risk );\r
3666     cvWriteReal( fs, "tree_error", node->tree_error );\r
3667 \r
3668     if( node->left )\r
3669     {\r
3670         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );\r
3671 \r
3672         for( split = node->split; split != 0; split = split->next )\r
3673             write_split( fs, split );\r
3674 \r
3675         cvEndWriteStruct( fs );\r
3676     }\r
3677 \r
3678     cvEndWriteStruct( fs );\r
3679 }\r
3680 \r
3681 \r
3682 void CvDTree::write_tree_nodes( CvFileStorage* fs )\r
3683 {\r
3684     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );\r
3685 \r
3686     __BEGIN__;\r
3687 \r
3688     CvDTreeNode* node = root;\r
3689 \r
3690     // traverse the tree and save all the nodes in depth-first order\r
3691     for(;;)\r
3692     {\r
3693         CvDTreeNode* parent;\r
3694         for(;;)\r
3695         {\r
3696             write_node( fs, node );\r
3697             if( !node->left )\r
3698                 break;\r
3699             node = node->left;\r
3700         }\r
3701 \r
3702         for( parent = node->parent; parent && parent->right == node;\r
3703             node = parent, parent = parent->parent )\r
3704             ;\r
3705 \r
3706         if( !parent )\r
3707             break;\r
3708 \r
3709         node = parent->right;\r
3710     }\r
3711 \r
3712     __END__;\r
3713 }\r
3714 \r
3715 \r
3716 void CvDTree::write( CvFileStorage* fs, const char* name )\r
3717 {\r
3718     //CV_FUNCNAME( "CvDTree::write" );\r
3719 \r
3720     __BEGIN__;\r
3721 \r
3722     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );\r
3723 \r
3724     get_var_importance();\r
3725     data->write_params( fs );\r
3726     if( var_importance )\r
3727         cvWrite( fs, "var_importance", var_importance );\r
3728     write( fs );\r
3729 \r
3730     cvEndWriteStruct( fs );\r
3731 \r
3732     __END__;\r
3733 }\r
3734 \r
3735 \r
3736 void CvDTree::write( CvFileStorage* fs )\r
3737 {\r
3738     //CV_FUNCNAME( "CvDTree::write" );\r
3739 \r
3740     __BEGIN__;\r
3741 \r
3742     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );\r
3743 \r
3744     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );\r
3745     write_tree_nodes( fs );\r
3746     cvEndWriteStruct( fs );\r
3747 \r
3748     __END__;\r
3749 }\r
3750 \r
3751 \r
3752 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )\r
3753 {\r
3754     CvDTreeSplit* split = 0;\r
3755 \r
3756     CV_FUNCNAME( "CvDTree::read_split" );\r
3757 \r
3758     __BEGIN__;\r
3759 \r
3760     int vi, ci;\r
3761 \r
3762     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3763         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );\r
3764 \r
3765     vi = cvReadIntByName( fs, fnode, "var", -1 );\r
3766     if( (unsigned)vi >= (unsigned)data->var_count )\r
3767         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );\r
3768 \r
3769     ci = data->get_var_type(vi);\r
3770     if( ci >= 0 ) // split on categorical var\r
3771     {\r
3772         int i, n = data->cat_count->data.i[ci], inversed = 0, val;\r
3773         CvSeqReader reader;\r
3774         CvFileNode* inseq;\r
3775         split = data->new_split_cat( vi, 0 );\r
3776         inseq = cvGetFileNodeByName( fs, fnode, "in" );\r
3777         if( !inseq )\r
3778         {\r
3779             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );\r
3780             inversed = 1;\r
3781         }\r
3782         if( !inseq ||\r
3783             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))\r
3784             CV_ERROR( CV_StsParseError,\r
3785             "Either 'in' or 'not_in' tags should be inside a categorical split data" );\r
3786 \r
3787         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )\r
3788         {\r
3789             val = inseq->data.i;\r
3790             if( (unsigned)val >= (unsigned)n )\r
3791                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3792 \r
3793             split->subset[val >> 5] |= 1 << (val & 31);\r
3794         }\r
3795         else\r
3796         {\r
3797             cvStartReadSeq( inseq->data.seq, &reader );\r
3798 \r
3799             for( i = 0; i < reader.seq->total; i++ )\r
3800             {\r
3801                 CvFileNode* inode = (CvFileNode*)reader.ptr;\r
3802                 val = inode->data.i;\r
3803                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )\r
3804                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3805 \r
3806                 split->subset[val >> 5] |= 1 << (val & 31);\r
3807                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3808             }\r
3809         }\r
3810 \r
3811         // for categorical splits we do not use inversed splits,\r
3812         // instead we inverse the variable set in the split\r
3813         if( inversed )\r
3814             for( i = 0; i < (n + 31) >> 5; i++ )\r
3815                 split->subset[i] ^= -1;\r
3816     }\r
3817     else\r
3818     {\r
3819         CvFileNode* cmp_node;\r
3820         split = data->new_split_ord( vi, 0, 0, 0, 0 );\r
3821 \r
3822         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );\r
3823         if( !cmp_node )\r
3824         {\r
3825             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );\r
3826             split->inversed = 1;\r
3827         }\r
3828 \r
3829         split->ord.c = (float)cvReadReal( cmp_node );\r
3830     }\r
3831 \r
3832     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );\r
3833 \r
3834     __END__;\r
3835 \r
3836     return split;\r
3837 }\r
3838 \r
3839 \r
3840 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )\r
3841 {\r
3842     CvDTreeNode* node = 0;\r
3843 \r
3844     CV_FUNCNAME( "CvDTree::read_node" );\r
3845 \r
3846     __BEGIN__;\r
3847 \r
3848     CvFileNode* splits;\r
3849     int i, depth;\r
3850 \r
3851     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3852         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );\r
3853 \r
3854     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));\r
3855     depth = cvReadIntByName( fs, fnode, "depth", -1 );\r
3856     if( depth != node->depth )\r
3857         CV_ERROR( CV_StsParseError, "incorrect node depth" );\r
3858 \r
3859     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );\r
3860     node->value = cvReadRealByName( fs, fnode, "value" );\r
3861     if( data->is_classifier )\r
3862         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );\r
3863 \r
3864     node->Tn = cvReadIntByName( fs, fnode, "Tn" );\r
3865     node->complexity = cvReadIntByName( fs, fnode, "complexity" );\r
3866     node->alpha = cvReadRealByName( fs, fnode, "alpha" );\r
3867     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );\r
3868     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );\r
3869     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );\r
3870 \r
3871     splits = cvGetFileNodeByName( fs, fnode, "splits" );\r
3872     if( splits )\r
3873     {\r
3874         CvSeqReader reader;\r
3875         CvDTreeSplit* last_split = 0;\r
3876 \r
3877         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )\r
3878             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );\r
3879 \r
3880         cvStartReadSeq( splits->data.seq, &reader );\r
3881         for( i = 0; i < reader.seq->total; i++ )\r
3882         {\r
3883             CvDTreeSplit* split;\r
3884             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));\r
3885             if( !last_split )\r
3886                 node->split = last_split = split;\r
3887             else\r
3888                 last_split = last_split->next = split;\r
3889 \r
3890             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3891         }\r
3892     }\r
3893 \r
3894     __END__;\r
3895 \r
3896     return node;\r
3897 }\r
3898 \r
3899 \r
3900 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )\r
3901 {\r
3902     CV_FUNCNAME( "CvDTree::read_tree_nodes" );\r
3903 \r
3904     __BEGIN__;\r
3905 \r
3906     CvSeqReader reader;\r
3907     CvDTreeNode _root;\r
3908     CvDTreeNode* parent = &_root;\r
3909     int i;\r
3910     parent->left = parent->right = parent->parent = 0;\r
3911 \r
3912     cvStartReadSeq( fnode->data.seq, &reader );\r
3913 \r
3914     for( i = 0; i < reader.seq->total; i++ )\r
3915     {\r
3916         CvDTreeNode* node;\r
3917 \r
3918         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));\r
3919         if( !parent->left )\r
3920             parent->left = node;\r
3921         else\r
3922             parent->right = node;\r
3923         if( node->split )\r
3924             parent = node;\r
3925         else\r
3926         {\r
3927             while( parent && parent->right )\r
3928                 parent = parent->parent;\r
3929         }\r
3930 \r
3931         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3932     }\r
3933 \r
3934     root = _root.left;\r
3935 \r
3936     __END__;\r
3937 }\r
3938 \r
3939 \r
3940 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )\r
3941 {\r
3942     CvDTreeTrainData* _data = new CvDTreeTrainData();\r
3943     _data->read_params( fs, fnode );\r
3944 \r
3945     read( fs, fnode, _data );\r
3946     get_var_importance();\r
3947 }\r
3948 \r
3949 \r
3950 // a special entry point for reading weak decision trees from the tree ensembles\r
3951 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )\r
3952 {\r
3953     CV_FUNCNAME( "CvDTree::read" );\r
3954 \r
3955     __BEGIN__;\r
3956 \r
3957     CvFileNode* tree_nodes;\r
3958 \r
3959     clear();\r
3960     data = _data;\r
3961 \r
3962     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );\r
3963     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )\r
3964         CV_ERROR( CV_StsParseError, "nodes tag is missing" );\r
3965 \r
3966     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );\r
3967     read_tree_nodes( fs, tree_nodes );\r
3968 \r
3969     __END__;\r
3970 }\r
3971 \r
3972 /* End of file. */\r