]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mltree.cpp
Reducing memory footprint in decision tree engine. Remake of CvERTrees.
[opencv.git] / opencv / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////\r
2 //\r
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
4 //\r
5 //  By downloading, copying, installing or using the software you agree to this license.\r
6 //  If you do not agree to this license, do not download, install,\r
7 //  copy or use the software.\r
8 //\r
9 //\r
10 //                        Intel License Agreement\r
11 //\r
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.\r
13 // Third party copyrights are property of their respective owners.\r
14 //\r
15 // Redistribution and use in source and binary forms, with or without modification,\r
16 // are permitted provided that the following conditions are met:\r
17 //\r
18 //   * Redistribution's of source code must retain the above copyright notice,\r
19 //     this list of conditions and the following disclaimer.\r
20 //\r
21 //   * Redistribution's in binary form must reproduce the above copyright notice,\r
22 //     this list of conditions and the following disclaimer in the documentation\r
23 //     and/or other materials provided with the distribution.\r
24 //\r
25 //   * The name of Intel Corporation may not be used to endorse or promote products\r
26 //     derived from this software without specific prior written permission.\r
27 //\r
28 // This software is provided by the copyright holders and contributors "as is" and\r
29 // any express or implied warranties, including, but not limited to, the implied\r
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.\r
31 // In no event shall the Intel Corporation or contributors be liable for any direct,\r
32 // indirect, incidental, special, exemplary, or consequential damages\r
33 // (including, but not limited to, procurement of substitute goods or services;\r
34 // loss of use, data, or profits; or business interruption) however caused\r
35 // and on any theory of liability, whether in contract, strict liability,\r
36 // or tort (including negligence or otherwise) arising in any way out of\r
37 // the use of this software, even if advised of the possibility of such damage.\r
38 //\r
39 //M*/\r
40 \r
41 #include "_ml.h"\r
42 \r
43 static const float ord_nan = FLT_MAX*0.5f;\r
44 static const int min_block_size = 1 << 16;\r
45 static const int block_size_delta = 1 << 10;\r
46 \r
47 CvDTreeTrainData::CvDTreeTrainData()\r
48 {\r
49     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
50         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
51     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
52     pred_float_buf = resp_float_buf = 0;\r
53     tree_storage = temp_storage = 0;\r
54 \r
55     clear();\r
56 }\r
57 \r
58 \r
59 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,\r
60                       const CvMat* _responses, const CvMat* _var_idx,\r
61                       const CvMat* _sample_idx, const CvMat* _var_type,\r
62                       const CvMat* _missing_mask, const CvDTreeParams& _params,\r
63                       bool _shared, bool _add_labels )\r
64 {\r
65     var_idx = var_type = cat_count = cat_ofs = cat_map =\r
66         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;\r
67 \r
68     pred_int_buf = resp_int_buf = cv_lables_buf = sample_idx_buf = 0;\r
69     pred_float_buf = resp_float_buf = 0;\r
70 \r
71     tree_storage = temp_storage = 0;\r
72 \r
73     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,\r
74               _var_type, _missing_mask, _params, _shared, _add_labels );\r
75 }\r
76 \r
77 \r
78 CvDTreeTrainData::~CvDTreeTrainData()\r
79 {\r
80     clear();\r
81 }\r
82 \r
83 \r
84 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )\r
85 {\r
86     bool ok = false;\r
87 \r
88     CV_FUNCNAME( "CvDTreeTrainData::set_params" );\r
89 \r
90     __BEGIN__;\r
91 \r
92     // set parameters\r
93     params = _params;\r
94 \r
95     if( params.max_categories < 2 )\r
96         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );\r
97     params.max_categories = MIN( params.max_categories, 15 );\r
98 \r
99     if( params.max_depth < 0 )\r
100         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );\r
101     params.max_depth = MIN( params.max_depth, 25 );\r
102 \r
103     params.min_sample_count = MAX(params.min_sample_count,1);\r
104 \r
105     if( params.cv_folds < 0 )\r
106         CV_ERROR( CV_StsOutOfRange,\r
107         "params.cv_folds should be =0 (the tree is not pruned) "\r
108         "or n>0 (tree is pruned using n-fold cross-validation)" );\r
109 \r
110     if( params.cv_folds == 1 )\r
111         params.cv_folds = 0;\r
112 \r
113     if( params.regression_accuracy < 0 )\r
114         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );\r
115 \r
116     ok = true;\r
117 \r
118     __END__;\r
119 \r
120     return ok;\r
121 }\r
122 \r
123 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))\r
124 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )\r
125 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )\r
126 \r
127 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])\r
128 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )\r
129 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )\r
130 \r
131 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))\r
132 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )\r
133 \r
134 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
135     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
136     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
137     bool _shared, bool _add_labels, bool _update_data )\r
138 {\r
139     CvMat* sample_indices = 0;\r
140     CvMat* var_type0 = 0;\r
141     CvMat* tmp_map = 0;\r
142     int** int_ptr = 0;\r
143     CvPair16u32s* pair16u32s_ptr = 0;\r
144     CvDTreeTrainData* data = 0;\r
145     float *_fdst = 0;\r
146     int *_idst = 0;\r
147     unsigned short* udst = 0;\r
148     int* idst = 0;\r
149 \r
150     CV_FUNCNAME( "CvDTreeTrainData::set_data" );\r
151 \r
152     __BEGIN__;\r
153 \r
154     int sample_all = 0, r_type = 0, cv_n;\r
155     int total_c_count = 0;\r
156     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
157     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step\r
158     int vi, i, size;\r
159     char err[100];\r
160     const int *sidx = 0, *vidx = 0;\r
161     \r
162     \r
163     if( _update_data && data_root )\r
164     {\r
165         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,\r
166             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );\r
167 \r
168         // compare new and old train data\r
169         if( !(data->var_count == var_count &&\r
170             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
171             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&\r
172             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )\r
173             CV_ERROR( CV_StsBadArg,\r
174             "The new training data must have the same types and the input and output variables "\r
175             "and the same categories for categorical variables" );\r
176 \r
177         cvReleaseMat( &priors );\r
178         cvReleaseMat( &priors_mult );\r
179         cvReleaseMat( &buf );\r
180         cvReleaseMat( &direction );\r
181         cvReleaseMat( &split_buf );\r
182         cvReleaseMemStorage( &temp_storage );\r
183 \r
184         priors = data->priors; data->priors = 0;\r
185         priors_mult = data->priors_mult; data->priors_mult = 0;\r
186         buf = data->buf; data->buf = 0;\r
187         buf_count = data->buf_count; buf_size = data->buf_size;\r
188         sample_count = data->sample_count;\r
189 \r
190         direction = data->direction; data->direction = 0;\r
191         split_buf = data->split_buf; data->split_buf = 0;\r
192         temp_storage = data->temp_storage; data->temp_storage = 0;\r
193         nv_heap = data->nv_heap; cv_heap = data->cv_heap;\r
194 \r
195         data_root = new_node( 0, sample_count, 0, 0 );\r
196         EXIT;\r
197     }\r
198 \r
199     clear();\r
200 \r
201     var_all = 0;\r
202     rng = cvRNG(-1);\r
203 \r
204     CV_CALL( set_params( _params ));\r
205 \r
206     // check parameter types and sizes\r
207     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));\r
208 \r
209     train_data = _train_data;\r
210     responses = _responses;\r
211 \r
212     if( _tflag == CV_ROW_SAMPLE )\r
213     {\r
214         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
215         dv_step = 1;\r
216         if( _missing_mask )\r
217             ms_step = _missing_mask->step, mv_step = 1;\r
218     }\r
219     else\r
220     {\r
221         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);\r
222         ds_step = 1;\r
223         if( _missing_mask )\r
224             mv_step = _missing_mask->step, ms_step = 1;\r
225     }\r
226     tflag = _tflag;\r
227 \r
228     sample_count = sample_all;\r
229     var_count = var_all;\r
230     is_buf_16u = false;\r
231     if (_train_data->rows + _train_data->cols -1 < 65536) \r
232         is_buf_16u = true;                                \r
233     \r
234     if( _sample_idx )\r
235     {\r
236         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));\r
237         sidx = sample_indices->data.i;\r
238         sample_count = sample_indices->rows + sample_indices->cols - 1;\r
239     }\r
240 \r
241     if( _var_idx )\r
242     {\r
243         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));\r
244         vidx = var_idx->data.i;\r
245         var_count = var_idx->rows + var_idx->cols - 1;\r
246     }\r
247 \r
248     if( !CV_IS_MAT(_responses) ||\r
249         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
250          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
251         (_responses->rows != 1 && _responses->cols != 1) ||\r
252         _responses->rows + _responses->cols - 1 != sample_all )\r
253         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
254                   "floating-point vector containing as many elements as "\r
255                   "the total number of samples in the training data matrix" );\r
256    \r
257   \r
258     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));\r
259 \r
260     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));\r
261    \r
262     \r
263     cat_var_count = 0;\r
264     ord_var_count = -1;\r
265 \r
266     is_classifier = r_type == CV_VAR_CATEGORICAL;\r
267 \r
268     // step 0. calc the number of categorical vars\r
269     for( vi = 0; vi < var_count; vi++ )\r
270     {\r
271         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?\r
272             cat_var_count++ : ord_var_count--;\r
273     }\r
274 \r
275     ord_var_count = ~ord_var_count;\r
276     cv_n = params.cv_folds;\r
277     // set the two last elements of var_type array to be able\r
278     // to locate responses and cross-validation labels using\r
279     // the corresponding get_* functions.\r
280     var_type->data.i[var_count] = cat_var_count;\r
281     var_type->data.i[var_count+1] = cat_var_count+1;\r
282 \r
283     // in case of single ordered predictor we need dummy cv_labels\r
284     // for safe split_node_data() operation\r
285     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;\r
286 \r
287     work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);\r
288     buf_size = (work_var_count + 1)*sample_count;\r
289     shared = _shared;\r
290     buf_count = shared ? 2 : 1;\r
291     \r
292     if ( is_buf_16u )\r
293     {\r
294         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));\r
295         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));\r
296     }\r
297     else\r
298     {\r
299         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));\r
300         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));\r
301     }    \r
302 \r
303     size = is_classifier ? cat_var_count+1 : cat_var_count;\r
304     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));\r
305     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));\r
306         \r
307     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;\r
308     size = !size ? 1 : size;\r
309     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));\r
310 \r
311     // now calculate the maximum size of split,\r
312     // create memory storage that will keep nodes and splits of the decision tree\r
313     // allocate root node and the buffer for the whole training data\r
314     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
315         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));\r
316     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
317     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
318     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
319     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));\r
320 \r
321     nv_size = var_count*sizeof(int);\r
322     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));\r
323 \r
324     temp_block_size = nv_size;\r
325 \r
326     if( cv_n )\r
327     {\r
328         if( sample_count < cv_n*MAX(params.min_sample_count,10) )\r
329             CV_ERROR( CV_StsOutOfRange,\r
330                 "The many folds in cross-validation for such a small dataset" );\r
331 \r
332         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );\r
333         temp_block_size = MAX(temp_block_size, cv_size);\r
334     }\r
335 \r
336     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );\r
337     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));\r
338     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));\r
339     if( cv_size )\r
340         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));\r
341 \r
342     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));\r
343 \r
344     max_c_count = 1;\r
345 \r
346     _fdst = 0;\r
347     _idst = 0;\r
348     if (ord_var_count)\r
349         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));\r
350     if (is_buf_16u && (cat_var_count || is_classifier))\r
351         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));\r
352 \r
353     // transform the training data to convenient representation\r
354     for( vi = 0; vi <= var_count; vi++ )\r
355     {\r
356         int ci;\r
357         const uchar* mask = 0;\r
358         int m_step = 0, step;\r
359         const int* idata = 0;\r
360         const float* fdata = 0;\r
361         int num_valid = 0;\r
362 \r
363         if( vi < var_count ) // analyze i-th input variable\r
364         {\r
365             int vi0 = vidx ? vidx[vi] : vi;\r
366             ci = get_var_type(vi);\r
367             step = ds_step; m_step = ms_step;\r
368             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
369                 idata = _train_data->data.i + vi0*dv_step;\r
370             else\r
371                 fdata = _train_data->data.fl + vi0*dv_step;\r
372             if( _missing_mask )\r
373                 mask = _missing_mask->data.ptr + vi0*mv_step;\r
374         }\r
375         else // analyze _responses\r
376         {\r
377             ci = cat_var_count;\r
378             step = CV_IS_MAT_CONT(_responses->type) ?\r
379                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);\r
380             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )\r
381                 idata = _responses->data.i;\r
382             else\r
383                 fdata = _responses->data.fl;\r
384         }\r
385 \r
386         if( (vi < var_count && ci>=0) ||\r
387             (vi == var_count && is_classifier) ) // process categorical variable or response\r
388         {\r
389             int c_count, prev_label;\r
390             int* c_map;\r
391             \r
392             if (is_buf_16u)\r
393                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
394             else\r
395                 idst = buf->data.i + vi*sample_count;\r
396             \r
397             // copy data\r
398             for( i = 0; i < sample_count; i++ )\r
399             {\r
400                 int val = INT_MAX, si = sidx ? sidx[i] : i;\r
401                 if( !mask || !mask[si*m_step] )\r
402                 {\r
403                     if( idata )\r
404                         val = idata[si*step];\r
405                     else\r
406                     {\r
407                         float t = fdata[si*step];\r
408                         val = cvRound(t);\r
409                         if( val != t )\r
410                         {\r
411                             sprintf( err, "%d-th value of %d-th (categorical) "\r
412                                 "variable is not an integer", i, vi );\r
413                             CV_ERROR( CV_StsBadArg, err );\r
414                         }\r
415                     }\r
416 \r
417                     if( val == INT_MAX )\r
418                     {\r
419                         sprintf( err, "%d-th value of %d-th (categorical) "\r
420                             "variable is too large", i, vi );\r
421                         CV_ERROR( CV_StsBadArg, err );\r
422                     }\r
423                     num_valid++;\r
424                 }\r
425                 if (is_buf_16u)\r
426                 {\r
427                     _idst[i] = val;\r
428                     pair16u32s_ptr[i].u = udst + i;\r
429                     pair16u32s_ptr[i].i = _idst + i;\r
430                 }   \r
431                 else\r
432                 {\r
433                     idst[i] = val;\r
434                     int_ptr[i] = idst + i;\r
435                 }\r
436             }\r
437 \r
438             c_count = num_valid > 0;\r
439 \r
440             if (is_buf_16u)\r
441             {\r
442                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );\r
443                 // count the categories\r
444                 for( i = 1; i < num_valid; i++ )\r
445                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)\r
446                         c_count ++ ;\r
447             }\r
448             else\r
449             {\r
450                 icvSortIntPtr( int_ptr, sample_count, 0 );\r
451                 // count the categories\r
452                 for( i = 1; i < num_valid; i++ )\r
453                     c_count += *int_ptr[i] != *int_ptr[i-1];\r
454             }\r
455 \r
456             if( vi > 0 )\r
457                 max_c_count = MAX( max_c_count, c_count );\r
458             cat_count->data.i[ci] = c_count;\r
459             cat_ofs->data.i[ci] = total_c_count;\r
460 \r
461             // resize cat_map, if need\r
462             if( cat_map->cols < total_c_count + c_count )\r
463             {\r
464                 tmp_map = cat_map;\r
465                 CV_CALL( cat_map = cvCreateMat( 1,\r
466                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));\r
467                 for( i = 0; i < total_c_count; i++ )\r
468                     cat_map->data.i[i] = tmp_map->data.i[i];\r
469                 cvReleaseMat( &tmp_map );\r
470             }\r
471 \r
472             c_map = cat_map->data.i + total_c_count;\r
473             total_c_count += c_count;\r
474 \r
475             c_count = -1;\r
476             if (is_buf_16u)\r
477             {\r
478                 // compact the class indices and build the map\r
479                 prev_label = ~*pair16u32s_ptr[0].i;\r
480                 for( i = 0; i < num_valid; i++ )\r
481                 {\r
482                     int cur_label = *pair16u32s_ptr[i].i;\r
483                     if( cur_label != prev_label )\r
484                         c_map[++c_count] = prev_label = cur_label;\r
485                     *pair16u32s_ptr[i].u = (unsigned short)c_count;\r
486                 }\r
487                 // replace labels for missing values with -1\r
488                 for( ; i < sample_count; i++ )\r
489                     *pair16u32s_ptr[i].u = 65535;\r
490             }\r
491             else\r
492             {\r
493                 // compact the class indices and build the map\r
494                 prev_label = ~*int_ptr[0];\r
495                 for( i = 0; i < num_valid; i++ )\r
496                 {\r
497                     int cur_label = *int_ptr[i];\r
498                     if( cur_label != prev_label )\r
499                         c_map[++c_count] = prev_label = cur_label;\r
500                     *int_ptr[i] = c_count;\r
501                 }\r
502                 // replace labels for missing values with -1\r
503                 for( ; i < sample_count; i++ )\r
504                     *int_ptr[i] = -1;\r
505             }           \r
506         }\r
507         else if( ci < 0 ) // process ordered variable\r
508         {\r
509             if (is_buf_16u)\r
510                 udst = (unsigned short*)(buf->data.s + vi*sample_count);\r
511             else\r
512                 idst = buf->data.i + vi*sample_count;\r
513 \r
514             for( i = 0; i < sample_count; i++ )\r
515             {\r
516                 float val = ord_nan;\r
517                 int si = sidx ? sidx[i] : i;\r
518                 if( !mask || !mask[si*m_step] )\r
519                 {\r
520                     if( idata )\r
521                         val = (float)idata[si*step];\r
522                     else\r
523                         val = fdata[si*step];\r
524 \r
525                     if( fabs(val) >= ord_nan )\r
526                     {\r
527                         sprintf( err, "%d-th value of %d-th (ordered) "\r
528                             "variable (=%g) is too large", i, vi, val );\r
529                         CV_ERROR( CV_StsBadArg, err );\r
530                     }\r
531                 }\r
532                 num_valid++;\r
533                 if (is_buf_16u)\r
534                     udst[i] = (unsigned short)i;\r
535                 else\r
536                     idst[i] = i; // Ã¯Ã¥Ã°Ã¥Ã­Ã¥Ã±Ã²Ã¨ Ã¢Ã»Ã¸Ã¥ Ã¢ if( idata )\r
537                 _fdst[i] = val;\r
538                 \r
539             }\r
540             if (is_buf_16u)\r
541                 icvSortUShAux( udst, num_valid, _fdst);\r
542             else\r
543                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );\r
544         }\r
545        \r
546         if( vi < var_count )\r
547             data_root->set_num_valid(vi, num_valid);\r
548     }\r
549 \r
550     // set sample labels\r
551     if (is_buf_16u)\r
552         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);\r
553     else\r
554         idst = buf->data.i + work_var_count*sample_count;\r
555 \r
556     for (i = 0; i < sample_count; i++)\r
557     {\r
558         if (udst)\r
559             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;\r
560         else\r
561             idst[i] = sidx ? sidx[i] : i;\r
562     }\r
563 \r
564     if( cv_n )\r
565     {\r
566         unsigned short* udst = 0;\r
567         int* idst = 0;\r
568         CvRNG* r = &rng;\r
569 \r
570         if (is_buf_16u)\r
571         {\r
572             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);\r
573             for( i = vi = 0; i < sample_count; i++ )\r
574             {\r
575                 udst[i] = (unsigned short)vi++;\r
576                 vi &= vi < cv_n ? -1 : 0;\r
577             }\r
578 \r
579             for( i = 0; i < sample_count; i++ )\r
580             {\r
581                 int a = cvRandInt(r) % sample_count;\r
582                 int b = cvRandInt(r) % sample_count;\r
583                 unsigned short unsh = (unsigned short)vi;\r
584                 CV_SWAP( udst[a], udst[b], unsh );\r
585             }\r
586         }\r
587         else\r
588         {\r
589             idst = buf->data.i + (get_work_var_count()-1)*sample_count;\r
590             for( i = vi = 0; i < sample_count; i++ )\r
591             {\r
592                 idst[i] = vi++;\r
593                 vi &= vi < cv_n ? -1 : 0;\r
594             }\r
595 \r
596             for( i = 0; i < sample_count; i++ )\r
597             {\r
598                 int a = cvRandInt(r) % sample_count;\r
599                 int b = cvRandInt(r) % sample_count;\r
600                 CV_SWAP( idst[a], idst[b], vi );\r
601             }\r
602         }\r
603     }\r
604 \r
605     if ( cat_map ) \r
606         cat_map->cols = MAX( total_c_count, 1 );\r
607 \r
608     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
609         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
610     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));\r
611 \r
612     have_priors = is_classifier && params.priors;\r
613     if( is_classifier )\r
614     {\r
615         int m = get_num_classes();\r
616         double sum = 0;\r
617         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));\r
618         for( i = 0; i < m; i++ )\r
619         {\r
620             double val = have_priors ? params.priors[i] : 1.;\r
621             if( val <= 0 )\r
622                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );\r
623             priors->data.db[i] = val;\r
624             sum += val;\r
625         }\r
626 \r
627         // normalize weights\r
628         if( have_priors )\r
629             cvScale( priors, priors, 1./sum );\r
630 \r
631         CV_CALL( priors_mult = cvCloneMat( priors ));\r
632         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));\r
633     }\r
634 \r
635 \r
636     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));\r
637     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));\r
638 \r
639     CV_CALL( pred_float_buf = (float*)cvAlloc(sample_count*sizeof(pred_float_buf[0])) );\r
640     CV_CALL( pred_int_buf = (int*)cvAlloc(sample_count*sizeof(pred_int_buf[0])) );\r
641     CV_CALL( resp_float_buf = (float*)cvAlloc(sample_count*sizeof(resp_float_buf[0])) );\r
642     CV_CALL( resp_int_buf = (int*)cvAlloc(sample_count*sizeof(resp_int_buf[0])) );\r
643     CV_CALL( cv_lables_buf = (int*)cvAlloc(sample_count*sizeof(cv_lables_buf[0])) );\r
644     CV_CALL( sample_idx_buf = (int*)cvAlloc(sample_count*sizeof(sample_idx_buf[0])) );\r
645 \r
646     __END__;\r
647 \r
648     if( data )\r
649         delete data;\r
650 \r
651     if (_fdst)\r
652         cvFree( &_fdst );\r
653     if (_idst)\r
654         cvFree( &_idst );\r
655     cvFree( &int_ptr );\r
656     cvReleaseMat( &var_type0 );\r
657     cvReleaseMat( &sample_indices );\r
658     cvReleaseMat( &tmp_map );\r
659 }\r
660 \r
661 \r
662 \r
663 void CvDTreeTrainData::do_responses_copy()\r
664 {\r
665     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );\r
666     cvCopy( responses, responses_copy);\r
667     responses = responses_copy;\r
668 }\r
669 \r
670 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )\r
671 {\r
672     CvDTreeNode* root = 0;\r
673     CvMat* isubsample_idx = 0;\r
674     CvMat* subsample_co = 0;\r
675 \r
676     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );\r
677 \r
678     __BEGIN__;\r
679 \r
680     if( !data_root )\r
681         CV_ERROR( CV_StsError, "No training data has been set" );\r
682 \r
683     if( _subsample_idx )\r
684         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
685 \r
686     if( !isubsample_idx )\r
687     {\r
688         // make a copy of the root node\r
689         CvDTreeNode temp;\r
690         int i;\r
691         root = new_node( 0, 1, 0, 0 );\r
692         temp = *root;\r
693         *root = *data_root;\r
694         root->num_valid = temp.num_valid;\r
695         if( root->num_valid )\r
696         {\r
697             for( i = 0; i < var_count; i++ )\r
698                 root->num_valid[i] = data_root->num_valid[i];\r
699         }\r
700         root->cv_Tn = temp.cv_Tn;\r
701         root->cv_node_risk = temp.cv_node_risk;\r
702         root->cv_node_error = temp.cv_node_error;\r
703     }\r
704     else\r
705     {\r
706         int* sidx = isubsample_idx->data.i;\r
707         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)\r
708         int* co, cur_ofs = 0;\r
709         int vi, i;\r
710         int work_var_count = get_work_var_count();\r
711         int count = isubsample_idx->rows + isubsample_idx->cols - 1;\r
712 \r
713         root = new_node( 0, count, 1, 0 );\r
714 \r
715         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
716         cvZero( subsample_co );\r
717         co = subsample_co->data.i;\r
718         for( i = 0; i < count; i++ )\r
719             co[sidx[i]*2]++;\r
720         for( i = 0; i < sample_count; i++ )\r
721         {\r
722             if( co[i*2] )\r
723             {\r
724                 co[i*2+1] = cur_ofs;\r
725                 cur_ofs += co[i*2];\r
726             }\r
727             else\r
728                 co[i*2+1] = -1;\r
729         }\r
730 \r
731         for( vi = 0; vi < work_var_count; vi++ )\r
732         {\r
733             int ci = get_var_type(vi);\r
734 \r
735             if( ci >= 0 || vi >= var_count )\r
736             {\r
737                 int* src_buf = pred_int_buf;\r
738                 const int* src = 0;\r
739                 int num_valid = 0;\r
740                 \r
741                 get_cat_var_data( data_root, vi, src_buf, &src );\r
742 \r
743                 if (is_buf_16u)\r
744                 {\r
745                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
746                         vi*sample_count + root->offset);\r
747                     for( i = 0; i < count; i++ )\r
748                     {\r
749                         int val = src[sidx[i]];\r
750                         udst[i] = (unsigned short)val;\r
751                         num_valid += val >= 0;\r
752                     }\r
753                 }\r
754                 else\r
755                 {\r
756                     int* idst = buf->data.i + root->buf_idx*buf->cols + \r
757                         vi*sample_count + root->offset;\r
758                     for( i = 0; i < count; i++ )\r
759                     {\r
760                         int val = src[sidx[i]];\r
761                         idst[i] = val;\r
762                         num_valid += val >= 0;\r
763                     }\r
764                 }\r
765 \r
766                 if( vi < var_count )\r
767                     root->set_num_valid(vi, num_valid);\r
768             }\r
769             else\r
770             {\r
771                 int *src_idx_buf = pred_int_buf;\r
772                 const int* src_idx = 0;\r
773                 float *src_val_buf = pred_float_buf;\r
774                 const float* src_val = 0;\r
775                 int j = 0, idx, count_i;\r
776                 int num_valid = data_root->get_num_valid(vi);\r
777 \r
778                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );\r
779                 if (is_buf_16u)\r
780                 {\r
781                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
782                         vi*sample_count + data_root->offset);\r
783                     for( i = 0; i < num_valid; i++ )\r
784                     {\r
785                         idx = src_idx[i];\r
786                         count_i = co[idx*2];\r
787                         if( count_i )\r
788                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
789                                 udst_idx[j] = (unsigned short)cur_ofs;\r
790                     }\r
791 \r
792                     root->set_num_valid(vi, j);\r
793 \r
794                     for( ; i < sample_count; i++ )\r
795                     {\r
796                         idx = src_idx[i];\r
797                         count_i = co[idx*2];\r
798                         if( count_i )\r
799                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
800                                 udst_idx[j] = (unsigned short)cur_ofs;\r
801                     }\r
802                 }\r
803                 else\r
804                 {\r
805                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + \r
806                         vi*sample_count + root->offset;\r
807                     for( i = 0; i < num_valid; i++ )\r
808                     {\r
809                         idx = src_idx[i];\r
810                         count_i = co[idx*2];\r
811                         if( count_i )\r
812                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
813                                 idst_idx[j] = cur_ofs;\r
814                     }\r
815 \r
816                     root->set_num_valid(vi, j);\r
817 \r
818                     for( ; i < sample_count; i++ )\r
819                     {\r
820                         idx = src_idx[i];\r
821                         count_i = co[idx*2];\r
822                         if( count_i )\r
823                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )\r
824                                 idst_idx[j] = cur_ofs;\r
825                     }\r
826                 }\r
827             }\r
828         }\r
829         // sample indices subsampling\r
830         int* sample_idx_src_buf = sample_idx_buf;\r
831         const int* sample_idx_src = 0;\r
832         get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);\r
833         if (is_buf_16u)\r
834         {\r
835             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + \r
836                 get_work_var_count()*sample_count + root->offset);            \r
837             for (i = 0; i < count; i++)\r
838                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];\r
839         }\r
840         else\r
841         {\r
842             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + \r
843                 get_work_var_count()*sample_count + root->offset;            \r
844             for (i = 0; i < count; i++)\r
845                 sample_idx_dst[i] = sample_idx_src[sidx[i]];\r
846         }\r
847     }\r
848 \r
849     __END__;\r
850 \r
851     cvReleaseMat( &isubsample_idx );\r
852     cvReleaseMat( &subsample_co );\r
853 \r
854     return root;\r
855 }\r
856 \r
857 \r
858 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,\r
859                                     float* values, uchar* missing,\r
860                                     float* responses, bool get_class_idx )\r
861 {\r
862     CvMat* subsample_idx = 0;\r
863     CvMat* subsample_co = 0;\r
864 \r
865     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );\r
866 \r
867     __BEGIN__;\r
868 \r
869     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
870     int* sidx = 0;\r
871     int* co = 0;\r
872 \r
873     if( _subsample_idx )\r
874     {\r
875         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
876         sidx = subsample_idx->data.i;\r
877         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));\r
878         co = subsample_co->data.i;\r
879         cvZero( subsample_co );\r
880         count = subsample_idx->cols + subsample_idx->rows - 1;\r
881         for( i = 0; i < count; i++ )\r
882             co[sidx[i]*2]++;\r
883         for( i = 0; i < total; i++ )\r
884         {\r
885             int count_i = co[i*2];\r
886             if( count_i )\r
887             {\r
888                 co[i*2+1] = cur_ofs*var_count;\r
889                 cur_ofs += count_i;\r
890             }\r
891         }\r
892     }\r
893 \r
894     if( missing )\r
895         memset( missing, 1, count*var_count );\r
896 \r
897     for( vi = 0; vi < var_count; vi++ )\r
898     {\r
899         int ci = get_var_type(vi);\r
900         if( ci >= 0 ) // categorical\r
901         {\r
902             float* dst = values + vi;\r
903             uchar* m = missing ? missing + vi : 0;\r
904             int* src_buf = pred_int_buf;\r
905             const int* src = 0; \r
906             get_cat_var_data(data_root, vi, src_buf, &src);\r
907 \r
908             for( i = 0; i < count; i++, dst += var_count )\r
909             {\r
910                 int idx = sidx ? sidx[i] : i;\r
911                 int val = src[idx];\r
912                 *dst = (float)val;\r
913                 if( m )\r
914                 {\r
915                     *m = val < 0;\r
916                     m += var_count;\r
917                 }\r
918             }\r
919         }\r
920         else // ordered\r
921         {\r
922             float* dst = values + vi;\r
923             uchar* m = missing ? missing + vi : 0;\r
924             int count1 = data_root->get_num_valid(vi);\r
925             float *src_val_buf = pred_float_buf;\r
926             const float *src_val = 0;\r
927             int* src_idx_buf = pred_int_buf;\r
928             const int* src_idx = 0;\r
929             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
930 \r
931             for( i = 0; i < count1; i++ )\r
932             {\r
933                 int idx = src_idx[i];\r
934                 int count_i = 1;\r
935                 if( co )\r
936                 {\r
937                     count_i = co[idx*2];\r
938                     cur_ofs = co[idx*2+1];\r
939                 }\r
940                 else\r
941                     cur_ofs = idx*var_count;\r
942                 if( count_i )\r
943                 {\r
944                     float val = src_val[i];\r
945                     for( ; count_i > 0; count_i--, cur_ofs += var_count )\r
946                     {\r
947                         dst[cur_ofs] = val;\r
948                         if( m )\r
949                             m[cur_ofs] = 0;\r
950                     }\r
951                 }\r
952             }\r
953         }\r
954     }\r
955 \r
956     // copy responses\r
957     if( responses )\r
958     {\r
959         if( is_classifier )\r
960         {\r
961             int* src_buf = resp_int_buf;\r
962             const int* src = 0;\r
963             get_class_labels(data_root, src_buf, &src);\r
964             for( i = 0; i < count; i++ )\r
965             {\r
966                 int idx = sidx ? sidx[i] : i;\r
967                 int val = get_class_idx ? src[idx] :\r
968                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];\r
969                 responses[i] = (float)val;\r
970             }\r
971         }\r
972         else\r
973         {\r
974             float *_values_buf = resp_float_buf;\r
975             const float* _values = 0;\r
976             get_ord_responses(data_root, _values_buf, &_values);\r
977             for( i = 0; i < count; i++ )\r
978             {\r
979                 int idx = sidx ? sidx[i] : i;\r
980                 responses[i] = _values[idx];\r
981             }\r
982         }\r
983     }\r
984 \r
985     __END__;\r
986 \r
987     cvReleaseMat( &subsample_idx );\r
988     cvReleaseMat( &subsample_co );\r
989 }\r
990 \r
991 \r
992 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,\r
993                                          int storage_idx, int offset )\r
994 {\r
995     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );\r
996 \r
997     node->sample_count = count;\r
998     node->depth = parent ? parent->depth + 1 : 0;\r
999     node->parent = parent;\r
1000     node->left = node->right = 0;\r
1001     node->split = 0;\r
1002     node->value = 0;\r
1003     node->class_idx = 0;\r
1004     node->maxlr = 0.;\r
1005 \r
1006     node->buf_idx = storage_idx;\r
1007     node->offset = offset;\r
1008     if( nv_heap )\r
1009         node->num_valid = (int*)cvSetNew( nv_heap );\r
1010     else\r
1011         node->num_valid = 0;\r
1012     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;\r
1013     node->complexity = 0;\r
1014 \r
1015     if( params.cv_folds > 0 && cv_heap )\r
1016     {\r
1017         int cv_n = params.cv_folds;\r
1018         node->Tn = INT_MAX;\r
1019         node->cv_Tn = (int*)cvSetNew( cv_heap );\r
1020         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));\r
1021         node->cv_node_error = node->cv_node_risk + cv_n;\r
1022     }\r
1023     else\r
1024     {\r
1025         node->Tn = 0;\r
1026         node->cv_Tn = 0;\r
1027         node->cv_node_risk = 0;\r
1028         node->cv_node_error = 0;\r
1029     }\r
1030 \r
1031     return node;\r
1032 }\r
1033 \r
1034 \r
1035 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,\r
1036                 int split_point, int inversed, float quality )\r
1037 {\r
1038     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1039     split->var_idx = vi;\r
1040     split->condensed_idx = INT_MIN;\r
1041     split->ord.c = cmp_val;\r
1042     split->ord.split_point = split_point;\r
1043     split->inversed = inversed;\r
1044     split->quality = quality;\r
1045     split->next = 0;\r
1046 \r
1047     return split;\r
1048 }\r
1049 \r
1050 \r
1051 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )\r
1052 {\r
1053     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );\r
1054     int i, n = (max_c_count + 31)/32;\r
1055 \r
1056     split->var_idx = vi;\r
1057     split->condensed_idx = INT_MIN;\r
1058     split->inversed = 0;\r
1059     split->quality = quality;\r
1060     for( i = 0; i < n; i++ )\r
1061         split->subset[i] = 0;\r
1062     split->next = 0;\r
1063 \r
1064     return split;\r
1065 }\r
1066 \r
1067 \r
1068 void CvDTreeTrainData::free_node( CvDTreeNode* node )\r
1069 {\r
1070     CvDTreeSplit* split = node->split;\r
1071     free_node_data( node );\r
1072     while( split )\r
1073     {\r
1074         CvDTreeSplit* next = split->next;\r
1075         cvSetRemoveByPtr( split_heap, split );\r
1076         split = next;\r
1077     }\r
1078     node->split = 0;\r
1079     cvSetRemoveByPtr( node_heap, node );\r
1080 }\r
1081 \r
1082 \r
1083 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )\r
1084 {\r
1085     if( node->num_valid )\r
1086     {\r
1087         cvSetRemoveByPtr( nv_heap, node->num_valid );\r
1088         node->num_valid = 0;\r
1089     }\r
1090     // do not free cv_* fields, as all the cross-validation related data is released at once.\r
1091 }\r
1092 \r
1093 \r
1094 void CvDTreeTrainData::free_train_data()\r
1095 {\r
1096     cvReleaseMat( &counts );\r
1097     cvReleaseMat( &buf );\r
1098     cvReleaseMat( &direction );\r
1099     cvReleaseMat( &split_buf );\r
1100     cvReleaseMemStorage( &temp_storage );\r
1101     cvReleaseMat( &responses_copy );\r
1102     cvFree( &pred_float_buf );
1103     cvFree( &pred_int_buf );
1104     cvFree( &resp_float_buf );
1105     cvFree( &resp_int_buf );
1106     cvFree( &cv_lables_buf );
1107     cvFree( &sample_idx_buf );\r
1108 \r
1109     cv_heap = nv_heap = 0;\r
1110 }\r
1111 \r
1112 \r
1113 void CvDTreeTrainData::clear()\r
1114 {\r
1115     free_train_data();\r
1116 \r
1117     cvReleaseMemStorage( &tree_storage );\r
1118 \r
1119     cvReleaseMat( &var_idx );\r
1120     cvReleaseMat( &var_type );\r
1121     cvReleaseMat( &cat_count );\r
1122     cvReleaseMat( &cat_ofs );\r
1123     cvReleaseMat( &cat_map );\r
1124     cvReleaseMat( &priors );\r
1125     cvReleaseMat( &priors_mult );\r
1126     \r
1127     node_heap = split_heap = 0;\r
1128 \r
1129     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;\r
1130     have_labels = have_priors = is_classifier = false;\r
1131 \r
1132     buf_count = buf_size = 0;\r
1133     shared = false;\r
1134     \r
1135     data_root = 0;\r
1136 \r
1137     rng = cvRNG(-1);\r
1138 }\r
1139 \r
1140 \r
1141 int CvDTreeTrainData::get_num_classes() const\r
1142 {\r
1143     return is_classifier ? cat_count->data.i[cat_var_count] : 0;\r
1144 }\r
1145 \r
1146 \r
1147 int CvDTreeTrainData::get_var_type(int vi) const\r
1148 {\r
1149     return var_type->data.i[vi];\r
1150 }\r
1151 \r
1152 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )\r
1153 {\r
1154     int node_sample_count = n->sample_count; \r
1155     int* sample_indices_buf = sample_idx_buf;\r
1156     const int* sample_indices = 0;\r
1157 \r
1158     get_sample_indices(n, sample_indices_buf, &sample_indices);\r
1159 \r
1160     if( !is_buf_16u )\r
1161         *indices = buf->data.i + n->buf_idx*buf->cols + \r
1162         vi*sample_count + n->offset;\r
1163     else {\r
1164         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1165             vi*sample_count + n->offset );\r
1166         for( int i = 0; i < node_sample_count; i++ )\r
1167             indices_buf[i] = short_indices[i];\r
1168         *indices = indices_buf;\r
1169     }\r
1170     \r
1171     int td_cols = train_data->cols;\r
1172     if( tflag == CV_ROW_SAMPLE )\r
1173     {\r
1174         for( int i = 0; i < node_sample_count && \r
1175             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1176         {\r
1177             int idx = (*indices)[i];\r
1178             idx = sample_indices[idx];\r
1179             ord_values_buf[i] = *(train_data->data.fl + idx * td_cols + vi);\r
1180         }\r
1181     }\r
1182     else\r
1183         for( int i = 0; i < node_sample_count && \r
1184             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )\r
1185         {\r
1186             int idx = (*indices)[i];\r
1187             idx = sample_indices[idx];\r
1188             ord_values_buf[i] = *(train_data->data.fl + vi* td_cols + idx);\r
1189         }\r
1190     \r
1191     *ord_values = ord_values_buf;\r
1192     return 0; //TODO: return the number of non-missing values\r
1193 }\r
1194 \r
1195 \r
1196 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1197 {\r
1198     if (is_classifier)\r
1199         get_cat_var_data( n, var_count, labels_buf, labels );\r
1200 }\r
1201 \r
1202 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )\r
1203 {\r
1204     get_cat_var_data( n, get_work_var_count(), indices_buf, indices );\r
1205 }\r
1206 \r
1207 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)\r
1208 {\r
1209     int sample_count = n->sample_count;\r
1210     int* indices_buf = sample_idx_buf;\r
1211     const int* indices = 0;\r
1212 \r
1213     int r_cols = responses->cols;\r
1214 \r
1215     get_sample_indices(n, indices_buf, &indices);\r
1216 \r
1217     \r
1218     for( int i = 0; i < sample_count && \r
1219         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )\r
1220     {\r
1221         int idx = indices[i];\r
1222         values_buf[i] = *(responses->data.fl + idx * r_cols);\r
1223     }\r
1224     \r
1225     *values = values_buf;    \r
1226 }\r
1227 \r
1228 \r
1229 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )\r
1230 {\r
1231     if (have_labels)\r
1232         get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );\r
1233 }\r
1234 \r
1235 \r
1236 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )\r
1237 {\r
1238     if( !is_buf_16u )\r
1239         *cat_values = buf->data.i + n->buf_idx*buf->cols + \r
1240         vi*sample_count + n->offset;\r
1241     else {\r
1242         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + \r
1243             vi*sample_count + n->offset);\r
1244         for( int i = 0; i < n->sample_count; i++ )\r
1245             cat_values_buf[i] = short_values[i];\r
1246         *cat_values = cat_values_buf;\r
1247     }\r
1248 \r
1249     return 0; //TODO: return the number of non-missing values\r
1250 }\r
1251 \r
1252 \r
1253 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )\r
1254 {\r
1255     int idx = n->buf_idx + 1;\r
1256     if( idx >= buf_count )\r
1257         idx = shared ? 1 : 0;\r
1258     return idx;\r
1259 }\r
1260 \r
1261 \r
1262 void CvDTreeTrainData::write_params( CvFileStorage* fs )\r
1263 {\r
1264     CV_FUNCNAME( "CvDTreeTrainData::write_params" );\r
1265 \r
1266     __BEGIN__;\r
1267 \r
1268     int vi, vcount = var_count;\r
1269 \r
1270     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );\r
1271     cvWriteInt( fs, "var_all", var_all );\r
1272     cvWriteInt( fs, "var_count", var_count );\r
1273     cvWriteInt( fs, "ord_var_count", ord_var_count );\r
1274     cvWriteInt( fs, "cat_var_count", cat_var_count );\r
1275 \r
1276     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );\r
1277     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );\r
1278 \r
1279     if( is_classifier )\r
1280     {\r
1281         cvWriteInt( fs, "max_categories", params.max_categories );\r
1282     }\r
1283     else\r
1284     {\r
1285         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );\r
1286     }\r
1287 \r
1288     cvWriteInt( fs, "max_depth", params.max_depth );\r
1289     cvWriteInt( fs, "min_sample_count", params.min_sample_count );\r
1290     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );\r
1291 \r
1292     if( params.cv_folds > 1 )\r
1293     {\r
1294         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );\r
1295         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );\r
1296     }\r
1297 \r
1298     if( priors )\r
1299         cvWrite( fs, "priors", priors );\r
1300 \r
1301     cvEndWriteStruct( fs );\r
1302 \r
1303     if( var_idx )\r
1304         cvWrite( fs, "var_idx", var_idx );\r
1305 \r
1306     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );\r
1307 \r
1308     for( vi = 0; vi < vcount; vi++ )\r
1309         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );\r
1310 \r
1311     cvEndWriteStruct( fs );\r
1312 \r
1313     if( cat_count && (cat_var_count > 0 || is_classifier) )\r
1314     {\r
1315         CV_ASSERT( cat_count != 0 );\r
1316         cvWrite( fs, "cat_count", cat_count );\r
1317         cvWrite( fs, "cat_map", cat_map );\r
1318     }\r
1319 \r
1320     __END__;\r
1321 }\r
1322 \r
1323 \r
1324 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )\r
1325 {\r
1326     CV_FUNCNAME( "CvDTreeTrainData::read_params" );\r
1327 \r
1328     __BEGIN__;\r
1329 \r
1330     CvFileNode *tparams_node, *vartype_node;\r
1331     CvSeqReader reader;\r
1332     int vi, max_split_size, tree_block_size;\r
1333 \r
1334     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);\r
1335     var_all = cvReadIntByName( fs, node, "var_all" );\r
1336     var_count = cvReadIntByName( fs, node, "var_count", var_all );\r
1337     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );\r
1338     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );\r
1339 \r
1340     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );\r
1341 \r
1342     if( tparams_node ) // training parameters are not necessary\r
1343     {\r
1344         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;\r
1345 \r
1346         if( is_classifier )\r
1347         {\r
1348             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );\r
1349         }\r
1350         else\r
1351         {\r
1352             params.regression_accuracy =\r
1353                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );\r
1354         }\r
1355 \r
1356         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );\r
1357         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );\r
1358         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );\r
1359 \r
1360         if( params.cv_folds > 1 )\r
1361         {\r
1362             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;\r
1363             params.truncate_pruned_tree =\r
1364                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;\r
1365         }\r
1366 \r
1367         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );\r
1368         if( priors )\r
1369         {\r
1370             if( !CV_IS_MAT(priors) )\r
1371                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );\r
1372             priors_mult = cvCloneMat( priors );\r
1373         }\r
1374     }\r
1375 \r
1376     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));\r
1377     if( var_idx )\r
1378     {\r
1379         if( !CV_IS_MAT(var_idx) ||\r
1380             (var_idx->cols != 1 && var_idx->rows != 1) ||\r
1381             var_idx->cols + var_idx->rows - 1 != var_count ||\r
1382             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )\r
1383             CV_ERROR( CV_StsParseError,\r
1384                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );\r
1385 \r
1386         for( vi = 0; vi < var_count; vi++ )\r
1387             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )\r
1388                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );\r
1389     }\r
1390 \r
1391     ////// read var type\r
1392     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));\r
1393 \r
1394     cat_var_count = 0;\r
1395     ord_var_count = -1;\r
1396     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );\r
1397 \r
1398     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )\r
1399         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;\r
1400     else\r
1401     {\r
1402         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||\r
1403             vartype_node->data.seq->total != var_count )\r
1404             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1405 \r
1406         cvStartReadSeq( vartype_node->data.seq, &reader );\r
1407 \r
1408         for( vi = 0; vi < var_count; vi++ )\r
1409         {\r
1410             CvFileNode* n = (CvFileNode*)reader.ptr;\r
1411             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )\r
1412                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );\r
1413             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;\r
1414             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
1415         }\r
1416     }\r
1417     var_type->data.i[var_count] = cat_var_count;\r
1418 \r
1419     ord_var_count = ~ord_var_count;\r
1420     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )\r
1421         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );\r
1422     //////\r
1423 \r
1424     if( cat_var_count > 0 || is_classifier )\r
1425     {\r
1426         int ccount, total_c_count = 0;\r
1427         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));\r
1428         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));\r
1429 \r
1430         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||\r
1431             (cat_count->cols != 1 && cat_count->rows != 1) ||\r
1432             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||\r
1433             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||\r
1434             (cat_map->cols != 1 && cat_map->rows != 1) ||\r
1435             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )\r
1436             CV_ERROR( CV_StsParseError,\r
1437             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );\r
1438 \r
1439         ccount = cat_var_count + is_classifier;\r
1440 \r
1441         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));\r
1442         cat_ofs->data.i[0] = 0;\r
1443         max_c_count = 1;\r
1444 \r
1445         for( vi = 0; vi < ccount; vi++ )\r
1446         {\r
1447             int val = cat_count->data.i[vi];\r
1448             if( val <= 0 )\r
1449                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );\r
1450             max_c_count = MAX( max_c_count, val );\r
1451             cat_ofs->data.i[vi+1] = total_c_count += val;\r
1452         }\r
1453 \r
1454         if( cat_map->cols + cat_map->rows - 1 != total_c_count )\r
1455             CV_ERROR( CV_StsBadSize,\r
1456             "cat_map vector length is not equal to the total number of categories in all categorical vars" );\r
1457     }\r
1458 \r
1459     max_split_size = cvAlign(sizeof(CvDTreeSplit) +\r
1460         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));\r
1461 \r
1462     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);\r
1463     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);\r
1464     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));\r
1465     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),\r
1466             sizeof(CvDTreeNode), tree_storage ));\r
1467     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),\r
1468             max_split_size, tree_storage ));\r
1469 \r
1470     __END__;\r
1471 }\r
1472 /////////////////////// Decision Tree /////////////////////////\r
1473 \r
1474 CvDTree::CvDTree()\r
1475 {\r
1476     data = 0;\r
1477     var_importance = 0;\r
1478     default_model_name = "my_tree";\r
1479 \r
1480     clear();\r
1481 }\r
1482 \r
1483 \r
1484 void CvDTree::clear()\r
1485 {\r
1486     cvReleaseMat( &var_importance );\r
1487     if( data )\r
1488     {\r
1489         if( !data->shared )\r
1490             delete data;\r
1491         else\r
1492             free_tree();\r
1493         data = 0;\r
1494     }\r
1495     root = 0;\r
1496     pruned_tree_idx = -1;\r
1497 }\r
1498 \r
1499 \r
1500 CvDTree::~CvDTree()\r
1501 {\r
1502     clear();\r
1503 }\r
1504 \r
1505 \r
1506 const CvDTreeNode* CvDTree::get_root() const\r
1507 {\r
1508     return root;\r
1509 }\r
1510 \r
1511 \r
1512 int CvDTree::get_pruned_tree_idx() const\r
1513 {\r
1514     return pruned_tree_idx;\r
1515 }\r
1516 \r
1517 \r
1518 CvDTreeTrainData* CvDTree::get_data()\r
1519 {\r
1520     return data;\r
1521 }\r
1522 \r
1523 \r
1524 bool CvDTree::train( const CvMat* _train_data, int _tflag,\r
1525                      const CvMat* _responses, const CvMat* _var_idx,\r
1526                      const CvMat* _sample_idx, const CvMat* _var_type,\r
1527                      const CvMat* _missing_mask, CvDTreeParams _params )\r
1528 {\r
1529     bool result = false;\r
1530 \r
1531     CV_FUNCNAME( "CvDTree::train" );\r
1532 \r
1533     __BEGIN__;\r
1534 \r
1535     clear();\r
1536     data = new CvDTreeTrainData( _train_data, _tflag, _responses,\r
1537                                  _var_idx, _sample_idx, _var_type,\r
1538                                  _missing_mask, _params, false );\r
1539     CV_CALL( result = do_train(0));\r
1540 \r
1541     __END__;\r
1542 \r
1543     return result;\r
1544 }\r
1545 \r
1546 \r
1547 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )\r
1548 {\r
1549     bool result = false;\r
1550 \r
1551     CV_FUNCNAME( "CvDTree::train" );\r
1552 \r
1553     __BEGIN__;\r
1554 \r
1555     clear();\r
1556     data = _data;\r
1557     data->shared = true;\r
1558     CV_CALL( result = do_train(_subsample_idx));\r
1559 \r
1560     __END__;\r
1561 \r
1562     return result;\r
1563 }\r
1564 \r
1565 \r
1566 bool CvDTree::do_train( const CvMat* _subsample_idx )\r
1567 {\r
1568     bool result = false;\r
1569 \r
1570     CV_FUNCNAME( "CvDTree::do_train" );\r
1571 \r
1572     __BEGIN__;\r
1573 \r
1574     root = data->subsample_data( _subsample_idx );\r
1575 \r
1576     CV_CALL( try_split_node(root));\r
1577 \r
1578     if( data->params.cv_folds > 0 )\r
1579         CV_CALL( prune_cv());\r
1580 \r
1581     if( !data->shared )\r
1582         data->free_train_data();\r
1583 \r
1584     result = true;\r
1585 \r
1586     __END__;\r
1587 \r
1588     return result;\r
1589 }\r
1590 \r
1591 \r
1592 void CvDTree::try_split_node( CvDTreeNode* node )\r
1593 {\r
1594     CvDTreeSplit* best_split = 0;\r
1595     int i, n = node->sample_count, vi;\r
1596     bool can_split = true;\r
1597     double quality_scale;\r
1598 \r
1599     calc_node_value( node );\r
1600 \r
1601     if( node->sample_count <= data->params.min_sample_count ||\r
1602         node->depth >= data->params.max_depth )\r
1603         can_split = false;\r
1604 \r
1605     if( can_split && data->is_classifier )\r
1606     {\r
1607         // check if we have a "pure" node,\r
1608         // we assume that cls_count is filled by calc_node_value()\r
1609         int* cls_count = data->counts->data.i;\r
1610         int nz = 0, m = data->get_num_classes();\r
1611         for( i = 0; i < m; i++ )\r
1612             nz += cls_count[i] != 0;\r
1613         if( nz == 1 ) // there is only one class\r
1614             can_split = false;\r
1615     }\r
1616     else if( can_split )\r
1617     {\r
1618         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )\r
1619             can_split = false;\r
1620     }\r
1621 \r
1622     if( can_split )\r
1623     {\r
1624         best_split = find_best_split(node);\r
1625         // TODO: check the split quality ...\r
1626         node->split = best_split;\r
1627     }\r
1628 \r
1629     if( !can_split || !best_split )\r
1630     {\r
1631         data->free_node_data(node);\r
1632         return;\r
1633     }\r
1634 \r
1635     quality_scale = calc_node_dir( node );\r
1636 \r
1637     if( data->params.use_surrogates )\r
1638     {\r
1639         // find all the surrogate splits\r
1640         // and sort them by their similarity to the primary one\r
1641         for( vi = 0; vi < data->var_count; vi++ )\r
1642         {\r
1643             CvDTreeSplit* split;\r
1644             int ci = data->get_var_type(vi);\r
1645 \r
1646             if( vi == best_split->var_idx )\r
1647                 continue;\r
1648 \r
1649             if( ci >= 0 )\r
1650                 split = find_surrogate_split_cat( node, vi );\r
1651             else\r
1652                 split = find_surrogate_split_ord( node, vi );\r
1653 \r
1654             if( split )\r
1655             {\r
1656                 // insert the split\r
1657                 CvDTreeSplit* prev_split = node->split;\r
1658                 split->quality = (float)(split->quality*quality_scale);\r
1659 \r
1660                 while( prev_split->next &&\r
1661                        prev_split->next->quality > split->quality )\r
1662                     prev_split = prev_split->next;\r
1663                 split->next = prev_split->next;\r
1664                 prev_split->next = split;\r
1665             }\r
1666         }\r
1667     }\r
1668 \r
1669     split_node_data( node );\r
1670     try_split_node( node->left );\r
1671     try_split_node( node->right );\r
1672 }\r
1673 \r
1674 \r
1675 // calculate direction (left(-1),right(1),missing(0))\r
1676 // for each sample using the best split\r
1677 // the function returns scale coefficients for surrogate split quality factors.\r
1678 // the scale is applied to normalize surrogate split quality relatively to the\r
1679 // best (primary) split quality. That is, if a surrogate split is absolutely\r
1680 // identical to the primary split, its quality will be set to the maximum value =\r
1681 // quality of the primary split; otherwise, it will be lower.\r
1682 // besides, the function compute node->maxlr,\r
1683 // minimum possible quality (w/o considering the above mentioned scale)\r
1684 // for a surrogate split. Surrogate splits with quality less than node->maxlr\r
1685 // are not discarded.\r
1686 double CvDTree::calc_node_dir( CvDTreeNode* node )\r
1687 {\r
1688     char* dir = (char*)data->direction->data.ptr;\r
1689     int i, n = node->sample_count, vi = node->split->var_idx;\r
1690     double L, R;\r
1691 \r
1692     assert( !node->split->inversed );\r
1693 \r
1694     if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
1695     {\r
1696         int* labels_buf = data->pred_int_buf;\r
1697         const int* labels = 0;\r
1698         const int* subset = node->split->subset;\r
1699         data->get_cat_var_data( node, vi, labels_buf, &labels );\r
1700         if( !data->have_priors )\r
1701         {\r
1702             int sum = 0, sum_abs = 0;\r
1703 \r
1704             for( i = 0; i < n; i++ )\r
1705             {\r
1706                 int idx = labels[i];\r
1707                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?\r
1708                     CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1709                 sum += d; sum_abs += d & 1;\r
1710                 dir[i] = (char)d;\r
1711             }\r
1712 \r
1713             R = (sum_abs + sum) >> 1;\r
1714             L = (sum_abs - sum) >> 1;\r
1715         }\r
1716         else\r
1717         {\r
1718             const double* priors = data->priors_mult->data.db;\r
1719             double sum = 0, sum_abs = 0;\r
1720             int *responses_buf = data->resp_int_buf;\r
1721             const int* responses;\r
1722             data->get_class_labels(node, responses_buf, &responses);\r
1723 \r
1724             for( i = 0; i < n; i++ )\r
1725             {\r
1726                 int idx = labels[i];\r
1727                 double w = priors[responses[i]];\r
1728                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;\r
1729                 sum += d*w; sum_abs += (d & 1)*w;\r
1730                 dir[i] = (char)d;\r
1731             }\r
1732 \r
1733             R = (sum_abs + sum) * 0.5;\r
1734             L = (sum_abs - sum) * 0.5;\r
1735         }\r
1736     }\r
1737     else // split on ordered var\r
1738     {\r
1739         int split_point = node->split->ord.split_point;\r
1740         int n1 = node->get_num_valid(vi);\r
1741         \r
1742         float* val_buf = data->pred_float_buf;\r
1743         const float* val = 0;\r
1744         int* sorted_buf = data->pred_int_buf;\r
1745         const int* sorted = 0;\r
1746         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);\r
1747 \r
1748         assert( 0 <= split_point && split_point < n1-1 );\r
1749 \r
1750         if( !data->have_priors )\r
1751         {\r
1752             for( i = 0; i <= split_point; i++ )\r
1753                 dir[sorted[i]] = (char)-1;\r
1754             for( ; i < n1; i++ )\r
1755                 dir[sorted[i]] = (char)1;\r
1756             for( ; i < n; i++ )\r
1757                 dir[sorted[i]] = (char)0;\r
1758 \r
1759             L = split_point-1;\r
1760             R = n1 - split_point + 1;\r
1761         }\r
1762         else\r
1763         {\r
1764             const double* priors = data->priors_mult->data.db;\r
1765             int* responses_buf = data->resp_int_buf;\r
1766             const int* responses = 0;\r
1767             data->get_class_labels(node, responses_buf, &responses);\r
1768             L = R = 0;\r
1769 \r
1770             for( i = 0; i <= split_point; i++ )\r
1771             {\r
1772                 int idx = sorted[i];\r
1773                 double w = priors[responses[idx]];\r
1774                 dir[idx] = (char)-1;\r
1775                 L += w;\r
1776             }\r
1777 \r
1778             for( ; i < n1; i++ )\r
1779             {\r
1780                 int idx = sorted[i];\r
1781                 double w = priors[responses[idx]];\r
1782                 dir[idx] = (char)1;\r
1783                 R += w;\r
1784             }\r
1785 \r
1786             for( ; i < n; i++ )\r
1787                 dir[sorted[i]] = (char)0;\r
1788         }\r
1789     }\r
1790 \r
1791     node->maxlr = MAX( L, R );\r
1792     return node->split->quality/(L + R);\r
1793 }\r
1794 \r
1795 \r
1796 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )\r
1797 {\r
1798     int vi;\r
1799     CvDTreeSplit *best_split = 0, *split = 0, *t;\r
1800 \r
1801     for( vi = 0; vi < data->var_count; vi++ )\r
1802     {\r
1803         int ci = data->get_var_type(vi);\r
1804         if( node->get_num_valid(vi) <= 1 )\r
1805             continue;\r
1806 \r
1807         if( data->is_classifier )\r
1808         {\r
1809             if( ci >= 0 )\r
1810                 split = find_split_cat_class( node, vi );\r
1811             else\r
1812                 split = find_split_ord_class( node, vi );\r
1813         }\r
1814         else\r
1815         {\r
1816             if( ci >= 0 )\r
1817                 split = find_split_cat_reg( node, vi );\r
1818             else\r
1819                 split = find_split_ord_reg( node, vi );\r
1820         }\r
1821 \r
1822         if( split )\r
1823         {\r
1824             if( !best_split || best_split->quality < split->quality )\r
1825                 CV_SWAP( best_split, split, t );\r
1826             if( split )\r
1827                 cvSetRemoveByPtr( data->split_heap, split );\r
1828         }\r
1829     }\r
1830 \r
1831     return best_split;\r
1832 }\r
1833 \r
1834 \r
1835 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )\r
1836 {\r
1837     const float epsilon = FLT_EPSILON*2;\r
1838     int n = node->sample_count;\r
1839     int n1 = node->get_num_valid(vi);\r
1840     int m = data->get_num_classes();\r
1841 \r
1842     float* values_buf = data->pred_float_buf;\r
1843     const float* values = 0;\r
1844     int* indices_buf = data->pred_int_buf;\r
1845     const int* indices = 0;\r
1846     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
1847     int* responses_buf =  data->resp_int_buf;\r
1848     const int* responses = 0;\r
1849     data->get_class_labels( node, responses_buf, &responses );\r
1850 \r
1851     const int* rc0 = data->counts->data.i;\r
1852     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
1853     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
1854     int i, best_i = -1;\r
1855     double lsum2 = 0, rsum2 = 0, best_val = 0;\r
1856     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
1857 \r
1858     // init arrays of class instance counters on both sides of the split\r
1859     for( i = 0; i < m; i++ )\r
1860     {\r
1861         lc[i] = 0;\r
1862         rc[i] = rc0[i];\r
1863     }\r
1864 \r
1865     // compensate for missing values\r
1866     for( i = n1; i < n; i++ )\r
1867     {\r
1868         rc[responses[indices[i]]]--;\r
1869     }\r
1870 \r
1871     if( !priors )\r
1872     {\r
1873         int L = 0, R = n1;\r
1874 \r
1875         for( i = 0; i < m; i++ )\r
1876             rsum2 += (double)rc[i]*rc[i];\r
1877 \r
1878         for( i = 0; i < n1 - 1; i++ )\r
1879         {\r
1880             int idx = responses[indices[i]];\r
1881             int lv, rv;\r
1882             L++; R--;\r
1883             lv = lc[idx]; rv = rc[idx];\r
1884             lsum2 += lv*2 + 1;\r
1885             rsum2 -= rv*2 - 1;\r
1886             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1887 \r
1888             if( values[i] + epsilon < values[i+1] )\r
1889             {\r
1890                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1891                 if( best_val < val )\r
1892                 {\r
1893                     best_val = val;\r
1894                     best_i = i;\r
1895                 }\r
1896             }\r
1897         }\r
1898     }\r
1899     else\r
1900     {\r
1901         double L = 0, R = 0;\r
1902         for( i = 0; i < m; i++ )\r
1903         {\r
1904             double wv = rc[i]*priors[i];\r
1905             R += wv;\r
1906             rsum2 += wv*wv;\r
1907         }\r
1908 \r
1909         for( i = 0; i < n1 - 1; i++ )\r
1910         {\r
1911             int idx = responses[indices[i]];\r
1912             int lv, rv;\r
1913             double p = priors[idx], p2 = p*p;\r
1914             L += p; R -= p;\r
1915             lv = lc[idx]; rv = rc[idx];\r
1916             lsum2 += p2*(lv*2 + 1);\r
1917             rsum2 -= p2*(rv*2 - 1);\r
1918             lc[idx] = lv + 1; rc[idx] = rv - 1;\r
1919 \r
1920             if( values[i] + epsilon < values[i+1] )\r
1921             {\r
1922                 double val = (lsum2*R + rsum2*L)/((double)L*R);\r
1923                 if( best_val < val )\r
1924                 {\r
1925                     best_val = val;\r
1926                     best_i = i;\r
1927                 }\r
1928             }\r
1929         }\r
1930     }\r
1931 \r
1932     return best_i >= 0 ? data->new_split_ord( vi,\r
1933         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
1934         0, (float)best_val ) : 0;\r
1935 }\r
1936 \r
1937 \r
1938 void CvDTree::cluster_categories( const int* vectors, int n, int m,\r
1939                                 int* csums, int k, int* labels )\r
1940 {\r
1941     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm\r
1942     int iters = 0, max_iters = 100;\r
1943     int i, j, idx;\r
1944     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );\r
1945     double *v_weights = buf, *c_weights = buf + k;\r
1946     bool modified = true;\r
1947     CvRNG* r = &data->rng;\r
1948 \r
1949     // assign labels randomly\r
1950     for( i = idx = 0; i < n; i++ )\r
1951     {\r
1952         int sum = 0;\r
1953         const int* v = vectors + i*m;\r
1954         labels[i] = idx++;\r
1955         idx &= idx < k ? -1 : 0;\r
1956 \r
1957         // compute weight of each vector\r
1958         for( j = 0; j < m; j++ )\r
1959             sum += v[j];\r
1960         v_weights[i] = sum ? 1./sum : 0.;\r
1961     }\r
1962 \r
1963     for( i = 0; i < n; i++ )\r
1964     {\r
1965         int i1 = cvRandInt(r) % n;\r
1966         int i2 = cvRandInt(r) % n;\r
1967         CV_SWAP( labels[i1], labels[i2], j );\r
1968     }\r
1969 \r
1970     for( iters = 0; iters <= max_iters; iters++ )\r
1971     {\r
1972         // calculate csums\r
1973         for( i = 0; i < k; i++ )\r
1974         {\r
1975             for( j = 0; j < m; j++ )\r
1976                 csums[i*m + j] = 0;\r
1977         }\r
1978 \r
1979         for( i = 0; i < n; i++ )\r
1980         {\r
1981             const int* v = vectors + i*m;\r
1982             int* s = csums + labels[i]*m;\r
1983             for( j = 0; j < m; j++ )\r
1984                 s[j] += v[j];\r
1985         }\r
1986 \r
1987         // exit the loop here, when we have up-to-date csums\r
1988         if( iters == max_iters || !modified )\r
1989             break;\r
1990 \r
1991         modified = false;\r
1992 \r
1993         // calculate weight of each cluster\r
1994         for( i = 0; i < k; i++ )\r
1995         {\r
1996             const int* s = csums + i*m;\r
1997             int sum = 0;\r
1998             for( j = 0; j < m; j++ )\r
1999                 sum += s[j];\r
2000             c_weights[i] = sum ? 1./sum : 0;\r
2001         }\r
2002 \r
2003         // now for each vector determine the closest cluster\r
2004         for( i = 0; i < n; i++ )\r
2005         {\r
2006             const int* v = vectors + i*m;\r
2007             double alpha = v_weights[i];\r
2008             double min_dist2 = DBL_MAX;\r
2009             int min_idx = -1;\r
2010 \r
2011             for( idx = 0; idx < k; idx++ )\r
2012             {\r
2013                 const int* s = csums + idx*m;\r
2014                 double dist2 = 0., beta = c_weights[idx];\r
2015                 for( j = 0; j < m; j++ )\r
2016                 {\r
2017                     double t = v[j]*alpha - s[j]*beta;\r
2018                     dist2 += t*t;\r
2019                 }\r
2020                 if( min_dist2 > dist2 )\r
2021                 {\r
2022                     min_dist2 = dist2;\r
2023                     min_idx = idx;\r
2024                 }\r
2025             }\r
2026 \r
2027             if( min_idx != labels[i] )\r
2028                 modified = true;\r
2029             labels[i] = min_idx;\r
2030         }\r
2031     }\r
2032 }\r
2033 \r
2034 \r
2035 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )\r
2036 {\r
2037     CvDTreeSplit* split;\r
2038     int ci = data->get_var_type(vi);\r
2039     int n = node->sample_count;\r
2040     int m = data->get_num_classes();\r
2041     int _mi = data->cat_count->data.i[ci], mi = _mi;\r
2042 \r
2043     int* labels_buf = data->pred_int_buf;\r
2044     const int* labels = 0;\r
2045     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2046     int *responses_buf = data->resp_int_buf;\r
2047     const int* responses = 0;\r
2048     data->get_class_labels(node, responses_buf, &responses);\r
2049 \r
2050     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));\r
2051     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));\r
2052     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;\r
2053     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );\r
2054     int* cluster_labels = 0;\r
2055     int** int_ptr = 0;\r
2056     int i, j, k, idx;\r
2057     double L = 0, R = 0;\r
2058     double best_val = 0;\r
2059     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;\r
2060     const double* priors = data->priors_mult->data.db;\r
2061 \r
2062     // init array of counters:\r
2063     // c_{jk} - number of samples that have vi-th input variable = j and response = k.\r
2064     for( j = -1; j < mi; j++ )\r
2065         for( k = 0; k < m; k++ )\r
2066             cjk[j*m + k] = 0;\r
2067 \r
2068     for( i = 0; i < n; i++ )\r
2069     {\r
2070         j = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2071         k = responses[i];\r
2072         cjk[j*m + k]++;\r
2073     }\r
2074 \r
2075     if( m > 2 )\r
2076     {\r
2077         if( mi > data->params.max_categories )\r
2078         {\r
2079             mi = MIN(data->params.max_categories, n);\r
2080             cjk += _mi*m;\r
2081             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));\r
2082             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );\r
2083         }\r
2084         subset_i = 1;\r
2085         subset_n = 1 << mi;\r
2086     }\r
2087     else\r
2088     {\r
2089         assert( m == 2 );\r
2090         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );\r
2091         for( j = 0; j < mi; j++ )\r
2092             int_ptr[j] = cjk + j*2 + 1;\r
2093         icvSortIntPtr( int_ptr, mi, 0 );\r
2094         subset_i = 0;\r
2095         subset_n = mi;\r
2096     }\r
2097 \r
2098     for( k = 0; k < m; k++ )\r
2099     {\r
2100         int sum = 0;\r
2101         for( j = 0; j < mi; j++ )\r
2102             sum += cjk[j*m + k];\r
2103         rc[k] = sum;\r
2104         lc[k] = 0;\r
2105     }\r
2106 \r
2107     for( j = 0; j < mi; j++ )\r
2108     {\r
2109         double sum = 0;\r
2110         for( k = 0; k < m; k++ )\r
2111             sum += cjk[j*m + k]*priors[k];\r
2112         c_weights[j] = sum;\r
2113         R += c_weights[j];\r
2114     }\r
2115 \r
2116     for( ; subset_i < subset_n; subset_i++ )\r
2117     {\r
2118         double weight;\r
2119         int* crow;\r
2120         double lsum2 = 0, rsum2 = 0;\r
2121 \r
2122         if( m == 2 )\r
2123             idx = (int)(int_ptr[subset_i] - cjk)/2;\r
2124         else\r
2125         {\r
2126             int graycode = (subset_i>>1)^subset_i;\r
2127             int diff = graycode ^ prevcode;\r
2128 \r
2129             // determine index of the changed bit.\r
2130             Cv32suf u;\r
2131             idx = diff >= (1 << 16) ? 16 : 0;\r
2132             u.f = (float)(((diff >> 16) | diff) & 65535);\r
2133             idx += (u.i >> 23) - 127;\r
2134             subtract = graycode < prevcode;\r
2135             prevcode = graycode;\r
2136         }\r
2137 \r
2138         crow = cjk + idx*m;\r
2139         weight = c_weights[idx];\r
2140         if( weight < FLT_EPSILON )\r
2141             continue;\r
2142 \r
2143         if( !subtract )\r
2144         {\r
2145             for( k = 0; k < m; k++ )\r
2146             {\r
2147                 int t = crow[k];\r
2148                 int lval = lc[k] + t;\r
2149                 int rval = rc[k] - t;\r
2150                 double p = priors[k], p2 = p*p;\r
2151                 lsum2 += p2*lval*lval;\r
2152                 rsum2 += p2*rval*rval;\r
2153                 lc[k] = lval; rc[k] = rval;\r
2154             }\r
2155             L += weight;\r
2156             R -= weight;\r
2157         }\r
2158         else\r
2159         {\r
2160             for( k = 0; k < m; k++ )\r
2161             {\r
2162                 int t = crow[k];\r
2163                 int lval = lc[k] - t;\r
2164                 int rval = rc[k] + t;\r
2165                 double p = priors[k], p2 = p*p;\r
2166                 lsum2 += p2*lval*lval;\r
2167                 rsum2 += p2*rval*rval;\r
2168                 lc[k] = lval; rc[k] = rval;\r
2169             }\r
2170             L -= weight;\r
2171             R += weight;\r
2172         }\r
2173 \r
2174         if( L > FLT_EPSILON && R > FLT_EPSILON )\r
2175         {\r
2176             double val = (lsum2*R + rsum2*L)/((double)L*R);\r
2177             if( best_val < val )\r
2178             {\r
2179                 best_val = val;\r
2180                 best_subset = subset_i;\r
2181             }\r
2182         }\r
2183     }\r
2184 \r
2185     if( best_subset < 0 )\r
2186         return 0;\r
2187 \r
2188     split = data->new_split_cat( vi, (float)best_val );\r
2189 \r
2190     if( m == 2 )\r
2191     {\r
2192         for( i = 0; i <= best_subset; i++ )\r
2193         {\r
2194             idx = (int)(int_ptr[i] - cjk) >> 1;\r
2195             split->subset[idx >> 5] |= 1 << (idx & 31);\r
2196         }\r
2197     }\r
2198     else\r
2199     {\r
2200         for( i = 0; i < _mi; i++ )\r
2201         {\r
2202             idx = cluster_labels ? cluster_labels[i] : i;\r
2203             if( best_subset & (1 << idx) )\r
2204                 split->subset[i >> 5] |= 1 << (i & 31);\r
2205         }\r
2206     }\r
2207 \r
2208     return split;\r
2209 }\r
2210 \r
2211 \r
2212 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )\r
2213 {\r
2214     const float epsilon = FLT_EPSILON*2;\r
2215     int n = node->sample_count;\r
2216     int n1 = node->get_num_valid(vi);\r
2217 \r
2218     float* values_buf = data->pred_float_buf;\r
2219     const float* values = 0;\r
2220     int* indices_buf = data->pred_int_buf;\r
2221     const int* indices = 0;\r
2222     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2223     float* responses_buf =  data->resp_float_buf;\r
2224     const float* responses = 0;\r
2225     data->get_ord_responses( node, responses_buf, &responses );\r
2226 \r
2227     int i, best_i = -1;\r
2228     double best_val = 0, lsum = 0, rsum = node->value*n;\r
2229     int L = 0, R = n1;\r
2230 \r
2231     // compensate for missing values\r
2232     for( i = n1; i < n; i++ )\r
2233         rsum -= responses[indices[i]];\r
2234 \r
2235     // find the optimal split\r
2236     for( i = 0; i < n1 - 1; i++ )\r
2237     {\r
2238         float t = responses[indices[i]];\r
2239         L++; R--;\r
2240         lsum += t;\r
2241         rsum -= t;\r
2242 \r
2243         if( values[i] + epsilon < values[i+1] )\r
2244         {\r
2245             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2246             if( best_val < val )\r
2247             {\r
2248                 best_val = val;\r
2249                 best_i = i;\r
2250             }\r
2251         }\r
2252     }\r
2253 \r
2254     return best_i >= 0 ? data->new_split_ord( vi,\r
2255         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2256         0, (float)best_val ) : 0;\r
2257 }\r
2258 \r
2259 \r
2260 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )\r
2261 {\r
2262     CvDTreeSplit* split;\r
2263     int ci = data->get_var_type(vi);\r
2264     int n = node->sample_count;\r
2265     int mi = data->cat_count->data.i[ci];\r
2266     int* labels_buf = data->pred_int_buf;\r
2267     const int* labels = 0;\r
2268     float* responses_buf = data->resp_float_buf;\r
2269     const float* responses = 0;\r
2270     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2271     data->get_ord_responses(node, responses_buf, &responses);\r
2272 \r
2273     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;\r
2274     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;\r
2275     double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );\r
2276     int i, L = 0, R = 0;\r
2277     double best_val = 0, lsum = 0, rsum = 0;\r
2278     int best_subset = -1, subset_i;\r
2279 \r
2280     for( i = -1; i < mi; i++ )\r
2281         sum[i] = counts[i] = 0;\r
2282 \r
2283     // calculate sum response and weight of each category of the input var\r
2284     for( i = 0; i < n; i++ )\r
2285     {\r
2286         int idx = labels[i];\r
2287         double s = sum[idx] + responses[i];\r
2288         int nc = counts[idx] + 1;\r
2289         sum[idx] = s;\r
2290         counts[idx] = nc;\r
2291     }\r
2292 \r
2293     // calculate average response in each category\r
2294     for( i = 0; i < mi; i++ )\r
2295     {\r
2296         R += counts[i];\r
2297         rsum += sum[i];\r
2298         sum[i] /= MAX(counts[i],1);\r
2299         sum_ptr[i] = sum + i;\r
2300     }\r
2301 \r
2302     icvSortDblPtr( sum_ptr, mi, 0 );\r
2303 \r
2304     // revert back to unnormalized sums\r
2305     // (there should be a very little loss of accuracy)\r
2306     for( i = 0; i < mi; i++ )\r
2307         sum[i] *= counts[i];\r
2308 \r
2309     for( subset_i = 0; subset_i < mi-1; subset_i++ )\r
2310     {\r
2311         int idx = (int)(sum_ptr[subset_i] - sum);\r
2312         int ni = counts[idx];\r
2313 \r
2314         if( ni )\r
2315         {\r
2316             double s = sum[idx];\r
2317             lsum += s; L += ni;\r
2318             rsum -= s; R -= ni;\r
2319 \r
2320             if( L && R )\r
2321             {\r
2322                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);\r
2323                 if( best_val < val )\r
2324                 {\r
2325                     best_val = val;\r
2326                     best_subset = subset_i;\r
2327                 }\r
2328             }\r
2329         }\r
2330     }\r
2331 \r
2332     if( best_subset < 0 )\r
2333         return 0;\r
2334 \r
2335     split = data->new_split_cat( vi, (float)best_val );\r
2336     for( i = 0; i <= best_subset; i++ )\r
2337     {\r
2338         int idx = (int)(sum_ptr[i] - sum);\r
2339         split->subset[idx >> 5] |= 1 << (idx & 31);\r
2340     }\r
2341 \r
2342     return split;\r
2343 }\r
2344 \r
2345 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )\r
2346 {\r
2347     const float epsilon = FLT_EPSILON*2;\r
2348     const char* dir = (char*)data->direction->data.ptr;\r
2349     int n1 = node->get_num_valid(vi);\r
2350     float* values_buf = data->pred_float_buf;\r
2351     const float* values = 0;\r
2352     int* indices_buf = data->pred_int_buf;\r
2353     const int* indices = 0;\r
2354     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2355     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2356     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2357     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2358     // RR - ... both send to the right\r
2359     int i, best_i = -1, best_inversed = 0;\r
2360     double best_val;\r
2361 \r
2362     if( !data->have_priors )\r
2363     {\r
2364         int LL = 0, RL = 0, LR, RR;\r
2365         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;\r
2366         int sum = 0, sum_abs = 0;\r
2367 \r
2368         for( i = 0; i < n1; i++ )\r
2369         {\r
2370             int d = dir[indices[i]];\r
2371             sum += d; sum_abs += d & 1;\r
2372         }\r
2373 \r
2374         // sum_abs = R + L; sum = R - L\r
2375         RR = (sum_abs + sum) >> 1;\r
2376         LR = (sum_abs - sum) >> 1;\r
2377 \r
2378         // initially all the samples are sent to the right by the surrogate split,\r
2379         // LR of them are sent to the left by primary split, and RR - to the right.\r
2380         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2381         for( i = 0; i < n1 - 1; i++ )\r
2382         {\r
2383             int d = dir[indices[i]];\r
2384 \r
2385             if( d < 0 )\r
2386             {\r
2387                 LL++; LR--;\r
2388                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )\r
2389                 {\r
2390                     best_val = LL + RR;\r
2391                     best_i = i; best_inversed = 0;\r
2392                 }\r
2393             }\r
2394             else if( d > 0 )\r
2395             {\r
2396                 RL++; RR--;\r
2397                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )\r
2398                 {\r
2399                     best_val = RL + LR;\r
2400                     best_i = i; best_inversed = 1;\r
2401                 }\r
2402             }\r
2403         }\r
2404         best_val = _best_val;\r
2405     }\r
2406     else\r
2407     {\r
2408         double LL = 0, RL = 0, LR, RR;\r
2409         double worst_val = node->maxlr;\r
2410         double sum = 0, sum_abs = 0;\r
2411         const double* priors = data->priors_mult->data.db;\r
2412         int* responses_buf = data->resp_int_buf;\r
2413         const int* responses = 0;\r
2414         data->get_class_labels(node, responses_buf, &responses);\r
2415         best_val = worst_val;\r
2416 \r
2417         for( i = 0; i < n1; i++ )\r
2418         {\r
2419             int idx = indices[i];\r
2420             double w = priors[responses[idx]];\r
2421             int d = dir[idx];\r
2422             sum += d*w; sum_abs += (d & 1)*w;\r
2423         }\r
2424 \r
2425         // sum_abs = R + L; sum = R - L\r
2426         RR = (sum_abs + sum)*0.5;\r
2427         LR = (sum_abs - sum)*0.5;\r
2428 \r
2429         // initially all the samples are sent to the right by the surrogate split,\r
2430         // LR of them are sent to the left by primary split, and RR - to the right.\r
2431         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.\r
2432         for( i = 0; i < n1 - 1; i++ )\r
2433         {\r
2434             int idx = indices[i];\r
2435             double w = priors[responses[idx]];\r
2436             int d = dir[idx];\r
2437 \r
2438             if( d < 0 )\r
2439             {\r
2440                 LL += w; LR -= w;\r
2441                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )\r
2442                 {\r
2443                     best_val = LL + RR;\r
2444                     best_i = i; best_inversed = 0;\r
2445                 }\r
2446             }\r
2447             else if( d > 0 )\r
2448             {\r
2449                 RL += w; RR -= w;\r
2450                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )\r
2451                 {\r
2452                     best_val = RL + LR;\r
2453                     best_i = i; best_inversed = 1;\r
2454                 }\r
2455             }\r
2456         }\r
2457     }\r
2458 \r
2459     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,\r
2460         (values[best_i] + values[best_i+1])*0.5f, best_i,\r
2461         best_inversed, (float)best_val ) : 0;\r
2462 }\r
2463 \r
2464 \r
2465 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )\r
2466 {\r
2467     const char* dir = (char*)data->direction->data.ptr;\r
2468     int n = node->sample_count;\r
2469     int* labels_buf = data->pred_int_buf;\r
2470     const int* labels = 0;\r
2471     data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2472     // LL - number of samples that both the primary and the surrogate splits send to the left\r
2473     // LR - ... primary split sends to the left and the surrogate split sends to the right\r
2474     // RL - ... primary split sends to the right and the surrogate split sends to the left\r
2475     // RR - ... both send to the right\r
2476     CvDTreeSplit* split = data->new_split_cat( vi, 0 );\r
2477     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;\r
2478     double best_val = 0;\r
2479     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;\r
2480     double* rc = lc + mi + 1;\r
2481 \r
2482     for( i = -1; i < mi; i++ )\r
2483         lc[i] = rc[i] = 0;\r
2484 \r
2485     // for each category calculate the weight of samples\r
2486     // sent to the left (lc) and to the right (rc) by the primary split\r
2487     if( !data->have_priors )\r
2488     {\r
2489         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;\r
2490         int* _rc = _lc + mi + 1;\r
2491 \r
2492         for( i = -1; i < mi; i++ )\r
2493             _lc[i] = _rc[i] = 0;\r
2494 \r
2495         for( i = 0; i < n; i++ )\r
2496         {\r
2497             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];\r
2498             int d = dir[i];\r
2499             int sum = _lc[idx] + d;\r
2500             int sum_abs = _rc[idx] + (d & 1);\r
2501             _lc[idx] = sum; _rc[idx] = sum_abs;\r
2502         }\r
2503 \r
2504         for( i = 0; i < mi; i++ )\r
2505         {\r
2506             int sum = _lc[i];\r
2507             int sum_abs = _rc[i];\r
2508             lc[i] = (sum_abs - sum) >> 1;\r
2509             rc[i] = (sum_abs + sum) >> 1;\r
2510         }\r
2511     }\r
2512     else\r
2513     {\r
2514         const double* priors = data->priors_mult->data.db;\r
2515         int* responses_buf = data->resp_int_buf;\r
2516         const int* responses = 0;\r
2517         data->get_class_labels(node, responses_buf, &responses);\r
2518 \r
2519         for( i = 0; i < n; i++ )\r
2520         {\r
2521             int idx = labels[i];\r
2522             double w = priors[responses[i]];\r
2523             int d = dir[i];\r
2524             double sum = lc[idx] + d*w;\r
2525             double sum_abs = rc[idx] + (d & 1)*w;\r
2526             lc[idx] = sum; rc[idx] = sum_abs;\r
2527         }\r
2528 \r
2529         for( i = 0; i < mi; i++ )\r
2530         {\r
2531             double sum = lc[i];\r
2532             double sum_abs = rc[i];\r
2533             lc[i] = (sum_abs - sum) * 0.5;\r
2534             rc[i] = (sum_abs + sum) * 0.5;\r
2535         }\r
2536     }\r
2537 \r
2538     // 2. now form the split.\r
2539     // in each category send all the samples to the same direction as majority\r
2540     for( i = 0; i < mi; i++ )\r
2541     {\r
2542         double lval = lc[i], rval = rc[i];\r
2543         if( lval > rval )\r
2544         {\r
2545             split->subset[i >> 5] |= 1 << (i & 31);\r
2546             best_val += lval;\r
2547             l_win++;\r
2548         }\r
2549         else\r
2550             best_val += rval;\r
2551     }\r
2552 \r
2553     split->quality = (float)best_val;\r
2554     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )\r
2555         cvSetRemoveByPtr( data->split_heap, split ), split = 0;\r
2556 \r
2557     return split;\r
2558 }\r
2559 \r
2560 \r
2561 void CvDTree::calc_node_value( CvDTreeNode* node )\r
2562 {\r
2563     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;\r
2564     int* cv_labels_buf = data->cv_lables_buf;\r
2565     const int* cv_labels = 0;\r
2566     data->get_cv_labels(node, cv_labels_buf, &cv_labels);\r
2567 \r
2568     if( data->is_classifier )\r
2569     {\r
2570         // in case of classification tree:\r
2571         //  * node value is the label of the class that has the largest weight in the node.\r
2572         //  * node risk is the weighted number of misclassified samples,\r
2573         //  * j-th cross-validation fold value and risk are calculated as above,\r
2574         //    but using the samples with cv_labels(*)!=j.\r
2575         //  * j-th cross-validation fold error is calculated as the weighted number of\r
2576         //    misclassified samples with cv_labels(*)==j.\r
2577 \r
2578         // compute the number of instances of each class\r
2579         int* cls_count = data->counts->data.i;\r
2580         int* responses_buf = data->resp_int_buf;\r
2581         const int* responses = 0;\r
2582         data->get_class_labels(node, responses_buf, &responses);\r
2583         int m = data->get_num_classes();\r
2584         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));\r
2585         double max_val = -1, total_weight = 0;\r
2586         int max_k = -1;\r
2587         double* priors = data->priors_mult->data.db;\r
2588 \r
2589         for( k = 0; k < m; k++ )\r
2590             cls_count[k] = 0;\r
2591 \r
2592         if( cv_n == 0 )\r
2593         {\r
2594             for( i = 0; i < n; i++ )\r
2595                 cls_count[responses[i]]++;\r
2596         }\r
2597         else\r
2598         {\r
2599             for( j = 0; j < cv_n; j++ )\r
2600                 for( k = 0; k < m; k++ )\r
2601                     cv_cls_count[j*m + k] = 0;\r
2602 \r
2603             for( i = 0; i < n; i++ )\r
2604             {\r
2605                 j = cv_labels[i]; k = responses[i];\r
2606                 cv_cls_count[j*m + k]++;\r
2607             }\r
2608 \r
2609             for( j = 0; j < cv_n; j++ )\r
2610                 for( k = 0; k < m; k++ )\r
2611                     cls_count[k] += cv_cls_count[j*m + k];\r
2612         }\r
2613 \r
2614         if( data->have_priors && node->parent == 0 )\r
2615         {\r
2616             // compute priors_mult from priors, take the sample ratio into account.\r
2617             double sum = 0;\r
2618             for( k = 0; k < m; k++ )\r
2619             {\r
2620                 int n_k = cls_count[k];\r
2621                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);\r
2622                 sum += priors[k];\r
2623             }\r
2624             sum = 1./sum;\r
2625             for( k = 0; k < m; k++ )\r
2626                 priors[k] *= sum;\r
2627         }\r
2628 \r
2629         for( k = 0; k < m; k++ )\r
2630         {\r
2631             double val = cls_count[k]*priors[k];\r
2632             total_weight += val;\r
2633             if( max_val < val )\r
2634             {\r
2635                 max_val = val;\r
2636                 max_k = k;\r
2637             }\r
2638         }\r
2639 \r
2640         node->class_idx = max_k;\r
2641         node->value = data->cat_map->data.i[\r
2642             data->cat_ofs->data.i[data->cat_var_count] + max_k];\r
2643         node->node_risk = total_weight - max_val;\r
2644 \r
2645         for( j = 0; j < cv_n; j++ )\r
2646         {\r
2647             double sum_k = 0, sum = 0, max_val_k = 0;\r
2648             max_val = -1; max_k = -1;\r
2649 \r
2650             for( k = 0; k < m; k++ )\r
2651             {\r
2652                 double w = priors[k];\r
2653                 double val_k = cv_cls_count[j*m + k]*w;\r
2654                 double val = cls_count[k]*w - val_k;\r
2655                 sum_k += val_k;\r
2656                 sum += val;\r
2657                 if( max_val < val )\r
2658                 {\r
2659                     max_val = val;\r
2660                     max_val_k = val_k;\r
2661                     max_k = k;\r
2662                 }\r
2663             }\r
2664 \r
2665             node->cv_Tn[j] = INT_MAX;\r
2666             node->cv_node_risk[j] = sum - max_val;\r
2667             node->cv_node_error[j] = sum_k - max_val_k;\r
2668         }\r
2669     }\r
2670     else\r
2671     {\r
2672         // in case of regression tree:\r
2673         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,\r
2674         //    n is the number of samples in the node.\r
2675         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)\r
2676         //  * j-th cross-validation fold value and risk are calculated as above,\r
2677         //    but using the samples with cv_labels(*)!=j.\r
2678         //  * j-th cross-validation fold error is calculated\r
2679         //    using samples with cv_labels(*)==j as the test subset:\r
2680         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),\r
2681         //    where node_value_j is the node value calculated\r
2682         //    as described in the previous bullet, and summation is done\r
2683         //    over the samples with cv_labels(*)==j.\r
2684 \r
2685         double sum = 0, sum2 = 0;\r
2686         float* values_buf = data->resp_float_buf;\r
2687         const float* values = 0;\r
2688         data->get_ord_responses(node, values_buf, &values);\r
2689         double *cv_sum = 0, *cv_sum2 = 0;\r
2690         int* cv_count = 0;\r
2691 \r
2692         if( cv_n == 0 )\r
2693         {\r
2694             for( i = 0; i < n; i++ )\r
2695             {\r
2696                 double t = values[i];\r
2697                 sum += t;\r
2698                 sum2 += t*t;\r
2699             }\r
2700         }\r
2701         else\r
2702         {\r
2703             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );\r
2704             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );\r
2705             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );\r
2706 \r
2707             for( j = 0; j < cv_n; j++ )\r
2708             {\r
2709                 cv_sum[j] = cv_sum2[j] = 0.;\r
2710                 cv_count[j] = 0;\r
2711             }\r
2712 \r
2713             for( i = 0; i < n; i++ )\r
2714             {\r
2715                 j = cv_labels[i];\r
2716                 double t = values[i];\r
2717                 double s = cv_sum[j] + t;\r
2718                 double s2 = cv_sum2[j] + t*t;\r
2719                 int nc = cv_count[j] + 1;\r
2720                 cv_sum[j] = s;\r
2721                 cv_sum2[j] = s2;\r
2722                 cv_count[j] = nc;\r
2723             }\r
2724 \r
2725             for( j = 0; j < cv_n; j++ )\r
2726             {\r
2727                 sum += cv_sum[j];\r
2728                 sum2 += cv_sum2[j];\r
2729             }\r
2730         }\r
2731 \r
2732         node->node_risk = sum2 - (sum/n)*sum;\r
2733         node->value = sum/n;\r
2734 \r
2735         for( j = 0; j < cv_n; j++ )\r
2736         {\r
2737             double s = cv_sum[j], si = sum - s;\r
2738             double s2 = cv_sum2[j], s2i = sum2 - s2;\r
2739             int c = cv_count[j], ci = n - c;\r
2740             double r = si/MAX(ci,1);\r
2741             node->cv_node_risk[j] = s2i - r*r*ci;\r
2742             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;\r
2743             node->cv_Tn[j] = INT_MAX;\r
2744         }\r
2745     }\r
2746 }\r
2747 \r
2748 \r
2749 void CvDTree::complete_node_dir( CvDTreeNode* node )\r
2750 {\r
2751     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;\r
2752     int nz = n - node->get_num_valid(node->split->var_idx);\r
2753     char* dir = (char*)data->direction->data.ptr;\r
2754 \r
2755     // try to complete direction using surrogate splits\r
2756     if( nz && data->params.use_surrogates )\r
2757     {\r
2758         CvDTreeSplit* split = node->split->next;\r
2759         for( ; split != 0 && nz; split = split->next )\r
2760         {\r
2761             int inversed_mask = split->inversed ? -1 : 0;\r
2762             vi = split->var_idx;\r
2763 \r
2764             if( data->get_var_type(vi) >= 0 ) // split on categorical var\r
2765             {\r
2766                 int* labels_buf = data->pred_int_buf;\r
2767                 const int* labels = 0;\r
2768                 data->get_cat_var_data(node, vi, labels_buf, &labels);\r
2769                 const int* subset = split->subset;\r
2770 \r
2771                 for( i = 0; i < n; i++ )\r
2772                 {\r
2773                     int idx = labels[i];\r
2774                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))\r
2775                         \r
2776                     {\r
2777                         int d = CV_DTREE_CAT_DIR(idx,subset);\r
2778                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);\r
2779                         if( --nz )\r
2780                             break;\r
2781                     }\r
2782                 }\r
2783             }\r
2784             else // split on ordered var\r
2785             {\r
2786                 float* values_buf = data->pred_float_buf;\r
2787                 const float* values = 0;\r
2788                 int* indices_buf = data->pred_int_buf;\r
2789                 const int* indices = 0;\r
2790                 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );\r
2791                 int split_point = split->ord.split_point;\r
2792                 int n1 = node->get_num_valid(vi);\r
2793 \r
2794                 assert( 0 <= split_point && split_point < n-1 );\r
2795 \r
2796                 for( i = 0; i < n1; i++ )\r
2797                 {\r
2798                     int idx = indices[i];\r
2799                     if( !dir[idx] )\r
2800                     {\r
2801                         int d = i <= split_point ? -1 : 1;\r
2802                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);\r
2803                         if( --nz )\r
2804                             break;\r
2805                     }\r
2806                 }\r
2807             }\r
2808         }\r
2809     }\r
2810 \r
2811     // find the default direction for the rest\r
2812     if( nz )\r
2813     {\r
2814         for( i = nr = 0; i < n; i++ )\r
2815             nr += dir[i] > 0;\r
2816         nl = n - nr - nz;\r
2817         d0 = nl > nr ? -1 : nr > nl;\r
2818     }\r
2819 \r
2820     // make sure that every sample is directed either to the left or to the right\r
2821     for( i = 0; i < n; i++ )\r
2822     {\r
2823         int d = dir[i];\r
2824         if( !d )\r
2825         {\r
2826             d = d0;\r
2827             if( !d )\r
2828                 d = d1, d1 = -d1;\r
2829         }\r
2830         d = d > 0;\r
2831         dir[i] = (char)d; // remap (-1,1) to (0,1)\r
2832     }\r
2833 }\r
2834 \r
2835 \r
2836 void CvDTree::split_node_data( CvDTreeNode* node )\r
2837 {\r
2838     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;\r
2839     char* dir = (char*)data->direction->data.ptr;\r
2840     CvDTreeNode *left = 0, *right = 0;\r
2841     int* new_idx = data->split_buf->data.i;\r
2842     int new_buf_idx = data->get_child_buf_idx( node );\r
2843     int work_var_count = data->get_work_var_count();\r
2844     CvMat* buf = data->buf;\r
2845     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));\r
2846 \r
2847     complete_node_dir(node);\r
2848 \r
2849     for( i = nl = nr = 0; i < n; i++ )\r
2850     {\r
2851         int d = dir[i];\r
2852         // initialize new indices for splitting ordered variables\r
2853         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li\r
2854         nr += d;\r
2855         nl += d^1;\r
2856     }\r
2857 \r
2858 \r
2859     bool split_input_data;\r
2860     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );\r
2861     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );\r
2862 \r
2863     split_input_data = node->depth + 1 < data->params.max_depth &&\r
2864         (node->left->sample_count > data->params.min_sample_count ||\r
2865         node->right->sample_count > data->params.min_sample_count);\r
2866 \r
2867     // split ordered variables, keep both halves sorted.\r
2868     for( vi = 0; vi < data->var_count; vi++ )\r
2869     {\r
2870         int ci = data->get_var_type(vi);\r
2871         int n1 = node->get_num_valid(vi);\r
2872         int *src_idx_buf = data->pred_int_buf;\r
2873         const int* src_idx = 0;\r
2874         float *src_val_buf = data->pred_float_buf;\r
2875         const float* src_val = 0;\r
2876         \r
2877         if( ci >= 0 || !split_input_data )\r
2878             continue;\r
2879 \r
2880         data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);\r
2881 \r
2882         for(i = 0; i < n; i++)\r
2883             temp_buf[i] = src_idx[i];\r
2884 \r
2885         if (data->is_buf_16u)\r
2886         {\r
2887             unsigned short *ldst, *rdst, *ldst0, *rdst0;\r
2888             //unsigned short tl, tr;\r
2889             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
2890                 vi*scount + left->offset);\r
2891             rdst0 = rdst = (unsigned short*)(ldst + nl);\r
2892 \r
2893             // split sorted\r
2894             for( i = 0; i < n1; i++ )\r
2895             {\r
2896                 int idx = temp_buf[i];\r
2897                 int d = dir[idx];\r
2898                 idx = new_idx[idx];\r
2899                 if (d)\r
2900                 {\r
2901                     *rdst = (unsigned short)idx;\r
2902                     rdst++;\r
2903                 }\r
2904                 else\r
2905                 {\r
2906                     *ldst = (unsigned short)idx;\r
2907                     ldst++;\r
2908                 }\r
2909             }\r
2910 \r
2911             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2912             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2913 \r
2914             // split missing\r
2915             for( ; i < n; i++ )\r
2916             {\r
2917                 int idx = temp_buf[i];\r
2918                 int d = dir[idx];\r
2919                 idx = new_idx[idx];\r
2920                 if (d)\r
2921                 {\r
2922                     *rdst = (unsigned short)idx;\r
2923                     rdst++;\r
2924                 }\r
2925                 else\r
2926                 {\r
2927                     *ldst = (unsigned short)idx;\r
2928                     ldst++;\r
2929                 }\r
2930             }\r
2931         }\r
2932         else\r
2933         {\r
2934             int *ldst0, *ldst, *rdst0, *rdst;\r
2935             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + \r
2936                 vi*scount + left->offset;\r
2937             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + \r
2938                 vi*scount + right->offset;\r
2939 \r
2940             // split sorted\r
2941             for( i = 0; i < n1; i++ )\r
2942             {\r
2943                 int idx = temp_buf[i];\r
2944                 int d = dir[idx];\r
2945                 idx = new_idx[idx];\r
2946                 if (d)\r
2947                 {\r
2948                     *rdst = idx;\r
2949                     rdst++;\r
2950                 }\r
2951                 else\r
2952                 {\r
2953                     *ldst = idx;\r
2954                     ldst++;\r
2955                 }\r
2956             }\r
2957 \r
2958             left->set_num_valid(vi, (int)(ldst - ldst0));\r
2959             right->set_num_valid(vi, (int)(rdst - rdst0));\r
2960 \r
2961             // split missing\r
2962             for( ; i < n; i++ )\r
2963             {\r
2964                 int idx = temp_buf[i];\r
2965                 int d = dir[idx];\r
2966                 idx = new_idx[idx];\r
2967                 if (d)\r
2968                 {\r
2969                     *rdst = idx;\r
2970                     rdst++;\r
2971                 }\r
2972                 else\r
2973                 {\r
2974                     *ldst = idx;\r
2975                     ldst++;\r
2976                 }\r
2977             }\r
2978         }\r
2979     }\r
2980 \r
2981     // split categorical vars, responses and cv_labels using new_idx relocation table\r
2982     for( vi = 0; vi < work_var_count; vi++ )\r
2983     {\r
2984         int ci = data->get_var_type(vi);\r
2985         int n1 = node->get_num_valid(vi), nr1 = 0;\r
2986         \r
2987         if( ci < 0 || (vi < data->var_count && !split_input_data) )\r
2988             continue;\r
2989 \r
2990         int *src_lbls_buf = data->pred_int_buf;\r
2991         const int* src_lbls = 0;\r
2992         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);\r
2993 \r
2994         for(i = 0; i < n; i++)\r
2995             temp_buf[i] = src_lbls[i];\r
2996 \r
2997         if (data->is_buf_16u)\r
2998         {\r
2999             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + \r
3000                 vi*scount + left->offset);\r
3001             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + \r
3002                 vi*scount + right->offset);\r
3003             \r
3004             for( i = 0; i < n; i++ )\r
3005             {\r
3006                 int d = dir[i];\r
3007                 int idx = temp_buf[i];\r
3008                 if (d)\r
3009                 {\r
3010                     *rdst = (unsigned short)idx;\r
3011                     rdst++;\r
3012                     nr1 += (idx != 65535 )&d;\r
3013                 }\r
3014                 else\r
3015                 {\r
3016                     *ldst = (unsigned short)idx;\r
3017                     ldst++;\r
3018                 }\r
3019             }\r
3020 \r
3021             if( vi < data->var_count )\r
3022             {\r
3023                 left->set_num_valid(vi, n1 - nr1);\r
3024                 right->set_num_valid(vi, nr1);\r
3025             }\r
3026         }\r
3027         else\r
3028         {\r
3029             int *ldst = buf->data.i + left->buf_idx*buf->cols + \r
3030                 vi*scount + left->offset;\r
3031             int *rdst = buf->data.i + right->buf_idx*buf->cols + \r
3032                 vi*scount + right->offset;\r
3033             \r
3034             for( i = 0; i < n; i++ )\r
3035             {\r
3036                 int d = dir[i];\r
3037                 int idx = temp_buf[i];\r
3038                 if (d)\r
3039                 {\r
3040                     *rdst = idx;\r
3041                     rdst++;\r
3042                     nr1 += (idx >= 0)&d;\r
3043                 }\r
3044                 else\r
3045                 {\r
3046                     *ldst = idx;\r
3047                     ldst++;\r
3048                 }\r
3049                 \r
3050             }\r
3051 \r
3052             if( vi < data->var_count )\r
3053             {\r
3054                 left->set_num_valid(vi, n1 - nr1);\r
3055                 right->set_num_valid(vi, nr1);\r
3056             }\r
3057         }        \r
3058     }\r
3059 \r
3060 \r
3061     // split sample indices\r
3062     int *sample_idx_src_buf = data->sample_idx_buf;\r
3063     const int* sample_idx_src = 0;\r
3064     data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);\r
3065 \r
3066     for(i = 0; i < n; i++)\r
3067         temp_buf[i] = sample_idx_src[i];\r
3068 \r
3069     int pos = data->get_work_var_count();\r
3070     if (data->is_buf_16u)\r
3071     {\r
3072         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + \r
3073             pos*scount + left->offset);\r
3074         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + \r
3075             pos*scount + right->offset);\r
3076         for (i = 0; i < n; i++)\r
3077         {\r
3078             int d = dir[i];\r
3079             unsigned short idx = (unsigned short)temp_buf[i];\r
3080             if (d)\r
3081             {\r
3082                 *rdst = idx;\r
3083                 rdst++;\r
3084             }\r
3085             else\r
3086             {\r
3087                 *ldst = idx;\r
3088                 ldst++;\r
3089             }\r
3090         }\r
3091     }\r
3092     else\r
3093     {\r
3094         int* ldst = buf->data.i + left->buf_idx*buf->cols + \r
3095             pos*scount + left->offset;\r
3096         int* rdst = buf->data.i + right->buf_idx*buf->cols + \r
3097             pos*scount + right->offset;\r
3098         for (i = 0; i < n; i++)\r
3099         {\r
3100             int d = dir[i];\r
3101             int idx = temp_buf[i];\r
3102             if (d)\r
3103             {\r
3104                 *rdst = idx;\r
3105                 rdst++;\r
3106             }\r
3107             else\r
3108             {\r
3109                 *ldst = idx;\r
3110                 ldst++;\r
3111             }\r
3112         }\r
3113     }\r
3114     \r
3115     // deallocate the parent node data that is not needed anymore\r
3116     data->free_node_data(node);    \r
3117 }\r
3118 \r
3119 \r
3120 void CvDTree::prune_cv()\r
3121 {\r
3122     CvMat* ab = 0;\r
3123     CvMat* temp = 0;\r
3124     CvMat* err_jk = 0;\r
3125 \r
3126     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.\r
3127     // 2. choose the best tree index (if need, apply 1SE rule).\r
3128     // 3. store the best index and cut the branches.\r
3129 \r
3130     CV_FUNCNAME( "CvDTree::prune_cv" );\r
3131 \r
3132     __BEGIN__;\r
3133 \r
3134     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;\r
3135     // currently, 1SE for regression is not implemented\r
3136     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;\r
3137     double* err;\r
3138     double min_err = 0, min_err_se = 0;\r
3139     int min_idx = -1;\r
3140 \r
3141     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));\r
3142 \r
3143     // build the main tree sequence, calculate alpha's\r
3144     for(;;tree_count++)\r
3145     {\r
3146         double min_alpha = update_tree_rnc(tree_count, -1);\r
3147         if( cut_tree(tree_count, -1, min_alpha) )\r
3148             break;\r
3149 \r
3150         if( ab->cols <= tree_count )\r
3151         {\r
3152             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));\r
3153             for( ti = 0; ti < ab->cols; ti++ )\r
3154                 temp->data.db[ti] = ab->data.db[ti];\r
3155             cvReleaseMat( &ab );\r
3156             ab = temp;\r
3157             temp = 0;\r
3158         }\r
3159 \r
3160         ab->data.db[tree_count] = min_alpha;\r
3161     }\r
3162 \r
3163     ab->data.db[0] = 0.;\r
3164 \r
3165     if( tree_count > 0 )\r
3166     {\r
3167         for( ti = 1; ti < tree_count-1; ti++ )\r
3168             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);\r
3169         ab->data.db[tree_count-1] = DBL_MAX*0.5;\r
3170 \r
3171         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));\r
3172         err = err_jk->data.db;\r
3173 \r
3174         for( j = 0; j < cv_n; j++ )\r
3175         {\r
3176             int tj = 0, tk = 0;\r
3177             for( ; tk < tree_count; tj++ )\r
3178             {\r
3179                 double min_alpha = update_tree_rnc(tj, j);\r
3180                 if( cut_tree(tj, j, min_alpha) )\r
3181                     min_alpha = DBL_MAX;\r
3182 \r
3183                 for( ; tk < tree_count; tk++ )\r
3184                 {\r
3185                     if( ab->data.db[tk] > min_alpha )\r
3186                         break;\r
3187                     err[j*tree_count + tk] = root->tree_error;\r
3188                 }\r
3189             }\r
3190         }\r
3191 \r
3192         for( ti = 0; ti < tree_count; ti++ )\r
3193         {\r
3194             double sum_err = 0;\r
3195             for( j = 0; j < cv_n; j++ )\r
3196                 sum_err += err[j*tree_count + ti];\r
3197             if( ti == 0 || sum_err < min_err )\r
3198             {\r
3199                 min_err = sum_err;\r
3200                 min_idx = ti;\r
3201                 if( use_1se )\r
3202                     min_err_se = sqrt( sum_err*(n - sum_err) );\r
3203             }\r
3204             else if( sum_err < min_err + min_err_se )\r
3205                 min_idx = ti;\r
3206         }\r
3207     }\r
3208 \r
3209     pruned_tree_idx = min_idx;\r
3210     free_prune_data(data->params.truncate_pruned_tree != 0);\r
3211 \r
3212     __END__;\r
3213 \r
3214     cvReleaseMat( &err_jk );\r
3215     cvReleaseMat( &ab );\r
3216     cvReleaseMat( &temp );\r
3217 }\r
3218 \r
3219 \r
3220 double CvDTree::update_tree_rnc( int T, int fold )\r
3221 {\r
3222     CvDTreeNode* node = root;\r
3223     double min_alpha = DBL_MAX;\r
3224 \r
3225     for(;;)\r
3226     {\r
3227         CvDTreeNode* parent;\r
3228         for(;;)\r
3229         {\r
3230             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3231             if( t <= T || !node->left )\r
3232             {\r
3233                 node->complexity = 1;\r
3234                 node->tree_risk = node->node_risk;\r
3235                 node->tree_error = 0.;\r
3236                 if( fold >= 0 )\r
3237                 {\r
3238                     node->tree_risk = node->cv_node_risk[fold];\r
3239                     node->tree_error = node->cv_node_error[fold];\r
3240                 }\r
3241                 break;\r
3242             }\r
3243             node = node->left;\r
3244         }\r
3245 \r
3246         for( parent = node->parent; parent && parent->right == node;\r
3247             node = parent, parent = parent->parent )\r
3248         {\r
3249             parent->complexity += node->complexity;\r
3250             parent->tree_risk += node->tree_risk;\r
3251             parent->tree_error += node->tree_error;\r
3252 \r
3253             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)\r
3254                 - parent->tree_risk)/(parent->complexity - 1);\r
3255             min_alpha = MIN( min_alpha, parent->alpha );\r
3256         }\r
3257 \r
3258         if( !parent )\r
3259             break;\r
3260 \r
3261         parent->complexity = node->complexity;\r
3262         parent->tree_risk = node->tree_risk;\r
3263         parent->tree_error = node->tree_error;\r
3264         node = parent->right;\r
3265     }\r
3266 \r
3267     return min_alpha;\r
3268 }\r
3269 \r
3270 \r
3271 int CvDTree::cut_tree( int T, int fold, double min_alpha )\r
3272 {\r
3273     CvDTreeNode* node = root;\r
3274     if( !node->left )\r
3275         return 1;\r
3276 \r
3277     for(;;)\r
3278     {\r
3279         CvDTreeNode* parent;\r
3280         for(;;)\r
3281         {\r
3282             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;\r
3283             if( t <= T || !node->left )\r
3284                 break;\r
3285             if( node->alpha <= min_alpha + FLT_EPSILON )\r
3286             {\r
3287                 if( fold >= 0 )\r
3288                     node->cv_Tn[fold] = T;\r
3289                 else\r
3290                     node->Tn = T;\r
3291                 if( node == root )\r
3292                     return 1;\r
3293                 break;\r
3294             }\r
3295             node = node->left;\r
3296         }\r
3297 \r
3298         for( parent = node->parent; parent && parent->right == node;\r
3299             node = parent, parent = parent->parent )\r
3300             ;\r
3301 \r
3302         if( !parent )\r
3303             break;\r
3304 \r
3305         node = parent->right;\r
3306     }\r
3307 \r
3308     return 0;\r
3309 }\r
3310 \r
3311 \r
3312 void CvDTree::free_prune_data(bool cut_tree)\r
3313 {\r
3314     CvDTreeNode* node = root;\r
3315 \r
3316     for(;;)\r
3317     {\r
3318         CvDTreeNode* parent;\r
3319         for(;;)\r
3320         {\r
3321             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )\r
3322             // as we will clear the whole cross-validation heap at the end\r
3323             node->cv_Tn = 0;\r
3324             node->cv_node_error = node->cv_node_risk = 0;\r
3325             if( !node->left )\r
3326                 break;\r
3327             node = node->left;\r
3328         }\r
3329 \r
3330         for( parent = node->parent; parent && parent->right == node;\r
3331             node = parent, parent = parent->parent )\r
3332         {\r
3333             if( cut_tree && parent->Tn <= pruned_tree_idx )\r
3334             {\r
3335                 data->free_node( parent->left );\r
3336                 data->free_node( parent->right );\r
3337                 parent->left = parent->right = 0;\r
3338             }\r
3339         }\r
3340 \r
3341         if( !parent )\r
3342             break;\r
3343 \r
3344         node = parent->right;\r
3345     }\r
3346 \r
3347     if( data->cv_heap )\r
3348         cvClearSet( data->cv_heap );\r
3349 }\r
3350 \r
3351 \r
3352 void CvDTree::free_tree()\r
3353 {\r
3354     if( root && data && data->shared )\r
3355     {\r
3356         pruned_tree_idx = INT_MIN;\r
3357         free_prune_data(true);\r
3358         data->free_node(root);\r
3359         root = 0;\r
3360     }\r
3361 }\r
3362 \r
3363 \r
3364 #include <conio.h>\r
3365 #include <ctype.h>\r
3366 \r
3367 \r
3368 CvDTreeNode* CvDTree::predict( const CvMat* _sample,\r
3369     const CvMat* _missing, bool preprocessed_input ) const\r
3370 {\r
3371     CvDTreeNode* result = 0;\r
3372     int* catbuf = 0;\r
3373 \r
3374     CV_FUNCNAME( "CvDTree::predict" );\r
3375 \r
3376     __BEGIN__;\r
3377 \r
3378     int i, step, mstep = 0;\r
3379     const float* sample;\r
3380     const uchar* m = 0;\r
3381     CvDTreeNode* node = root;\r
3382     const int* vtype;\r
3383     const int* vidx;\r
3384     const int* cmap;\r
3385     const int* cofs;\r
3386 \r
3387     if( !node )\r
3388         CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
3389 \r
3390     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
3391         (_sample->cols != 1 && _sample->rows != 1) ||\r
3392         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
3393         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
3394             CV_ERROR( CV_StsBadArg,\r
3395         "the input sample must be 1d floating-point vector with the same "\r
3396         "number of elements as the total number of variables used for training" );\r
3397 \r
3398     sample = _sample->data.fl;\r
3399     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
3400 \r
3401     if( data->cat_count && !preprocessed_input ) // cache for categorical variables\r
3402     {\r
3403         int n = data->cat_count->cols;\r
3404         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));\r
3405         for( i = 0; i < n; i++ )\r
3406             catbuf[i] = -1;\r
3407     }\r
3408 \r
3409     if( _missing )\r
3410     {\r
3411         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
3412         !CV_ARE_SIZES_EQ(_missing, _sample) )\r
3413             CV_ERROR( CV_StsBadArg,\r
3414         "the missing data mask must be 8-bit vector of the same size as input sample" );\r
3415         m = _missing->data.ptr;\r
3416         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
3417     }\r
3418 \r
3419     vtype = data->var_type->data.i;\r
3420     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
3421     cmap = data->cat_map ? data->cat_map->data.i : 0;\r
3422     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;\r
3423 \r
3424     while( node->Tn > pruned_tree_idx && node->left )\r
3425     {\r
3426         CvDTreeSplit* split = node->split;\r
3427         int dir = 0;\r
3428         for( ; !dir && split != 0; split = split->next )\r
3429         {\r
3430             int vi = split->var_idx;\r
3431             int ci = vtype[vi];\r
3432             i = vidx ? vidx[vi] : vi;\r
3433             float val = sample[i*step];\r
3434             if( m && m[i*mstep] )\r
3435                 continue;\r
3436             if( ci < 0 ) // ordered\r
3437                 dir = val <= split->ord.c ? -1 : 1;\r
3438             else // categorical\r
3439             {\r
3440                 int c;\r
3441                 if( preprocessed_input )\r
3442                     c = cvRound(val);\r
3443                 else\r
3444                 {\r
3445                     c = catbuf[ci];\r
3446                     if( c < 0 )\r
3447                     {\r
3448                         int a = c = cofs[ci];\r
3449                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];\r
3450                         \r
3451                         int ival = cvRound(val);\r
3452                         if( ival != val )\r
3453                             CV_ERROR( CV_StsBadArg,\r
3454                             "one of input categorical variable is not an integer" );\r
3455                         \r
3456                         int sh = 0;\r
3457                         while( a < b )\r
3458                         {\r
3459                             sh++;\r
3460                             c = (a + b) >> 1;\r
3461                             if( ival < cmap[c] )\r
3462                                 b = c;\r
3463                             else if( ival > cmap[c] )\r
3464                                 a = c+1;\r
3465                             else\r
3466                                 break;\r
3467                         }\r
3468 \r
3469                         if( c < 0 || ival != cmap[c] )\r
3470                             continue;\r
3471 \r
3472                         catbuf[ci] = c -= cofs[ci];\r
3473                     }\r
3474                 }\r
3475                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;\r
3476                 dir = CV_DTREE_CAT_DIR(c, split->subset);\r
3477             }\r
3478 \r
3479             if( split->inversed )\r
3480                 dir = -dir;\r
3481         }\r
3482 \r
3483         if( !dir )\r
3484         {\r
3485             double diff = node->right->sample_count - node->left->sample_count;\r
3486             dir = diff < 0 ? -1 : 1;\r
3487         }\r
3488         node = dir < 0 ? node->left : node->right;\r
3489     }\r
3490 \r
3491     result = node;\r
3492 \r
3493     __END__;\r
3494 \r
3495     return result;\r
3496 }\r
3497 \r
3498 \r
3499 const CvMat* CvDTree::get_var_importance()\r
3500 {\r
3501     if( !var_importance )\r
3502     {\r
3503         CvDTreeNode* node = root;\r
3504         double* importance;\r
3505         if( !node )\r
3506             return 0;\r
3507         var_importance = cvCreateMat( 1, data->var_count, CV_64F );\r
3508         cvZero( var_importance );\r
3509         importance = var_importance->data.db;\r
3510 \r
3511         for(;;)\r
3512         {\r
3513             CvDTreeNode* parent;\r
3514             for( ;; node = node->left )\r
3515             {\r
3516                 CvDTreeSplit* split = node->split;\r
3517 \r
3518                 if( !node->left || node->Tn <= pruned_tree_idx )\r
3519                     break;\r
3520 \r
3521                 for( ; split != 0; split = split->next )\r
3522                     importance[split->var_idx] += split->quality;\r
3523             }\r
3524 \r
3525             for( parent = node->parent; parent && parent->right == node;\r
3526                 node = parent, parent = parent->parent )\r
3527                 ;\r
3528 \r
3529             if( !parent )\r
3530                 break;\r
3531 \r
3532             node = parent->right;\r
3533         }\r
3534 \r
3535         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );\r
3536     }\r
3537 \r
3538     return var_importance;\r
3539 }\r
3540 \r
3541 \r
3542 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )\r
3543 {\r
3544     int ci;\r
3545 \r
3546     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );\r
3547     cvWriteInt( fs, "var", split->var_idx );\r
3548     cvWriteReal( fs, "quality", split->quality );\r
3549 \r
3550     ci = data->get_var_type(split->var_idx);\r
3551     if( ci >= 0 ) // split on a categorical var\r
3552     {\r
3553         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;\r
3554         for( i = 0; i < n; i++ )\r
3555             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;\r
3556 \r
3557         // ad-hoc rule when to use inverse categorical split notation\r
3558         // to achieve more compact and clear representation\r
3559         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;\r
3560 \r
3561         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?\r
3562                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );\r
3563 \r
3564         for( i = 0; i < n; i++ )\r
3565         {\r
3566             int dir = CV_DTREE_CAT_DIR(i,split->subset);\r
3567             if( dir*default_dir < 0 )\r
3568                 cvWriteInt( fs, 0, i );\r
3569         }\r
3570         cvEndWriteStruct( fs );\r
3571     }\r
3572     else\r
3573         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );\r
3574 \r
3575     cvEndWriteStruct( fs );\r
3576 }\r
3577 \r
3578 \r
3579 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )\r
3580 {\r
3581     CvDTreeSplit* split;\r
3582 \r
3583     cvStartWriteStruct( fs, 0, CV_NODE_MAP );\r
3584 \r
3585     cvWriteInt( fs, "depth", node->depth );\r
3586     cvWriteInt( fs, "sample_count", node->sample_count );\r
3587     cvWriteReal( fs, "value", node->value );\r
3588 \r
3589     if( data->is_classifier )\r
3590         cvWriteInt( fs, "norm_class_idx", node->class_idx );\r
3591 \r
3592     cvWriteInt( fs, "Tn", node->Tn );\r
3593     cvWriteInt( fs, "complexity", node->complexity );\r
3594     cvWriteReal( fs, "alpha", node->alpha );\r
3595     cvWriteReal( fs, "node_risk", node->node_risk );\r
3596     cvWriteReal( fs, "tree_risk", node->tree_risk );\r
3597     cvWriteReal( fs, "tree_error", node->tree_error );\r
3598 \r
3599     if( node->left )\r
3600     {\r
3601         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );\r
3602 \r
3603         for( split = node->split; split != 0; split = split->next )\r
3604             write_split( fs, split );\r
3605 \r
3606         cvEndWriteStruct( fs );\r
3607     }\r
3608 \r
3609     cvEndWriteStruct( fs );\r
3610 }\r
3611 \r
3612 \r
3613 void CvDTree::write_tree_nodes( CvFileStorage* fs )\r
3614 {\r
3615     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );\r
3616 \r
3617     __BEGIN__;\r
3618 \r
3619     CvDTreeNode* node = root;\r
3620 \r
3621     // traverse the tree and save all the nodes in depth-first order\r
3622     for(;;)\r
3623     {\r
3624         CvDTreeNode* parent;\r
3625         for(;;)\r
3626         {\r
3627             write_node( fs, node );\r
3628             if( !node->left )\r
3629                 break;\r
3630             node = node->left;\r
3631         }\r
3632 \r
3633         for( parent = node->parent; parent && parent->right == node;\r
3634             node = parent, parent = parent->parent )\r
3635             ;\r
3636 \r
3637         if( !parent )\r
3638             break;\r
3639 \r
3640         node = parent->right;\r
3641     }\r
3642 \r
3643     __END__;\r
3644 }\r
3645 \r
3646 \r
3647 void CvDTree::write( CvFileStorage* fs, const char* name )\r
3648 {\r
3649     //CV_FUNCNAME( "CvDTree::write" );\r
3650 \r
3651     __BEGIN__;\r
3652 \r
3653     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );\r
3654 \r
3655     get_var_importance();\r
3656     data->write_params( fs );\r
3657     if( var_importance )\r
3658         cvWrite( fs, "var_importance", var_importance );\r
3659     write( fs );\r
3660 \r
3661     cvEndWriteStruct( fs );\r
3662 \r
3663     __END__;\r
3664 }\r
3665 \r
3666 \r
3667 void CvDTree::write( CvFileStorage* fs )\r
3668 {\r
3669     //CV_FUNCNAME( "CvDTree::write" );\r
3670 \r
3671     __BEGIN__;\r
3672 \r
3673     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );\r
3674 \r
3675     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );\r
3676     write_tree_nodes( fs );\r
3677     cvEndWriteStruct( fs );\r
3678 \r
3679     __END__;\r
3680 }\r
3681 \r
3682 \r
3683 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )\r
3684 {\r
3685     CvDTreeSplit* split = 0;\r
3686 \r
3687     CV_FUNCNAME( "CvDTree::read_split" );\r
3688 \r
3689     __BEGIN__;\r
3690 \r
3691     int vi, ci;\r
3692 \r
3693     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3694         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );\r
3695 \r
3696     vi = cvReadIntByName( fs, fnode, "var", -1 );\r
3697     if( (unsigned)vi >= (unsigned)data->var_count )\r
3698         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );\r
3699 \r
3700     ci = data->get_var_type(vi);\r
3701     if( ci >= 0 ) // split on categorical var\r
3702     {\r
3703         int i, n = data->cat_count->data.i[ci], inversed = 0, val;\r
3704         CvSeqReader reader;\r
3705         CvFileNode* inseq;\r
3706         split = data->new_split_cat( vi, 0 );\r
3707         inseq = cvGetFileNodeByName( fs, fnode, "in" );\r
3708         if( !inseq )\r
3709         {\r
3710             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );\r
3711             inversed = 1;\r
3712         }\r
3713         if( !inseq ||\r
3714             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))\r
3715             CV_ERROR( CV_StsParseError,\r
3716             "Either 'in' or 'not_in' tags should be inside a categorical split data" );\r
3717 \r
3718         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )\r
3719         {\r
3720             val = inseq->data.i;\r
3721             if( (unsigned)val >= (unsigned)n )\r
3722                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3723 \r
3724             split->subset[val >> 5] |= 1 << (val & 31);\r
3725         }\r
3726         else\r
3727         {\r
3728             cvStartReadSeq( inseq->data.seq, &reader );\r
3729 \r
3730             for( i = 0; i < reader.seq->total; i++ )\r
3731             {\r
3732                 CvFileNode* inode = (CvFileNode*)reader.ptr;\r
3733                 val = inode->data.i;\r
3734                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )\r
3735                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );\r
3736 \r
3737                 split->subset[val >> 5] |= 1 << (val & 31);\r
3738                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3739             }\r
3740         }\r
3741 \r
3742         // for categorical splits we do not use inversed splits,\r
3743         // instead we inverse the variable set in the split\r
3744         if( inversed )\r
3745             for( i = 0; i < (n + 31) >> 5; i++ )\r
3746                 split->subset[i] ^= -1;\r
3747     }\r
3748     else\r
3749     {\r
3750         CvFileNode* cmp_node;\r
3751         split = data->new_split_ord( vi, 0, 0, 0, 0 );\r
3752 \r
3753         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );\r
3754         if( !cmp_node )\r
3755         {\r
3756             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );\r
3757             split->inversed = 1;\r
3758         }\r
3759 \r
3760         split->ord.c = (float)cvReadReal( cmp_node );\r
3761     }\r
3762 \r
3763     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );\r
3764 \r
3765     __END__;\r
3766 \r
3767     return split;\r
3768 }\r
3769 \r
3770 \r
3771 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )\r
3772 {\r
3773     CvDTreeNode* node = 0;\r
3774 \r
3775     CV_FUNCNAME( "CvDTree::read_node" );\r
3776 \r
3777     __BEGIN__;\r
3778 \r
3779     CvFileNode* splits;\r
3780     int i, depth;\r
3781 \r
3782     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )\r
3783         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );\r
3784 \r
3785     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));\r
3786     depth = cvReadIntByName( fs, fnode, "depth", -1 );\r
3787     if( depth != node->depth )\r
3788         CV_ERROR( CV_StsParseError, "incorrect node depth" );\r
3789 \r
3790     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );\r
3791     node->value = cvReadRealByName( fs, fnode, "value" );\r
3792     if( data->is_classifier )\r
3793         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );\r
3794 \r
3795     node->Tn = cvReadIntByName( fs, fnode, "Tn" );\r
3796     node->complexity = cvReadIntByName( fs, fnode, "complexity" );\r
3797     node->alpha = cvReadRealByName( fs, fnode, "alpha" );\r
3798     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );\r
3799     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );\r
3800     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );\r
3801 \r
3802     splits = cvGetFileNodeByName( fs, fnode, "splits" );\r
3803     if( splits )\r
3804     {\r
3805         CvSeqReader reader;\r
3806         CvDTreeSplit* last_split = 0;\r
3807 \r
3808         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )\r
3809             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );\r
3810 \r
3811         cvStartReadSeq( splits->data.seq, &reader );\r
3812         for( i = 0; i < reader.seq->total; i++ )\r
3813         {\r
3814             CvDTreeSplit* split;\r
3815             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));\r
3816             if( !last_split )\r
3817                 node->split = last_split = split;\r
3818             else\r
3819                 last_split = last_split->next = split;\r
3820 \r
3821             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3822         }\r
3823     }\r
3824 \r
3825     __END__;\r
3826 \r
3827     return node;\r
3828 }\r
3829 \r
3830 \r
3831 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )\r
3832 {\r
3833     CV_FUNCNAME( "CvDTree::read_tree_nodes" );\r
3834 \r
3835     __BEGIN__;\r
3836 \r
3837     CvSeqReader reader;\r
3838     CvDTreeNode _root;\r
3839     CvDTreeNode* parent = &_root;\r
3840     int i;\r
3841     parent->left = parent->right = parent->parent = 0;\r
3842 \r
3843     cvStartReadSeq( fnode->data.seq, &reader );\r
3844 \r
3845     for( i = 0; i < reader.seq->total; i++ )\r
3846     {\r
3847         CvDTreeNode* node;\r
3848 \r
3849         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));\r
3850         if( !parent->left )\r
3851             parent->left = node;\r
3852         else\r
3853             parent->right = node;\r
3854         if( node->split )\r
3855             parent = node;\r
3856         else\r
3857         {\r
3858             while( parent && parent->right )\r
3859                 parent = parent->parent;\r
3860         }\r
3861 \r
3862         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );\r
3863     }\r
3864 \r
3865     root = _root.left;\r
3866 \r
3867     __END__;\r
3868 }\r
3869 \r
3870 \r
3871 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )\r
3872 {\r
3873     CvDTreeTrainData* _data = new CvDTreeTrainData();\r
3874     _data->read_params( fs, fnode );\r
3875 \r
3876     read( fs, fnode, _data );\r
3877     get_var_importance();\r
3878 }\r
3879 \r
3880 \r
3881 // a special entry point for reading weak decision trees from the tree ensembles\r
3882 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )\r
3883 {\r
3884     CV_FUNCNAME( "CvDTree::read" );\r
3885 \r
3886     __BEGIN__;\r
3887 \r
3888     CvFileNode* tree_nodes;\r
3889 \r
3890     clear();\r
3891     data = _data;\r
3892 \r
3893     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );\r
3894     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )\r
3895         CV_ERROR( CV_StsParseError, "nodes tag is missing" );\r
3896 \r
3897     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );\r
3898     read_tree_nodes( fs, tree_nodes );\r
3899 \r
3900     __END__;\r
3901 }\r
3902 \r
3903 /* End of file. */\r