]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mlertrees.cpp
converted ml (dtree, rtrees, ertrees, boost) from OpenMP to TBB.
[opencv.git] / opencv / src / ml / mlertrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2
3   IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
5   By downloading, copying, installing or using the software you agree to this license.
6   If you do not agree to this license, do not download, install,
7   copy or use the software.
8
9
10                         Intel License Agreement
11
12  Copyright (C) 2000, Intel Corporation, all rights reserved.
13  Third party copyrights are property of their respective owners.
14
15  Redistribution and use in source and binary forms, with or without modification,
16  are permitted provided that the following conditions are met:
17
18    * Redistribution's of source code must retain the above copyright notice,
19      this list of conditions and the following disclaimer.
20
21    * Redistribution's in binary form must reproduce the above copyright notice,
22      this list of conditions and the following disclaimer in the documentation
23      and/or other materials provided with the distribution.
24
25    * The name of Intel Corporation may not be used to endorse or promote products
26      derived from this software without specific prior written permission.
27
28  This software is provided by the copyright holders and contributors "as is" and
29  any express or implied warranties, including, but not limited to, the implied
30  warranties of merchantability and fitness for a particular purpose are disclaimed.
31  In no event shall the Intel Corporation or contributors be liable for any direct,
32  indirect, incidental, special, exemplary, or consequential damages
33  (including, but not limited to, procurement of substitute goods or services;
34  loss of use, data, or profits; or business interruption) however caused
35  and on any theory of liability, whether in contract, strict liability,
36  or tort (including negligence or otherwise) arising in any way out of
37  the use of this software, even if advised of the possibility of such damage.
38
39 M*/
40
41 #include "_ml.h"
42
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
46
47 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
48 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
49
50 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
51 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
52
53 ///
54
55 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
56     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
57     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
58     bool _shared, bool _add_labels, bool _update_data )
59 {
60     CvMat* sample_indices = 0;
61     CvMat* var_type0 = 0;
62     CvMat* tmp_map = 0;
63     int** int_ptr = 0;
64     CvPair16u32s* pair16u32s_ptr = 0;
65     CvDTreeTrainData* data = 0;
66     float *_fdst = 0;
67     int *_idst = 0;
68     unsigned short* udst = 0;
69     int* idst = 0;
70
71     CV_FUNCNAME( "CvERTreeTrainData::set_data" );
72
73     __BEGIN__;
74     
75     int sample_all = 0, r_type = 0, cv_n;
76     int total_c_count = 0;
77     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
78     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
79     int vi, i, size;
80     char err[100];
81     const int *sidx = 0, *vidx = 0;
82     
83     if ( _params.use_surrogates )
84         CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
85         
86     if( _update_data && data_root )
87     {
88         CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
89     }
90
91     clear();
92
93     var_all = 0;
94     rng = cvRNG(-1);
95
96     CV_CALL( set_params( _params ));
97
98     // check parameter types and sizes
99     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
100
101     train_data = _train_data;
102     responses = _responses;
103     missing_mask = _missing_mask;
104
105     if( _tflag == CV_ROW_SAMPLE )
106     {
107         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
108         dv_step = 1;
109         if( _missing_mask )
110             ms_step = _missing_mask->step, mv_step = 1;
111     }
112     else
113     {
114         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
115         ds_step = 1;
116         if( _missing_mask )
117             mv_step = _missing_mask->step, ms_step = 1;
118     }
119     tflag = _tflag;
120
121     sample_count = sample_all;
122     var_count = var_all;
123
124     if( _sample_idx )
125     {
126         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
127         sidx = sample_indices->data.i;
128         sample_count = sample_indices->rows + sample_indices->cols - 1;
129     }
130
131     if( _var_idx )
132     {
133         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
134         vidx = var_idx->data.i;
135         var_count = var_idx->rows + var_idx->cols - 1;
136     }
137
138     if( !CV_IS_MAT(_responses) ||
139         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
140          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
141         (_responses->rows != 1 && _responses->cols != 1) ||
142         _responses->rows + _responses->cols - 1 != sample_all )
143         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
144                   "floating-point vector containing as many elements as "
145                   "the total number of samples in the training data matrix" );
146    
147     is_buf_16u = false;
148     if ( sample_count < 65536 )  
149         is_buf_16u = true;                                
150     
151   
152     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
153
154     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
155    
156     
157     cat_var_count = 0;
158     ord_var_count = -1;
159
160     is_classifier = r_type == CV_VAR_CATEGORICAL;
161
162     // step 0. calc the number of categorical vars
163     for( vi = 0; vi < var_count; vi++ )
164     {
165         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
166             cat_var_count++ : ord_var_count--;
167     }
168
169     ord_var_count = ~ord_var_count;
170     cv_n = params.cv_folds;
171     // set the two last elements of var_type array to be able
172     // to locate responses and cross-validation labels using
173     // the corresponding get_* functions.
174     var_type->data.i[var_count] = cat_var_count;
175     var_type->data.i[var_count+1] = cat_var_count+1;
176
177     // in case of single ordered predictor we need dummy cv_labels
178     // for safe split_node_data() operation
179     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
180
181     work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
182     buf_size = (work_var_count + 1)*sample_count;
183     shared = _shared;
184     buf_count = shared ? 2 : 1;
185     
186     if ( is_buf_16u )
187     {
188         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
189         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
190     }
191     else
192     {
193         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
194         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
195     }    
196
197     size = is_classifier ? cat_var_count+1 : cat_var_count;
198     size = !size ? 1 : size;
199     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
200     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
201     
202     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
203     size = !size ? 1 : size;
204     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
205
206     // now calculate the maximum size of split,
207     // create memory storage that will keep nodes and splits of the decision tree
208     // allocate root node and the buffer for the whole training data
209     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
210         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
211     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
212     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
213     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
214     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
215
216     nv_size = var_count*sizeof(int);
217     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
218
219     temp_block_size = nv_size;
220
221     if( cv_n )
222     {
223         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
224             CV_ERROR( CV_StsOutOfRange,
225                 "The many folds in cross-validation for such a small dataset" );
226
227         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
228         temp_block_size = MAX(temp_block_size, cv_size);
229     }
230
231     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
232     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
233     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
234     if( cv_size )
235         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
236
237     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
238
239     max_c_count = 1;
240
241     _fdst = 0;
242     _idst = 0;
243     if (ord_var_count)
244         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
245     if (is_buf_16u && (cat_var_count || is_classifier))
246         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
247
248     // transform the training data to convenient representation
249     for( vi = 0; vi <= var_count; vi++ )
250     {
251         int ci;
252         const uchar* mask = 0;
253         int m_step = 0, step;
254         const int* idata = 0;
255         const float* fdata = 0;
256         int num_valid = 0;
257
258         if( vi < var_count ) // analyze i-th input variable
259         {
260             int vi0 = vidx ? vidx[vi] : vi;
261             ci = get_var_type(vi);
262             step = ds_step; m_step = ms_step;
263             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
264                 idata = _train_data->data.i + vi0*dv_step;
265             else
266                 fdata = _train_data->data.fl + vi0*dv_step;
267             if( _missing_mask )
268                 mask = _missing_mask->data.ptr + vi0*mv_step;
269         }
270         else // analyze _responses
271         {
272             ci = cat_var_count;
273             step = CV_IS_MAT_CONT(_responses->type) ?
274                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
275             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
276                 idata = _responses->data.i;
277             else
278                 fdata = _responses->data.fl;
279         }
280
281         if( (vi < var_count && ci>=0) ||
282             (vi == var_count && is_classifier) ) // process categorical variable or response
283         {
284             int c_count, prev_label;
285             int* c_map;
286             
287             if (is_buf_16u)
288                 udst = (unsigned short*)(buf->data.s + ci*sample_count);
289             else
290                 idst = buf->data.i + ci*sample_count;
291             
292             // copy data
293             for( i = 0; i < sample_count; i++ )
294             {
295                 int val = INT_MAX, si = sidx ? sidx[i] : i;
296                 if( !mask || !mask[si*m_step] )
297                 {
298                     if( idata )
299                         val = idata[si*step];
300                     else
301                     {
302                         float t = fdata[si*step];
303                         val = cvRound(t);
304                         if( val != t )
305                         {
306                             sprintf( err, "%d-th value of %d-th (categorical) "
307                                 "variable is not an integer", i, vi );
308                             CV_ERROR( CV_StsBadArg, err );
309                         }
310                     }
311
312                     if( val == INT_MAX )
313                     {
314                         sprintf( err, "%d-th value of %d-th (categorical) "
315                             "variable is too large", i, vi );
316                         CV_ERROR( CV_StsBadArg, err );
317                     }
318                     num_valid++;
319                 }
320                 if (is_buf_16u)
321                 {
322                     _idst[i] = val;
323                     pair16u32s_ptr[i].u = udst + i;
324                     pair16u32s_ptr[i].i = _idst + i;
325                 }   
326                 else
327                 {
328                     idst[i] = val;
329                     int_ptr[i] = idst + i;
330                 }
331             }
332
333             c_count = num_valid > 0;
334
335             if (is_buf_16u)
336             {
337                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
338                 // count the categories
339                 for( i = 1; i < num_valid; i++ )
340                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
341                         c_count ++ ;
342             }
343             else
344             {
345                 icvSortIntPtr( int_ptr, sample_count, 0 );
346                 // count the categories
347                 for( i = 1; i < num_valid; i++ )
348                     c_count += *int_ptr[i] != *int_ptr[i-1];
349             }
350
351             if( vi > 0 )
352                 max_c_count = MAX( max_c_count, c_count );
353             cat_count->data.i[ci] = c_count;
354             cat_ofs->data.i[ci] = total_c_count;
355
356             // resize cat_map, if need
357             if( cat_map->cols < total_c_count + c_count )
358             {
359                 tmp_map = cat_map;
360                 CV_CALL( cat_map = cvCreateMat( 1,
361                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
362                 for( i = 0; i < total_c_count; i++ )
363                     cat_map->data.i[i] = tmp_map->data.i[i];
364                 cvReleaseMat( &tmp_map );
365             }
366
367             c_map = cat_map->data.i + total_c_count;
368             total_c_count += c_count;
369
370             c_count = -1;
371             if (is_buf_16u)
372             {
373                 // compact the class indices and build the map
374                 prev_label = ~*pair16u32s_ptr[0].i;
375                 for( i = 0; i < num_valid; i++ )
376                 {
377                     int cur_label = *pair16u32s_ptr[i].i;
378                     if( cur_label != prev_label )
379                         c_map[++c_count] = prev_label = cur_label;
380                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
381                 }
382                 // replace labels for missing values with 65535
383                 for( ; i < sample_count; i++ )
384                     *pair16u32s_ptr[i].u = 65535;
385             }
386             else
387             {
388                 // compact the class indices and build the map
389                 prev_label = ~*int_ptr[0];
390                 for( i = 0; i < num_valid; i++ )
391                 {
392                     int cur_label = *int_ptr[i];
393                     if( cur_label != prev_label )
394                         c_map[++c_count] = prev_label = cur_label;
395                     *int_ptr[i] = c_count;
396                 }
397                 // replace labels for missing values with -1
398                 for( ; i < sample_count; i++ )
399                     *int_ptr[i] = -1;
400             }           
401         }
402         else if( ci < 0 ) // process ordered variable
403         {
404             for( i = 0; i < sample_count; i++ )
405             {
406                 float val = ord_nan;
407                 int si = sidx ? sidx[i] : i;
408                 if( !mask || !mask[si*m_step] )
409                 {
410                     if( idata )
411                         val = (float)idata[si*step];
412                     else
413                         val = fdata[si*step];
414
415                     if( fabs(val) >= ord_nan )
416                     {
417                         sprintf( err, "%d-th value of %d-th (ordered) "
418                             "variable (=%g) is too large", i, vi, val );
419                         CV_ERROR( CV_StsBadArg, err );
420                     }
421                 }
422                 num_valid++;
423             }
424         }
425         if( vi < var_count )
426             data_root->set_num_valid(vi, num_valid);
427     }
428
429     // set sample labels
430     if (is_buf_16u)
431         udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
432     else
433         idst = buf->data.i + get_work_var_count()*sample_count;
434
435     for (i = 0; i < sample_count; i++)
436     {
437         if (udst)
438             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
439         else
440             idst[i] = sidx ? sidx[i] : i;
441     }
442
443     if( cv_n )
444     {
445         unsigned short* udst = 0;
446         int* idst = 0;
447         CvRNG* r = &rng;
448
449         if (is_buf_16u)
450         {
451             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
452             for( i = vi = 0; i < sample_count; i++ )
453             {
454                 udst[i] = (unsigned short)vi++;
455                 vi &= vi < cv_n ? -1 : 0;
456             }
457
458             for( i = 0; i < sample_count; i++ )
459             {
460                 int a = cvRandInt(r) % sample_count;
461                 int b = cvRandInt(r) % sample_count;
462                 unsigned short unsh = (unsigned short)vi;
463                 CV_SWAP( udst[a], udst[b], unsh );
464             }
465         }
466         else
467         {
468             idst = buf->data.i + (get_work_var_count()-1)*sample_count;
469             for( i = vi = 0; i < sample_count; i++ )
470             {
471                 idst[i] = vi++;
472                 vi &= vi < cv_n ? -1 : 0;
473             }
474
475             for( i = 0; i < sample_count; i++ )
476             {
477                 int a = cvRandInt(r) % sample_count;
478                 int b = cvRandInt(r) % sample_count;
479                 CV_SWAP( idst[a], idst[b], vi );
480             }
481         }
482     }
483
484     if ( cat_map ) 
485         cat_map->cols = MAX( total_c_count, 1 );
486
487     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
488         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
489     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
490
491     have_priors = is_classifier && params.priors;
492     if( is_classifier )
493     {
494         int m = get_num_classes();
495         double sum = 0;
496         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
497         for( i = 0; i < m; i++ )
498         {
499             double val = have_priors ? params.priors[i] : 1.;
500             if( val <= 0 )
501                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
502             priors->data.db[i] = val;
503             sum += val;
504         }
505
506         // normalize weights
507         if( have_priors )
508             cvScale( priors, priors, 1./sum );
509
510         CV_CALL( priors_mult = cvCloneMat( priors ));
511         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
512     }
513
514     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
515     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
516
517     __END__;
518
519     if( data )
520         delete data;
521
522     if (_fdst)
523         cvFree( &_fdst );
524     if (_idst)
525         cvFree( &_idst );
526     cvFree( &int_ptr );
527     cvReleaseMat( &var_type0 );
528     cvReleaseMat( &sample_indices );
529     cvReleaseMat( &tmp_map );
530 }
531
532 void CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
533                                           const float** ord_values, const int** missing, int* sample_indices_buf )
534 {
535     int vidx = var_idx ? var_idx->data.i[vi] : vi;
536     int node_sample_count = n->sample_count; 
537     // may use missing_buf as buffer for sample indices!
538     const int* sample_indices = get_sample_indices(n, sample_indices_buf ? sample_indices_buf : missing_buf);
539
540     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
541     int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
542     if( tflag == CV_ROW_SAMPLE )
543     {
544         for( int i = 0; i < node_sample_count; i++ )
545         {
546             int idx = sample_indices[i];
547             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
548             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
549         }
550     }
551     else
552         for( int i = 0; i < node_sample_count; i++ )
553         {
554             int idx = sample_indices[i];
555             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
556             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
557         }
558     *ord_values = ord_values_buf;
559     *missing = missing_buf;
560 }
561
562
563 const int* CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
564 {
565     return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf );
566 }
567
568
569 const int* CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
570 {
571     if (have_labels)
572         return get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf );
573     return 0;
574 }
575
576
577 const int* CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf )
578 {
579     int ci = get_var_type( vi);
580     const int* cat_values = 0;
581     if( !is_buf_16u )
582         cat_values = buf->data.i + n->buf_idx*buf->cols + ci*sample_count + n->offset;
583     else {
584         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
585             ci*sample_count + n->offset);
586         for( int i = 0; i < n->sample_count; i++ )
587             cat_values_buf[i] = short_values[i];
588         cat_values = cat_values_buf;
589     }
590     return cat_values;
591 }
592
593 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
594                                     float* values, uchar* missing,
595                                     float* responses, bool get_class_idx )
596 {
597     CvMat* subsample_idx = 0;
598     CvMat* subsample_co = 0;
599
600     cv::AutoBuffer<uchar> inn_buf(sample_count*(sizeof(float) + sizeof(int)));
601
602     CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
603
604     __BEGIN__;
605
606     int i, vi, total = sample_count, count = total, cur_ofs = 0;
607     int* sidx = 0;
608     int* co = 0;
609
610     if( _subsample_idx )
611     {
612         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
613         sidx = subsample_idx->data.i;
614         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
615         co = subsample_co->data.i;
616         cvZero( subsample_co );
617         count = subsample_idx->cols + subsample_idx->rows - 1;
618         for( i = 0; i < count; i++ )
619             co[sidx[i]*2]++;
620         for( i = 0; i < total; i++ )
621         {
622             int count_i = co[i*2];
623             if( count_i )
624             {
625                 co[i*2+1] = cur_ofs*var_count;
626                 cur_ofs += count_i;
627             }
628         }
629     }
630
631     if( missing )
632         memset( missing, 1, count*var_count );
633
634     for( vi = 0; vi < var_count; vi++ )
635     {
636         int ci = get_var_type(vi);
637         if( ci >= 0 ) // categorical
638         {
639             float* dst = values + vi;
640             uchar* m = missing ? missing + vi : 0;
641             int* lbls_buf = (int*)(uchar*)inn_buf;
642             const int* src = get_cat_var_data(data_root, vi, lbls_buf);
643
644             for( i = 0; i < count; i++, dst += var_count )
645             {
646                 int idx = sidx ? sidx[i] : i;
647                 int val = src[idx];
648                 *dst = (float)val;
649                 if( m )
650                 {
651                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
652                     m += var_count;
653                 }
654             }
655         }
656         else // ordered
657         {
658             int* mis_buf = (int*)(uchar*)inn_buf;
659             const float *dst = 0;
660             const int* mis = 0;
661             get_ord_var_data(data_root, vi, values + vi, mis_buf, &dst, &mis, 0);
662             for (int si = 0; si < total; si++)
663                 *(missing + vi + si) = mis[si] == 0 ? 0 : 1;
664         }
665     }
666
667     // copy responses
668     if( responses )
669     {
670         if( is_classifier )
671         {
672             int* lbls_buf = (int*)(uchar*)inn_buf;
673             const int* src = get_class_labels(data_root, lbls_buf);
674             for( i = 0; i < count; i++ )
675             {
676                 int idx = sidx ? sidx[i] : i;
677                 int val = get_class_idx ? src[idx] :
678                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
679                 responses[i] = (float)val;
680             }
681         }
682         else           
683         {
684             float* _values_buf = (float*)(uchar*)inn_buf;
685             int* sample_idx_buf = (int*)(_values_buf + sample_count);
686             const float* _values = get_ord_responses(data_root, _values_buf, sample_idx_buf);
687             for( i = 0; i < count; i++ )
688             {
689                 int idx = sidx ? sidx[i] : i;
690                 responses[i] = _values[idx];
691             }
692         }
693     }
694
695     __END__;
696
697     cvReleaseMat( &subsample_idx );
698     cvReleaseMat( &subsample_co );
699 }
700
701 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
702 {
703     CvDTreeNode* root = 0;
704     
705     CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
706
707     __BEGIN__;
708
709     if( !data_root )
710         CV_ERROR( CV_StsError, "No training data has been set" );
711
712     if( !_subsample_idx )
713     {
714         // make a copy of the root node
715         CvDTreeNode temp;
716         int i;
717         root = new_node( 0, 1, 0, 0 );
718         temp = *root;
719         *root = *data_root;
720         root->num_valid = temp.num_valid;
721         if( root->num_valid )
722         {
723             for( i = 0; i < var_count; i++ )
724                 root->num_valid[i] = data_root->num_valid[i];
725         }
726         root->cv_Tn = temp.cv_Tn;
727         root->cv_node_risk = temp.cv_node_risk;
728         root->cv_node_error = temp.cv_node_error;
729     }
730     else
731         CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
732     __END__;
733
734     return root;
735 }
736
737 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
738 {
739     char* dir = (char*)data->direction->data.ptr;
740     int i, n = node->sample_count, vi = node->split->var_idx;
741     double L, R;
742
743     assert( !node->split->inversed );
744
745     if( data->get_var_type(vi) >= 0 ) // split on categorical var
746     {
747         cv::AutoBuffer<uchar> inn_buf(n*sizeof(int)*(!data->have_priors ? 1 : 2));
748         int* labels_buf = (int*)(uchar*)inn_buf;
749         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
750         const int* subset = node->split->subset;
751         if( !data->have_priors )
752         {
753             int sum = 0, sum_abs = 0;
754
755             for( i = 0; i < n; i++ )
756             {
757                 int idx = labels[i];
758                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
759                     CV_DTREE_CAT_DIR(idx,subset) : 0;
760                 sum += d; sum_abs += d & 1;
761                 dir[i] = (char)d;
762             }
763
764             R = (sum_abs + sum) >> 1;
765             L = (sum_abs - sum) >> 1;
766         }
767         else
768         {
769             const double* priors = data->priors_mult->data.db;
770             double sum = 0, sum_abs = 0;
771             int *responses_buf = labels_buf + n;
772             const int* responses = data->get_class_labels(node, responses_buf);
773
774             for( i = 0; i < n; i++ )
775             {
776                 int idx = labels[i];
777                 double w = priors[responses[i]];
778                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
779                 sum += d*w; sum_abs += (d & 1)*w;
780                 dir[i] = (char)d;
781             }
782
783             R = (sum_abs + sum) * 0.5;
784             L = (sum_abs - sum) * 0.5;
785         }
786     }
787     else // split on ordered var
788     {
789         float split_val = node->split->ord.c;
790         cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(!data->have_priors ? 1 : 2) + sizeof(float)));
791         float* val_buf = (float*)(uchar*)inn_buf;
792         int* missing_buf = (int*)(val_buf + n);
793         const float* val = 0;
794         const int* missing = 0;
795         data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing, 0 );
796
797         if( !data->have_priors )
798         {
799             L = R = 0;
800             for( i = 0; i < n; i++ )
801             {
802                 if ( missing[i] )
803                     dir[i] = (char)0;
804                 else
805                 {
806                     if ( val[i] < split_val)
807                     {
808                         dir[i] = (char)-1;
809                         L++;
810                     }
811                     else
812                     {
813                         dir[i] = (char)1;
814                         R++;
815                     }
816                 }
817             }
818         }
819         else
820         {
821             const double* priors = data->priors_mult->data.db;
822             int* responses_buf = missing_buf + n;
823             const int* responses = data->get_class_labels(node, responses_buf);
824             L = R = 0;
825             for( i = 0; i < n; i++ )
826             {
827                 if ( missing[i] )
828                     dir[i] = (char)0;
829                 else
830                 {
831                     double w = priors[responses[i]];
832                     if ( val[i] < split_val)
833                     {
834                         dir[i] = (char)-1;
835                          L += w;
836                     }
837                     else
838                     {
839                         dir[i] = (char)1;
840                         R += w;
841                     }
842                 }
843             }
844         }
845     }
846
847     node->maxlr = MAX( L, R );
848     return node->split->quality/(L + R);
849 }
850
851 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
852                                                     uchar* _ext_buf )
853 {
854     const float epsilon = FLT_EPSILON*2;
855     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
856
857     int n = node->sample_count, i;
858     int m = data->get_num_classes();
859
860     cv::AutoBuffer<uchar> inn_buf;
861     if( !_ext_buf )
862         inn_buf.allocate(n*(2*sizeof(int) + sizeof(float)));
863     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
864     float* values_buf = (float*)ext_buf;
865     int* missing_buf = (int*)(values_buf + n);
866     const float* values = 0;
867     const int* missing = 0;
868     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
869     int* responses_buf = missing_buf + n;
870     const int* responses = data->get_class_labels( node, responses_buf );
871
872     double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
873     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
874     bool is_find_split = false;
875     float pmin, pmax;
876     int smpi = 0;
877     while ( missing[smpi] && (smpi < n) )
878         smpi++;
879     assert(smpi < n);
880
881     pmin = values[smpi];
882     pmax = pmin;
883     for (; smpi < n; smpi++)
884     {
885         float ptemp = values[smpi];
886         int m = missing[smpi];
887         if (m) continue;
888         if ( ptemp < pmin)
889             pmin = ptemp;
890         if ( ptemp > pmax)
891             pmax = ptemp;
892     }
893     float fdiff = pmax-pmin;
894     if (fdiff > epsilon)
895     {
896         is_find_split = true;
897         CvRNG* rng = &data->rng;
898         split_val = pmin + cvRandReal(rng) * fdiff ;
899         if (split_val - pmin <= FLT_EPSILON)
900             split_val = pmin + split_delta;
901         if (pmax - split_val <= FLT_EPSILON)
902             split_val = pmax - split_delta;       
903
904         // calculate Gini index
905         if ( !priors )
906         {
907             int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
908             int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
909             int L = 0, R = 0;
910     
911             // init arrays of class instance counters on both sides of the split
912             for( i = 0; i < m; i++ )
913             {
914                 lc[i] = 0;
915                 rc[i] = 0;
916             }
917             for( int si = 0; si < n; si++ )
918             {
919                 int r = responses[si];
920                 float val = values[si];
921                 int m = missing[si];
922                 if (m) continue;
923                 if ( val < split_val )
924                 {
925                     lc[r]++;
926                     L++;
927                 }
928                 else
929                 {
930                     rc[r]++;
931                     R++;
932                 }
933             }
934             for (int i = 0; i < m; i++)
935             {
936                 lbest_val += lc[i]*lc[i];
937                 rbest_val += rc[i]*rc[i];
938             }
939             best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
940         }
941         else
942         {
943             double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
944             double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
945             double L = 0, R = 0;
946     
947             // init arrays of class instance counters on both sides of the split
948             for( i = 0; i < m; i++ )
949             {
950                 lc[i] = 0;
951                 rc[i] = 0;
952             }
953             for( int si = 0; si < n; si++ )
954             {
955                 int r = responses[si];
956                 float val = values[si];
957                 int m = missing[si];
958                 double p = priors[si];
959                 if (m) continue;
960                 if ( val < split_val )
961                 {
962                     lc[r] += p;
963                     L += p;
964                 }
965                 else
966                 {
967                     rc[r] += p;
968                     R += p;
969                 }
970             }
971             for (int i = 0; i < m; i++)
972             {
973                 lbest_val += lc[i]*lc[i];
974                 rbest_val += rc[i]*rc[i];
975             }
976             best_val = (lbest_val*R + rbest_val*L) / (L*R);
977         }
978         
979     }
980
981     CvDTreeSplit* split = 0;
982     if( is_find_split )
983     {
984         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
985         split->var_idx = vi;
986         split->ord.c = (float)split_val;
987         split->ord.split_point = -1;
988         split->inversed = 0;
989         split->quality = (float)best_val;
990     }
991     return split;
992 }
993
994 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
995                                                     uchar* _ext_buf )
996 {
997     int ci = data->get_var_type(vi);
998     int n = node->sample_count;
999     int cm = data->get_num_classes(); 
1000     int vm = data->cat_count->data.i[ci];
1001     double best_val = init_quality;
1002     CvDTreeSplit *split = 0;
1003
1004     if ( vm > 1 )
1005     {
1006         cv::AutoBuffer<int> inn_buf;
1007         if( !_ext_buf )
1008             inn_buf.allocate(2*n);
1009         int* ext_buf = _ext_buf ? (int*)_ext_buf : (int*)inn_buf;
1010
1011         const int* labels = data->get_cat_var_data( node, vi, ext_buf );
1012         const int* responses = data->get_class_labels( node, ext_buf + n );
1013     
1014         const double* priors = data->have_priors ? data->priors_mult->data.db : 0;       
1015
1016         // create random class mask
1017         int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1018         for (int i = 0; i < vm; i++)
1019         {
1020             valid_cidx[i] = -1;
1021         }
1022         for (int si = 0; si < n; si++)
1023         {
1024             int c = labels[si];
1025             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1026                 continue;
1027             valid_cidx[c]++;
1028         }
1029
1030         int valid_ccount = 0;
1031         for (int i = 0; i < vm; i++)
1032             if (valid_cidx[i] >= 0)
1033             {
1034                 valid_cidx[i] = valid_ccount;
1035                 valid_ccount++;
1036             }
1037         if (valid_ccount > 1)
1038         {
1039             CvRNG* rng = forest->get_rng();
1040             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1041
1042             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1043             CvMat submask;
1044             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1045             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1046             cvSet( &submask, cvScalar(1) );
1047             for (int i = 0; i < valid_ccount; i++)
1048             {
1049                 uchar temp;
1050                 int i1 = cvRandInt( rng ) % valid_ccount;
1051                 int i2 = cvRandInt( rng ) % valid_ccount;
1052                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1053             }
1054
1055             split = _split ? _split : data->new_split_cat( 0, -1.0f );
1056             split->var_idx = vi;
1057             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1058
1059             // calculate Gini index
1060             double lbest_val = 0, rbest_val = 0;
1061             if( !priors )
1062             {
1063                 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
1064                 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
1065                 int L = 0, R = 0;
1066                 // init arrays of class instance counters on both sides of the split
1067                 for(int i = 0; i < cm; i++ )
1068                 {
1069                     lc[i] = 0;
1070                     rc[i] = 0;
1071                 }
1072                 for( int si = 0; si < n; si++ )
1073                 {
1074                     int r = responses[si];
1075                     int var_class_idx = labels[si];
1076                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1077                         continue;
1078                     int mask_class_idx = valid_cidx[var_class_idx];
1079                     if (var_class_mask->data.ptr[mask_class_idx])
1080                     {
1081                         lc[r]++;
1082                         L++;                 
1083                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1084                     }
1085                     else
1086                     {
1087                         rc[r]++;
1088                         R++;
1089                     }
1090                 }
1091                 for (int i = 0; i < cm; i++)
1092                 {
1093                     lbest_val += lc[i]*lc[i];
1094                     rbest_val += rc[i]*rc[i];
1095                 }                
1096                 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1097             }
1098             else
1099             {
1100                 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
1101                 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
1102                 double L = 0, R = 0;
1103                 // init arrays of class instance counters on both sides of the split
1104                 for(int i = 0; i < cm; i++ )
1105                 {
1106                     lc[i] = 0;
1107                     rc[i] = 0;
1108                 }
1109                 for( int si = 0; si < n; si++ )
1110                 {
1111                     int r = responses[si];
1112                     int var_class_idx = labels[si];
1113                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1114                         continue;
1115                     double p = priors[si];
1116                     int mask_class_idx = valid_cidx[var_class_idx];
1117                     
1118                     if (var_class_mask->data.ptr[mask_class_idx])
1119                     {
1120                         lc[r]+=p;
1121                         L+=p;                 
1122                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1123                     }
1124                     else
1125                     {
1126                         rc[r]+=p;
1127                         R+=p;
1128                     }
1129                 }
1130                 for (int i = 0; i < cm; i++)
1131                 {
1132                     lbest_val += lc[i]*lc[i];
1133                     rbest_val += rc[i]*rc[i];
1134                 }
1135                 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1136             }
1137             split->quality = (float)best_val;
1138
1139             cvReleaseMat(&var_class_mask);
1140         }   
1141     }  
1142
1143     return split;
1144 }
1145
1146 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1147                                                   uchar* _ext_buf )
1148 {
1149     const float epsilon = FLT_EPSILON*2;
1150     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1151     int n = node->sample_count;
1152     cv::AutoBuffer<uchar> inn_buf;
1153     if( !_ext_buf )
1154         inn_buf.allocate(n*(2*sizeof(int) + 2*sizeof(float)));
1155     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
1156     float* values_buf = (float*)ext_buf;
1157     int* missing_buf = (int*)(values_buf + n);
1158     const float* values = 0;
1159     const int* missing = 0;
1160     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1161     float* responses_buf =  (float*)(missing_buf + n);
1162     int* sample_indices_buf =  (int*)(responses_buf + n);
1163     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1164
1165     double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1166     int L = 0, R = 0;
1167
1168     bool is_find_split = false;
1169     float pmin, pmax;
1170     int smpi = 0;
1171     while ( missing[smpi] && (smpi < n) )
1172         smpi++;
1173
1174     assert(smpi < n);
1175
1176     pmin = values[smpi];
1177     pmax = pmin;
1178     for (; smpi < n; smpi++)
1179     {
1180         float ptemp = values[smpi];
1181         int m = missing[smpi];
1182         if (m) continue;
1183         if ( ptemp < pmin)
1184             pmin = ptemp;
1185         if ( ptemp > pmax)
1186             pmax = ptemp;
1187     }
1188     float fdiff = pmax-pmin;
1189     if (fdiff > epsilon)
1190     {
1191         is_find_split = true;
1192         CvRNG* rng = &data->rng;
1193         split_val = pmin + cvRandReal(rng) * fdiff ;
1194         if (split_val - pmin <= FLT_EPSILON)
1195             split_val = pmin + split_delta;
1196         if (pmax - split_val <= FLT_EPSILON)
1197             split_val = pmax - split_delta;    
1198
1199         for (int si = 0; si < n; si++)
1200         {
1201             float r = responses[si];
1202             float val = values[si];
1203             int m = missing[si];
1204             if (m) continue;
1205             if (val < split_val)
1206             {
1207                 lsum += r;
1208                 L++;
1209             }
1210             else
1211             {
1212                 rsum += r;
1213                 R++;            
1214             }
1215         }
1216         best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1217     }
1218
1219     CvDTreeSplit* split = 0;
1220     if( is_find_split )
1221     {
1222         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1223         split->var_idx = vi;
1224         split->ord.c = (float)split_val;
1225         split->ord.split_point = -1;
1226         split->inversed = 0;
1227         split->quality = (float)best_val;
1228     }
1229     return split;
1230 }
1231
1232 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split,
1233                                                   uchar* _ext_buf )
1234 {
1235     int ci = data->get_var_type(vi);
1236     int n = node->sample_count;
1237     int vm = data->cat_count->data.i[ci];
1238     double best_val = init_quality;
1239     CvDTreeSplit *split = 0;
1240     float lsum = 0, rsum = 0;
1241
1242     if ( vm > 1 )
1243     {
1244         int base_size =  vm*sizeof(int);
1245         cv::AutoBuffer<uchar> inn_buf(base_size);
1246         if( !_ext_buf )
1247             inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
1248         uchar* base_buf = (uchar*)inn_buf;
1249         uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1250         int* labels_buf = (int*)ext_buf;
1251         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1252         float* responses_buf =  (float*)(labels_buf + n);
1253         int* sample_indices_buf = (int*)(responses_buf + n);
1254         const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
1255
1256         // create random class mask
1257         int *valid_cidx = (int*)base_buf;
1258         for (int i = 0; i < vm; i++)
1259         {
1260             valid_cidx[i] = -1;
1261         }
1262         for (int si = 0; si < n; si++)
1263         {
1264             int c = labels[si];
1265             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1266                         continue;
1267             valid_cidx[c]++;
1268         }
1269
1270         int valid_ccount = 0;
1271         for (int i = 0; i < vm; i++)
1272             if (valid_cidx[i] >= 0)
1273             {
1274                 valid_cidx[i] = valid_ccount;
1275                 valid_ccount++;
1276             }
1277         if (valid_ccount > 1)
1278         {
1279             CvRNG* rng = forest->get_rng();
1280             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1281
1282             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1283             CvMat submask;
1284             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1285             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1286             cvSet( &submask, cvScalar(1) );
1287             for (int i = 0; i < valid_ccount; i++)
1288             {
1289                 uchar temp;
1290                 int i1 = cvRandInt( rng ) % valid_ccount;
1291                 int i2 = cvRandInt( rng ) % valid_ccount;
1292                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1293             }
1294
1295             split = _split ? _split : data->new_split_cat( 0, -1.0f);
1296             split->var_idx = vi;
1297             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1298
1299             int L = 0, R = 0;
1300             for( int si = 0; si < n; si++ )
1301             {
1302                 float r = responses[si];
1303                 int var_class_idx = labels[si];
1304                 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1305                         continue;
1306                 int mask_class_idx = valid_cidx[var_class_idx];
1307                 if (var_class_mask->data.ptr[mask_class_idx])
1308                 {
1309                     lsum += r;
1310                     L++;                 
1311                     split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1312                 }
1313                 else
1314                 {
1315                     rsum += r;
1316                     R++;
1317                 }
1318             }
1319             best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1320
1321             split->quality = (float)best_val;
1322
1323             cvReleaseMat(&var_class_mask);
1324         }   
1325     }  
1326
1327     return split;
1328 }
1329
1330 void CvForestERTree::split_node_data( CvDTreeNode* node )
1331 {
1332     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1333     char* dir = (char*)data->direction->data.ptr;
1334     CvDTreeNode *left = 0, *right = 0;
1335     int new_buf_idx = data->get_child_buf_idx( node );
1336     CvMat* buf = data->buf;
1337     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
1338
1339     complete_node_dir(node);
1340
1341     for( i = nl = nr = 0; i < n; i++ )
1342     {
1343         int d = dir[i];
1344         nr += d;
1345         nl += d^1;
1346     }
1347
1348     bool split_input_data;
1349     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1350     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1351
1352     split_input_data = node->depth + 1 < data->params.max_depth &&
1353         (node->left->sample_count > data->params.min_sample_count ||
1354         node->right->sample_count > data->params.min_sample_count);
1355
1356     cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)+sizeof(float)));
1357     // split ordered vars
1358     for( vi = 0; vi < data->var_count; vi++ )
1359     {
1360         int ci = data->get_var_type(vi);
1361         if (ci >= 0) continue;
1362         
1363         int n1 = node->get_num_valid(vi), nr1 = 0;
1364         float* values_buf = (float*)(uchar*)inn_buf;
1365         int* missing_buf = (int*)(values_buf + n);
1366         const float* values = 0;
1367         const int* missing = 0;
1368         data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing, 0 );
1369
1370         for( i = 0; i < n; i++ )
1371             nr1 += (!missing[i] & dir[i]);
1372         left->set_num_valid(vi, n1 - nr1);
1373         right->set_num_valid(vi, nr1);                
1374     }
1375     // split categorical vars, responses and cv_labels using new_idx relocation table
1376     for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1377     {
1378         int ci = data->get_var_type(vi);
1379         if (ci < 0) continue;
1380
1381         int n1 = node->get_num_valid(vi), nr1 = 0;
1382         const int* src_lbls = data->get_cat_var_data(node, vi, (int*)(uchar*)inn_buf);
1383
1384         for(i = 0; i < n; i++)
1385             temp_buf[i] = src_lbls[i];
1386
1387         if (data->is_buf_16u)
1388         {
1389             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + 
1390                 ci*scount + left->offset);
1391             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + 
1392                 ci*scount + right->offset);
1393             
1394             for( i = 0; i < n; i++ )
1395             {
1396                 int d = dir[i];
1397                 int idx = temp_buf[i];
1398                 if (d)
1399                 {
1400                     *rdst = (unsigned short)idx;
1401                     rdst++;
1402                     nr1 += (idx != 65535);
1403                 }
1404                 else
1405                 {
1406                     *ldst = (unsigned short)idx;
1407                     ldst++;
1408                 }
1409             }
1410
1411             if( vi < data->var_count )
1412             {
1413                 left->set_num_valid(vi, n1 - nr1);
1414                 right->set_num_valid(vi, nr1);
1415             }
1416         }
1417         else
1418         {
1419             int *ldst = buf->data.i + left->buf_idx*buf->cols + 
1420                 ci*scount + left->offset;
1421             int *rdst = buf->data.i + right->buf_idx*buf->cols + 
1422                 ci*scount + right->offset;
1423             
1424             for( i = 0; i < n; i++ )
1425             {
1426                 int d = dir[i];
1427                 int idx = temp_buf[i];
1428                 if (d)
1429                 {
1430                     *rdst = idx;
1431                     rdst++;
1432                     nr1 += (idx >= 0);
1433                 }
1434                 else
1435                 {
1436                     *ldst = idx;
1437                     ldst++;
1438                 }
1439                 
1440             }
1441
1442             if( vi < data->var_count )
1443             {
1444                 left->set_num_valid(vi, n1 - nr1);
1445                 right->set_num_valid(vi, nr1);
1446             }
1447         }        
1448     }
1449
1450     // split sample indices
1451     int *sample_idx_src_buf = (int*)(uchar*)inn_buf;
1452     const int* sample_idx_src = 0;
1453     if (split_input_data)
1454     {
1455         sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
1456
1457         for(i = 0; i < n; i++)
1458             temp_buf[i] = sample_idx_src[i];
1459
1460         int pos = data->get_work_var_count();
1461        
1462         if (data->is_buf_16u)
1463         {
1464             unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
1465                 pos*scount + left->offset);
1466             unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + 
1467                 pos*scount + right->offset);
1468             
1469             for (i = 0; i < n; i++)
1470             {
1471                 int d = dir[i];
1472                 unsigned short idx = (unsigned short)temp_buf[i];
1473                 if (d)
1474                 {
1475                     *rdst = idx;
1476                     rdst++;
1477                 }
1478                 else
1479                 {
1480                     *ldst = idx;
1481                     ldst++;
1482                 }
1483             }
1484         }
1485         else
1486         {
1487             int* ldst = buf->data.i + left->buf_idx*buf->cols + 
1488                 pos*scount + left->offset;
1489             int* rdst = buf->data.i + right->buf_idx*buf->cols + 
1490                 pos*scount + right->offset;
1491             for (i = 0; i < n; i++)
1492             {
1493                 int d = dir[i];
1494                 int idx = temp_buf[i];
1495                 if (d)
1496                 {
1497                     *rdst = idx;
1498                     rdst++;
1499                 }
1500                 else
1501                 {
1502                     *ldst = idx;
1503                     ldst++;
1504                 }
1505             }
1506         }
1507     }
1508     
1509     // deallocate the parent node data that is not needed anymore
1510     data->free_node_data(node);    
1511 }
1512
1513 CvERTrees::CvERTrees()
1514 {
1515 }
1516
1517 CvERTrees::~CvERTrees()
1518 {
1519 }
1520
1521 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1522                         const CvMat* _responses, const CvMat* _var_idx,
1523                         const CvMat* _sample_idx, const CvMat* _var_type,
1524                         const CvMat* _missing_mask, CvRTParams params )
1525 {
1526     bool result = false;
1527
1528     CV_FUNCNAME("CvERTrees::train");
1529     __BEGIN__
1530     int var_count = 0;
1531
1532     clear();
1533
1534     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1535         params.regression_accuracy, params.use_surrogates, params.max_categories,
1536         params.cv_folds, params.use_1se_rule, false, params.priors );
1537
1538     data = new CvERTreeTrainData();
1539     CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1540         _sample_idx, _var_type, _missing_mask, tree_params, true));
1541
1542     var_count = data->var_count;
1543     if( params.nactive_vars > var_count )
1544         params.nactive_vars = var_count;
1545     else if( params.nactive_vars == 0 )
1546         params.nactive_vars = (int)sqrt((double)var_count);
1547     else if( params.nactive_vars < 0 )
1548         CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1549
1550     // Create mask of active variables at the tree nodes
1551     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1552     if( params.calc_var_importance )
1553     {
1554         CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
1555         cvZero(var_importance);
1556     }
1557     { // initialize active variables mask
1558         CvMat submask1, submask2;
1559         CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1560         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1561         cvSet( &submask1, cvScalar(1) );
1562         if( params.nactive_vars < active_var_mask->cols )
1563         {
1564             cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1565             cvZero( &submask2 );
1566         }
1567     }
1568
1569     CV_CALL(result = grow_forest( params.term_crit ));
1570
1571     result = true;
1572
1573     __END__
1574     return result;
1575     
1576 }
1577
1578 bool CvERTrees::train( CvMLData* data, CvRTParams params)
1579 {
1580    bool result = false;
1581
1582     CV_FUNCNAME( "CvERTrees::train" );
1583
1584     __BEGIN__;
1585
1586     CV_CALL( result = CvRTrees::train( data, params) );
1587
1588     __END__;
1589
1590     return result;
1591 }
1592
1593 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1594 {
1595     bool result = false;
1596
1597     CvMat* sample_idx_for_tree      = 0;
1598
1599     CV_FUNCNAME("CvERTrees::grow_forest");
1600     __BEGIN__;
1601
1602     const int max_ntrees = term_crit.max_iter;
1603     const double max_oob_err = term_crit.epsilon;
1604
1605     const int dims = data->var_count;
1606     float maximal_response = 0;
1607
1608     CvMat* oob_sample_votes        = 0;
1609     CvMat* oob_responses       = 0;
1610
1611     float* oob_samples_perm_ptr= 0;
1612
1613     float* samples_ptr     = 0;
1614     uchar* missing_ptr     = 0;
1615     float* true_resp_ptr   = 0;
1616     bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1617
1618     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1619     // oob_num_of_predictions[i] = number of summands
1620     //                            (number of predictions for the i-th sample)
1621     // initialize these variable to avoid warning C4701
1622     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1623     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1624      
1625     nsamples = data->sample_count;
1626     nclasses = data->get_num_classes();
1627
1628     if ( is_oob_or_vimportance )
1629     {
1630         if( data->is_classifier )
1631         {
1632             CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1633             cvZero(oob_sample_votes);
1634         }
1635         else
1636         {
1637             // oob_responses[0,i] = oob_predictions_sum[i]
1638             //    = sum of predicted values for the i-th sample
1639             // oob_responses[1,i] = oob_num_of_predictions[i]
1640             //    = number of summands (number of predictions for the i-th sample)
1641             CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1642             cvZero(oob_responses);
1643             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1644             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1645         }
1646         
1647         CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1648         CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1649         CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1650         CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));            
1651
1652         CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1653         {
1654             double minval, maxval;
1655             CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1656             cvMinMaxLoc( &responses, &minval, &maxval );
1657             maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1658         }
1659     }
1660    
1661     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1662     memset( trees, 0, sizeof(trees[0])*max_ntrees );
1663
1664     CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1665
1666     for (int i = 0; i < nsamples; i++)
1667         sample_idx_for_tree->data.i[i] = i;
1668     ntrees = 0;
1669     while( ntrees < max_ntrees )
1670     {
1671         int i, oob_samples_count = 0;
1672         double ncorrect_responses = 0; // used for estimation of variable importance
1673         CvForestTree* tree = 0;
1674
1675         trees[ntrees] = new CvForestERTree();
1676         tree = (CvForestERTree*)trees[ntrees];
1677         CV_CALL(tree->train( data, 0, this ));
1678
1679         if ( is_oob_or_vimportance )
1680         {
1681             CvMat sample, missing;
1682             // form array of OOB samples indices and get these samples
1683             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
1684             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
1685
1686             oob_error = 0;
1687             for( i = 0; i < nsamples; i++,
1688                 sample.data.fl += dims, missing.data.ptr += dims )
1689             {
1690                 CvDTreeNode* predicted_node = 0;
1691                 
1692                 // predict oob samples
1693                 if( !predicted_node )
1694                     CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1695
1696                 if( !data->is_classifier ) //regression
1697                 {
1698                     double avg_resp, resp = predicted_node->value;
1699                     oob_predictions_sum.data.fl[i] += (float)resp;
1700                     oob_num_of_predictions.data.fl[i] += 1;
1701
1702                     // compute oob error
1703                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1704                     avg_resp -= true_resp_ptr[i];
1705                     oob_error += avg_resp*avg_resp;
1706                     resp = (resp - true_resp_ptr[i])/maximal_response;
1707                     ncorrect_responses += exp( -resp*resp );
1708                 }
1709                 else //classification
1710                 {
1711                     double prdct_resp;
1712                     CvPoint max_loc;
1713                     CvMat votes;
1714
1715                     cvGetRow(oob_sample_votes, &votes, i);
1716                     votes.data.i[predicted_node->class_idx]++;
1717
1718                     // compute oob error
1719                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1720
1721                     prdct_resp = data->cat_map->data.i[max_loc.x];
1722                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1723
1724                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1725                 }
1726                 oob_samples_count++;
1727             }
1728             if( oob_samples_count > 0 )
1729                 oob_error /= (double)oob_samples_count;
1730
1731             // estimate variable importance
1732             if( var_importance && oob_samples_count > 0 )
1733             {
1734                 int m;
1735
1736                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1737                 for( m = 0; m < dims; m++ )
1738                 {
1739                     double ncorrect_responses_permuted = 0;
1740                     // randomly permute values of the m-th variable in the oob samples
1741                     float* mth_var_ptr = oob_samples_perm_ptr + m;
1742
1743                     for( i = 0; i < nsamples; i++ )
1744                     {
1745                         int i1, i2;
1746                         float temp;
1747
1748                         i1 = cvRandInt( &rng ) % nsamples;
1749                         i2 = cvRandInt( &rng ) % nsamples;
1750                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1751
1752                         // turn values of (m-1)-th variable, that were permuted
1753                         // at the previous iteration, untouched
1754                         if( m > 1 )
1755                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1756                     }
1757
1758                     // predict "permuted" cases and calculate the number of votes for the
1759                     // correct class in the variable-m-permuted oob data
1760                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1761                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1762                     for( i = 0; i < nsamples; i++,
1763                         sample.data.fl += dims, missing.data.ptr += dims )
1764                     {
1765                         double predct_resp, true_resp;
1766
1767                         predct_resp = tree->predict(&sample, &missing, true)->value;
1768                         true_resp   = true_resp_ptr[i];
1769                         if( data->is_classifier )
1770                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1771                         else
1772                         {
1773                             true_resp = (true_resp - predct_resp)/maximal_response;
1774                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
1775                         }
1776                     }
1777                     var_importance->data.fl[m] += (float)(ncorrect_responses
1778                         - ncorrect_responses_permuted);
1779                 }
1780             }
1781         }
1782         ntrees++;
1783         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1784             break;
1785     }
1786     if( var_importance )
1787     {
1788         for ( int vi = 0; vi < var_importance->cols; vi++ )
1789                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1790                     var_importance->data.fl[vi] : 0;
1791         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1792     }
1793
1794     result = true;
1795     
1796     cvFree( &oob_samples_perm_ptr );
1797     cvFree( &samples_ptr );
1798     cvFree( &missing_ptr );
1799     cvFree( &true_resp_ptr );
1800     
1801     cvReleaseMat( &sample_idx_for_tree );
1802
1803     cvReleaseMat( &oob_sample_votes );
1804     cvReleaseMat( &oob_responses );
1805
1806     __END__;
1807
1808     return result;
1809 }
1810
1811 using namespace cv;
1812
1813 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1814                       const Mat& _responses, const Mat& _var_idx,
1815                       const Mat& _sample_idx, const Mat& _var_type,
1816                       const Mat& _missing_mask, CvRTParams params )
1817 {
1818     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1819     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1820     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1821                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1822                  mmask.data.ptr ? &mmask : 0, params);
1823 }
1824
1825 // End of file.
1826