]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/src/ml/mlertrees.cpp
fixed build fail on Mac (ticket #46)
[opencv.git] / opencv / src / ml / mlertrees.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2
3   IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
5   By downloading, copying, installing or using the software you agree to this license.
6   If you do not agree to this license, do not download, install,
7   copy or use the software.
8
9
10                         Intel License Agreement
11
12  Copyright (C) 2000, Intel Corporation, all rights reserved.
13  Third party copyrights are property of their respective owners.
14
15  Redistribution and use in source and binary forms, with or without modification,
16  are permitted provided that the following conditions are met:
17
18    * Redistribution's of source code must retain the above copyright notice,
19      this list of conditions and the following disclaimer.
20
21    * Redistribution's in binary form must reproduce the above copyright notice,
22      this list of conditions and the following disclaimer in the documentation
23      and/or other materials provided with the distribution.
24
25    * The name of Intel Corporation may not be used to endorse or promote products
26      derived from this software without specific prior written permission.
27
28  This software is provided by the copyright holders and contributors "as is" and
29  any express or implied warranties, including, but not limited to, the implied
30  warranties of merchantability and fitness for a particular purpose are disclaimed.
31  In no event shall the Intel Corporation or contributors be liable for any direct,
32  indirect, incidental, special, exemplary, or consequential damages
33  (including, but not limited to, procurement of substitute goods or services;
34  loss of use, data, or profits; or business interruption) however caused
35  and on any theory of liability, whether in contract, strict liability,
36  or tort (including negligence or otherwise) arising in any way out of
37  the use of this software, even if advised of the possibility of such damage.
38
39 M*/
40
41 #include "_ml.h"
42
43 #ifdef _OPENMP
44 #include "omp.h"
45 #endif
46
47 static const float ord_nan = FLT_MAX*0.5f;
48 static const int min_block_size = 1 << 16;
49 static const int block_size_delta = 1 << 10;
50
51 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
52 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
53
54 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
55 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
56
57 ///
58
59 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
60     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
61     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
62     bool _shared, bool _add_labels, bool _update_data )
63 {
64     CvMat* sample_indices = 0;
65     CvMat* var_type0 = 0;
66     CvMat* tmp_map = 0;
67     int** int_ptr = 0;
68     CvPair16u32s* pair16u32s_ptr = 0;
69     CvDTreeTrainData* data = 0;
70     float *_fdst = 0;
71     int *_idst = 0;
72     unsigned short* udst = 0;
73     int* idst = 0;
74
75     CV_FUNCNAME( "CvERTreeTrainData::set_data" );
76
77     __BEGIN__;
78     
79     int sample_all = 0, r_type = 0, cv_n;
80     int total_c_count = 0;
81     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
82     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
83     int vi, i, size;
84     char err[100];
85     const int *sidx = 0, *vidx = 0;
86     
87     if ( _params.use_surrogates )
88         CV_ERROR(CV_StsBadArg, "CvERTrees do not support surrogate splits");
89         
90     if( _update_data && data_root )
91     {
92         CV_ERROR(CV_StsBadArg, "CvERTrees do not support data update");
93     }
94
95     clear();
96
97     var_all = 0;
98     rng = cvRNG(-1);
99
100     CV_CALL( set_params( _params ));
101
102     // check parameter types and sizes
103     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
104
105     train_data = _train_data;
106     responses = _responses;
107     missing_mask = _missing_mask;
108
109     if( _tflag == CV_ROW_SAMPLE )
110     {
111         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
112         dv_step = 1;
113         if( _missing_mask )
114             ms_step = _missing_mask->step, mv_step = 1;
115     }
116     else
117     {
118         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
119         ds_step = 1;
120         if( _missing_mask )
121             mv_step = _missing_mask->step, ms_step = 1;
122     }
123     tflag = _tflag;
124
125     sample_count = sample_all;
126     var_count = var_all;
127
128     if( _sample_idx )
129     {
130         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
131         sidx = sample_indices->data.i;
132         sample_count = sample_indices->rows + sample_indices->cols - 1;
133     }
134
135     if( _var_idx )
136     {
137         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
138         vidx = var_idx->data.i;
139         var_count = var_idx->rows + var_idx->cols - 1;
140     }
141
142     if( !CV_IS_MAT(_responses) ||
143         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
144          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
145         (_responses->rows != 1 && _responses->cols != 1) ||
146         _responses->rows + _responses->cols - 1 != sample_all )
147         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
148                   "floating-point vector containing as many elements as "
149                   "the total number of samples in the training data matrix" );
150    
151     is_buf_16u = false;
152     if ( sample_count < 65536 )  
153         is_buf_16u = true;                                
154     
155   
156     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
157
158     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
159    
160     
161     cat_var_count = 0;
162     ord_var_count = -1;
163
164     is_classifier = r_type == CV_VAR_CATEGORICAL;
165
166     // step 0. calc the number of categorical vars
167     for( vi = 0; vi < var_count; vi++ )
168     {
169         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
170             cat_var_count++ : ord_var_count--;
171     }
172
173     ord_var_count = ~ord_var_count;
174     cv_n = params.cv_folds;
175     // set the two last elements of var_type array to be able
176     // to locate responses and cross-validation labels using
177     // the corresponding get_* functions.
178     var_type->data.i[var_count] = cat_var_count;
179     var_type->data.i[var_count+1] = cat_var_count+1;
180
181     // in case of single ordered predictor we need dummy cv_labels
182     // for safe split_node_data() operation
183     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
184
185     work_var_count = cat_var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
186     buf_size = (work_var_count + 1)*sample_count;
187     shared = _shared;
188     buf_count = shared ? 2 : 1;
189     
190     if ( is_buf_16u )
191     {
192         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
193         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
194     }
195     else
196     {
197         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
198         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
199     }    
200
201     size = is_classifier ? cat_var_count+1 : cat_var_count;
202     size = !size ? 1 : size;
203     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
204     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
205     
206     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
207     size = !size ? 1 : size;
208     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
209
210     // now calculate the maximum size of split,
211     // create memory storage that will keep nodes and splits of the decision tree
212     // allocate root node and the buffer for the whole training data
213     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
214         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
215     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
216     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
217     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
218     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
219
220     nv_size = var_count*sizeof(int);
221     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
222
223     temp_block_size = nv_size;
224
225     if( cv_n )
226     {
227         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
228             CV_ERROR( CV_StsOutOfRange,
229                 "The many folds in cross-validation for such a small dataset" );
230
231         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
232         temp_block_size = MAX(temp_block_size, cv_size);
233     }
234
235     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
236     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
237     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
238     if( cv_size )
239         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
240
241     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
242
243     max_c_count = 1;
244
245     _fdst = 0;
246     _idst = 0;
247     if (ord_var_count)
248         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
249     if (is_buf_16u && (cat_var_count || is_classifier))
250         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
251
252     // transform the training data to convenient representation
253     for( vi = 0; vi <= var_count; vi++ )
254     {
255         int ci;
256         const uchar* mask = 0;
257         int m_step = 0, step;
258         const int* idata = 0;
259         const float* fdata = 0;
260         int num_valid = 0;
261
262         if( vi < var_count ) // analyze i-th input variable
263         {
264             int vi0 = vidx ? vidx[vi] : vi;
265             ci = get_var_type(vi);
266             step = ds_step; m_step = ms_step;
267             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
268                 idata = _train_data->data.i + vi0*dv_step;
269             else
270                 fdata = _train_data->data.fl + vi0*dv_step;
271             if( _missing_mask )
272                 mask = _missing_mask->data.ptr + vi0*mv_step;
273         }
274         else // analyze _responses
275         {
276             ci = cat_var_count;
277             step = CV_IS_MAT_CONT(_responses->type) ?
278                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
279             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
280                 idata = _responses->data.i;
281             else
282                 fdata = _responses->data.fl;
283         }
284
285         if( (vi < var_count && ci>=0) ||
286             (vi == var_count && is_classifier) ) // process categorical variable or response
287         {
288             int c_count, prev_label;
289             int* c_map;
290             
291             if (is_buf_16u)
292                 udst = (unsigned short*)(buf->data.s + ci*sample_count);
293             else
294                 idst = buf->data.i + ci*sample_count;
295             
296             // copy data
297             for( i = 0; i < sample_count; i++ )
298             {
299                 int val = INT_MAX, si = sidx ? sidx[i] : i;
300                 if( !mask || !mask[si*m_step] )
301                 {
302                     if( idata )
303                         val = idata[si*step];
304                     else
305                     {
306                         float t = fdata[si*step];
307                         val = cvRound(t);
308                         if( val != t )
309                         {
310                             sprintf( err, "%d-th value of %d-th (categorical) "
311                                 "variable is not an integer", i, vi );
312                             CV_ERROR( CV_StsBadArg, err );
313                         }
314                     }
315
316                     if( val == INT_MAX )
317                     {
318                         sprintf( err, "%d-th value of %d-th (categorical) "
319                             "variable is too large", i, vi );
320                         CV_ERROR( CV_StsBadArg, err );
321                     }
322                     num_valid++;
323                 }
324                 if (is_buf_16u)
325                 {
326                     _idst[i] = val;
327                     pair16u32s_ptr[i].u = udst + i;
328                     pair16u32s_ptr[i].i = _idst + i;
329                 }   
330                 else
331                 {
332                     idst[i] = val;
333                     int_ptr[i] = idst + i;
334                 }
335             }
336
337             c_count = num_valid > 0;
338
339             if (is_buf_16u)
340             {
341                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
342                 // count the categories
343                 for( i = 1; i < num_valid; i++ )
344                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
345                         c_count ++ ;
346             }
347             else
348             {
349                 icvSortIntPtr( int_ptr, sample_count, 0 );
350                 // count the categories
351                 for( i = 1; i < num_valid; i++ )
352                     c_count += *int_ptr[i] != *int_ptr[i-1];
353             }
354
355             if( vi > 0 )
356                 max_c_count = MAX( max_c_count, c_count );
357             cat_count->data.i[ci] = c_count;
358             cat_ofs->data.i[ci] = total_c_count;
359
360             // resize cat_map, if need
361             if( cat_map->cols < total_c_count + c_count )
362             {
363                 tmp_map = cat_map;
364                 CV_CALL( cat_map = cvCreateMat( 1,
365                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
366                 for( i = 0; i < total_c_count; i++ )
367                     cat_map->data.i[i] = tmp_map->data.i[i];
368                 cvReleaseMat( &tmp_map );
369             }
370
371             c_map = cat_map->data.i + total_c_count;
372             total_c_count += c_count;
373
374             c_count = -1;
375             if (is_buf_16u)
376             {
377                 // compact the class indices and build the map
378                 prev_label = ~*pair16u32s_ptr[0].i;
379                 for( i = 0; i < num_valid; i++ )
380                 {
381                     int cur_label = *pair16u32s_ptr[i].i;
382                     if( cur_label != prev_label )
383                         c_map[++c_count] = prev_label = cur_label;
384                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
385                 }
386                 // replace labels for missing values with 65535
387                 for( ; i < sample_count; i++ )
388                     *pair16u32s_ptr[i].u = 65535;
389             }
390             else
391             {
392                 // compact the class indices and build the map
393                 prev_label = ~*int_ptr[0];
394                 for( i = 0; i < num_valid; i++ )
395                 {
396                     int cur_label = *int_ptr[i];
397                     if( cur_label != prev_label )
398                         c_map[++c_count] = prev_label = cur_label;
399                     *int_ptr[i] = c_count;
400                 }
401                 // replace labels for missing values with -1
402                 for( ; i < sample_count; i++ )
403                     *int_ptr[i] = -1;
404             }           
405         }
406         else if( ci < 0 ) // process ordered variable
407         {
408             for( i = 0; i < sample_count; i++ )
409             {
410                 float val = ord_nan;
411                 int si = sidx ? sidx[i] : i;
412                 if( !mask || !mask[si*m_step] )
413                 {
414                     if( idata )
415                         val = (float)idata[si*step];
416                     else
417                         val = fdata[si*step];
418
419                     if( fabs(val) >= ord_nan )
420                     {
421                         sprintf( err, "%d-th value of %d-th (ordered) "
422                             "variable (=%g) is too large", i, vi, val );
423                         CV_ERROR( CV_StsBadArg, err );
424                     }
425                 }
426                 num_valid++;
427             }
428         }
429         if( vi < var_count )
430             data_root->set_num_valid(vi, num_valid);
431     }
432
433     // set sample labels
434     if (is_buf_16u)
435         udst = (unsigned short*)(buf->data.s + get_work_var_count()*sample_count);
436     else
437         idst = buf->data.i + get_work_var_count()*sample_count;
438
439     for (i = 0; i < sample_count; i++)
440     {
441         if (udst)
442             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
443         else
444             idst[i] = sidx ? sidx[i] : i;
445     }
446
447     if( cv_n )
448     {
449         unsigned short* udst = 0;
450         int* idst = 0;
451         CvRNG* r = &rng;
452
453         if (is_buf_16u)
454         {
455             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
456             for( i = vi = 0; i < sample_count; i++ )
457             {
458                 udst[i] = (unsigned short)vi++;
459                 vi &= vi < cv_n ? -1 : 0;
460             }
461
462             for( i = 0; i < sample_count; i++ )
463             {
464                 int a = cvRandInt(r) % sample_count;
465                 int b = cvRandInt(r) % sample_count;
466                 unsigned short unsh = (unsigned short)vi;
467                 CV_SWAP( udst[a], udst[b], unsh );
468             }
469         }
470         else
471         {
472             idst = buf->data.i + (get_work_var_count()-1)*sample_count;
473             for( i = vi = 0; i < sample_count; i++ )
474             {
475                 idst[i] = vi++;
476                 vi &= vi < cv_n ? -1 : 0;
477             }
478
479             for( i = 0; i < sample_count; i++ )
480             {
481                 int a = cvRandInt(r) % sample_count;
482                 int b = cvRandInt(r) % sample_count;
483                 CV_SWAP( idst[a], idst[b], vi );
484             }
485         }
486     }
487
488     if ( cat_map ) 
489         cat_map->cols = MAX( total_c_count, 1 );
490
491     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
492         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
493     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
494
495     have_priors = is_classifier && params.priors;
496     if( is_classifier )
497     {
498         int m = get_num_classes();
499         double sum = 0;
500         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
501         for( i = 0; i < m; i++ )
502         {
503             double val = have_priors ? params.priors[i] : 1.;
504             if( val <= 0 )
505                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
506             priors->data.db[i] = val;
507             sum += val;
508         }
509
510         // normalize weights
511         if( have_priors )
512             cvScale( priors, priors, 1./sum );
513
514         CV_CALL( priors_mult = cvCloneMat( priors ));
515         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
516     }
517
518     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
519     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
520
521     {
522         int maxNumThreads = 1;
523 #ifdef _OPENMP
524         maxNumThreads = omp_get_num_procs();
525 #endif
526         pred_float_buf.resize(maxNumThreads);
527         pred_int_buf.resize(maxNumThreads);
528         resp_float_buf.resize(maxNumThreads);
529         resp_int_buf.resize(maxNumThreads);
530         cv_lables_buf.resize(maxNumThreads);
531         sample_idx_buf.resize(maxNumThreads);
532         for( int ti = 0; ti < maxNumThreads; ti++ )
533         {
534             pred_float_buf[ti].resize(sample_count);
535             pred_int_buf[ti].resize(sample_count);
536             resp_float_buf[ti].resize(sample_count);
537             resp_int_buf[ti].resize(sample_count);
538             cv_lables_buf[ti].resize(sample_count);
539             sample_idx_buf[ti].resize(sample_count);
540         }
541     }
542     
543     __END__;
544
545     if( data )
546         delete data;
547
548     if (_fdst)
549         cvFree( &_fdst );
550     if (_idst)
551         cvFree( &_idst );
552     cvFree( &int_ptr );
553     cvReleaseMat( &var_type0 );
554     cvReleaseMat( &sample_indices );
555     cvReleaseMat( &tmp_map );
556 }
557
558 int CvERTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, const float** ord_values, const int** missing )
559 {
560     int vidx = var_idx ? var_idx->data.i[vi] : vi;
561     int node_sample_count = n->sample_count; 
562     int* sample_indices_buf = get_sample_idx_buf();
563     const int* sample_indices = 0;
564
565     get_sample_indices(n, sample_indices_buf, &sample_indices);
566
567     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
568     int m_step = missing_mask ? missing_mask->step/CV_ELEM_SIZE(missing_mask->type) : 1;
569     if( tflag == CV_ROW_SAMPLE )
570     {
571         for( int i = 0; i < node_sample_count; i++ )
572         {
573             int idx = sample_indices[i];
574             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + idx * m_step + vi) : 0;
575             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
576         }
577     }
578     else
579         for( int i = 0; i < node_sample_count; i++ )
580         {
581             int idx = sample_indices[i];
582             missing_buf[i] = missing_mask ? *(missing_mask->data.ptr + vi* m_step + idx) : 0;
583             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
584         }
585     *ord_values = ord_values_buf;
586     *missing = missing_buf;
587     return 0; //TODO: return the number of non-missing values
588 }
589
590
591 void CvERTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
592 {
593     get_cat_var_data( n, var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0), indices_buf, indices );
594 }
595
596
597 void CvERTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
598 {
599     if (have_labels)
600         get_cat_var_data( n, var_count + (is_classifier ? 1 : 0), labels_buf, labels );
601 }
602
603
604 int CvERTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
605 {
606     int ci = get_var_type( vi);
607     if( !is_buf_16u )
608         *cat_values = buf->data.i + n->buf_idx*buf->cols + 
609         ci*sample_count + n->offset;
610     else {
611         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
612             ci*sample_count + n->offset);
613         for( int i = 0; i < n->sample_count; i++ )
614             cat_values_buf[i] = short_values[i];
615         *cat_values = cat_values_buf;
616     }
617
618     return 0; //TODO: return the number of non-missing values
619 }
620
621 void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
622                                     float* values, uchar* missing,
623                                     float* responses, bool get_class_idx )
624 {
625     CvMat* subsample_idx = 0;
626     CvMat* subsample_co = 0;
627
628     CV_FUNCNAME( "CvERTreeTrainData::get_vectors" );
629
630     __BEGIN__;
631
632     int i, vi, total = sample_count, count = total, cur_ofs = 0;
633     int* sidx = 0;
634     int* co = 0;
635
636     if( _subsample_idx )
637     {
638         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
639         sidx = subsample_idx->data.i;
640         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
641         co = subsample_co->data.i;
642         cvZero( subsample_co );
643         count = subsample_idx->cols + subsample_idx->rows - 1;
644         for( i = 0; i < count; i++ )
645             co[sidx[i]*2]++;
646         for( i = 0; i < total; i++ )
647         {
648             int count_i = co[i*2];
649             if( count_i )
650             {
651                 co[i*2+1] = cur_ofs*var_count;
652                 cur_ofs += count_i;
653             }
654         }
655     }
656
657     if( missing )
658         memset( missing, 1, count*var_count );
659
660     for( vi = 0; vi < var_count; vi++ )
661     {
662         int ci = get_var_type(vi);
663         if( ci >= 0 ) // categorical
664         {
665             float* dst = values + vi;
666             uchar* m = missing ? missing + vi : 0;
667             int* src_buf = get_pred_int_buf();
668             const int* src = 0; 
669             get_cat_var_data(data_root, vi, src_buf, &src);
670
671             for( i = 0; i < count; i++, dst += var_count )
672             {
673                 int idx = sidx ? sidx[i] : i;
674                 int val = src[idx];
675                 *dst = (float)val;
676                 if( m )
677                 {
678                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
679                     m += var_count;
680                 }
681             }
682         }
683         else // ordered
684         {
685             float* dst_buf = values + vi;
686             int* m_buf = get_pred_int_buf();
687             const float *dst = 0;
688             const int* m = 0;
689             get_ord_var_data(data_root, vi, dst_buf, m_buf, &dst, &m);
690             for (int si = 0; si < total; si++)
691                 *(missing + vi + si) = m[si] == 0 ? 0 : 1;
692         }
693     }
694
695     // copy responses
696     if( responses )
697     {
698         if( is_classifier )
699         {
700             int* src_buf = get_resp_int_buf();
701             const int* src = 0;
702             get_class_labels(data_root, src_buf, &src);
703             for( i = 0; i < count; i++ )
704             {
705                 int idx = sidx ? sidx[i] : i;
706                 int val = get_class_idx ? src[idx] :
707                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
708                 responses[i] = (float)val;
709             }
710         }
711         else           
712         {
713             float *_values_buf = get_resp_float_buf();
714             const float* _values = 0;
715             get_ord_responses(data_root, _values_buf, &_values);
716             for( i = 0; i < count; i++ )
717             {
718                 int idx = sidx ? sidx[i] : i;
719                 responses[i] = _values[idx];
720             }
721         }
722     }
723
724     __END__;
725
726     cvReleaseMat( &subsample_idx );
727     cvReleaseMat( &subsample_co );
728 }
729
730 CvDTreeNode* CvERTreeTrainData::subsample_data( const CvMat* _subsample_idx )
731 {
732     CvDTreeNode* root = 0;
733     
734     CV_FUNCNAME( "CvERTreeTrainData::subsample_data" );
735
736     __BEGIN__;
737
738     if( !data_root )
739         CV_ERROR( CV_StsError, "No training data has been set" );
740
741     if( !_subsample_idx )
742     {
743         // make a copy of the root node
744         CvDTreeNode temp;
745         int i;
746         root = new_node( 0, 1, 0, 0 );
747         temp = *root;
748         *root = *data_root;
749         root->num_valid = temp.num_valid;
750         if( root->num_valid )
751         {
752             for( i = 0; i < var_count; i++ )
753                 root->num_valid[i] = data_root->num_valid[i];
754         }
755         root->cv_Tn = temp.cv_Tn;
756         root->cv_node_risk = temp.cv_node_risk;
757         root->cv_node_error = temp.cv_node_error;
758     }
759     else
760         CV_ERROR( CV_StsError, "_subsample_idx must be null for extra-trees" );
761     __END__;
762
763     return root;
764 }
765
766 double CvForestERTree::calc_node_dir( CvDTreeNode* node )
767 {
768     char* dir = (char*)data->direction->data.ptr;
769     int i, n = node->sample_count, vi = node->split->var_idx;
770     double L, R;
771
772     assert( !node->split->inversed );
773
774     if( data->get_var_type(vi) >= 0 ) // split on categorical var
775     {
776         int* labels_buf = data->get_pred_int_buf();
777         const int* labels = 0;
778         const int* subset = node->split->subset;
779         data->get_cat_var_data( node, vi, labels_buf, &labels );
780         if( !data->have_priors )
781         {
782             int sum = 0, sum_abs = 0;
783
784             for( i = 0; i < n; i++ )
785             {
786                 int idx = labels[i];
787                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
788                     CV_DTREE_CAT_DIR(idx,subset) : 0;
789                 sum += d; sum_abs += d & 1;
790                 dir[i] = (char)d;
791             }
792
793             R = (sum_abs + sum) >> 1;
794             L = (sum_abs - sum) >> 1;
795         }
796         else
797         {
798             const double* priors = data->priors_mult->data.db;
799             double sum = 0, sum_abs = 0;
800             int *responses_buf = data->get_resp_int_buf();
801             const int* responses;
802             data->get_class_labels(node, responses_buf, &responses);
803
804             for( i = 0; i < n; i++ )
805             {
806                 int idx = labels[i];
807                 double w = priors[responses[i]];
808                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
809                 sum += d*w; sum_abs += (d & 1)*w;
810                 dir[i] = (char)d;
811             }
812
813             R = (sum_abs + sum) * 0.5;
814             L = (sum_abs - sum) * 0.5;
815         }
816     }
817     else // split on ordered var
818     {
819         float split_val = node->split->ord.c;
820         float* val_buf = data->get_pred_float_buf();
821         const float* val = 0;
822         int* missing_buf = data->get_pred_int_buf();
823         const int* missing = 0;
824         data->get_ord_var_data( node, vi, val_buf, missing_buf, &val, &missing );
825
826         if( !data->have_priors )
827         {
828             L = R = 0;
829             for( i = 0; i < n; i++ )
830             {
831                 if ( missing[i] )
832                     dir[i] = (char)0;
833                 else
834                 {
835                     if ( val[i] < split_val)
836                     {
837                         dir[i] = (char)-1;
838                         L++;
839                     }
840                     else
841                     {
842                         dir[i] = (char)1;
843                         R++;
844                     }
845                 }
846             }
847         }
848         else
849         {
850             const double* priors = data->priors_mult->data.db;
851             int* responses_buf = data->get_resp_int_buf();
852             const int* responses = 0;
853             data->get_class_labels(node, responses_buf, &responses);
854             L = R = 0;
855             for( i = 0; i < n; i++ )
856             {
857                 if ( missing[i] )
858                     dir[i] = (char)0;
859                 else
860                 {
861                     double w = priors[responses[i]];
862                     if ( val[i] < split_val)
863                     {
864                         dir[i] = (char)-1;
865                          L += w;
866                     }
867                     else
868                     {
869                         dir[i] = (char)1;
870                         R += w;
871                     }
872                 }
873             }
874         }
875     }
876
877     node->maxlr = MAX( L, R );
878     return node->split->quality/(L + R);
879 }
880
881 CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
882 {
883     const float epsilon = FLT_EPSILON*2;
884     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
885
886     int n = node->sample_count;
887     int m = data->get_num_classes();
888
889     float* values_buf = data->get_pred_float_buf();
890     const float* values = 0;
891     int* missing_buf = data->get_pred_int_buf();
892     const int* missing = 0;
893     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
894     int* responses_buf =  data->get_resp_int_buf();
895     const int* responses = 0;
896     data->get_class_labels( node, responses_buf, &responses );
897
898     double lbest_val = 0, rbest_val = 0, best_val = init_quality, split_val = 0;
899     
900     int i;
901
902     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
903
904     bool is_find_split = false;
905     
906     float pmin, pmax;
907     int smpi = 0;
908     while ( missing[smpi] && (smpi < n) )
909         smpi++;
910     assert(smpi < n);
911
912     pmin = values[smpi];
913     pmax = pmin;
914     for (; smpi < n; smpi++)
915     {
916         float ptemp = values[smpi];
917         int m = missing[smpi];
918         if (m) continue;
919         if ( ptemp < pmin)
920             pmin = ptemp;
921         if ( ptemp > pmax)
922             pmax = ptemp;
923     }
924     float fdiff = pmax-pmin;
925     if (fdiff > epsilon)
926     {
927         is_find_split = true;
928         CvRNG* rng = &data->rng;
929         split_val = pmin + cvRandReal(rng) * fdiff ;
930         if (split_val - pmin <= FLT_EPSILON)
931             split_val = pmin + split_delta;
932         if (pmax - split_val <= FLT_EPSILON)
933             split_val = pmax - split_delta;       
934
935         // calculate Gini index
936         if ( !priors )
937         {
938             int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
939             int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
940             int L = 0, R = 0;
941     
942             // init arrays of class instance counters on both sides of the split
943             for( i = 0; i < m; i++ )
944             {
945                 lc[i] = 0;
946                 rc[i] = 0;
947             }
948             for( int si = 0; si < n; si++ )
949             {
950                 int r = responses[si];
951                 float val = values[si];
952                 int m = missing[si];
953                 if (m) continue;
954                 if ( val < split_val )
955                 {
956                     lc[r]++;
957                     L++;
958                 }
959                 else
960                 {
961                     rc[r]++;
962                     R++;
963                 }
964             }
965             for (int i = 0; i < m; i++)
966             {
967                 lbest_val += lc[i]*lc[i];
968                 rbest_val += rc[i]*rc[i];
969             }
970             best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
971         }
972         else
973         {
974             double* lc = (double*)cvStackAlloc(m*sizeof(lc[0]));
975             double* rc = (double*)cvStackAlloc(m*sizeof(rc[0]));
976             double L = 0, R = 0;
977     
978             // init arrays of class instance counters on both sides of the split
979             for( i = 0; i < m; i++ )
980             {
981                 lc[i] = 0;
982                 rc[i] = 0;
983             }
984             for( int si = 0; si < n; si++ )
985             {
986                 int r = responses[si];
987                 float val = values[si];
988                 int m = missing[si];
989                 double p = priors[si];
990                 if (m) continue;
991                 if ( val < split_val )
992                 {
993                     lc[r] += p;
994                     L += p;
995                 }
996                 else
997                 {
998                     rc[r] += p;
999                     R += p;
1000                 }
1001             }
1002             for (int i = 0; i < m; i++)
1003             {
1004                 lbest_val += lc[i]*lc[i];
1005                 rbest_val += rc[i]*rc[i];
1006             }
1007             best_val = (lbest_val*R + rbest_val*L) / (L*R);
1008         }
1009         
1010     }
1011
1012     CvDTreeSplit* split = 0;
1013     if( is_find_split )
1014     {
1015         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1016         split->var_idx = vi;
1017         split->ord.c = (float)split_val;
1018         split->ord.split_point = -1;
1019         split->inversed = 0;
1020         split->quality = (float)best_val;
1021     }
1022     return split;
1023 }
1024
1025 CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1026 {
1027     int ci = data->get_var_type(vi);
1028     int n = node->sample_count;
1029     int cm = data->get_num_classes(); 
1030     int vm = data->cat_count->data.i[ci];
1031     double best_val = init_quality;
1032     CvDTreeSplit *split = 0;
1033
1034     if ( vm > 1 )
1035     {
1036         int* labels_buf = data->get_pred_int_buf();
1037         const int* labels = 0;
1038         data->get_cat_var_data( node, vi, labels_buf, &labels );
1039
1040         int* responses_buf =  data->get_resp_int_buf();
1041         const int* responses = 0;
1042         data->get_class_labels( node, responses_buf, &responses );
1043     
1044         const double* priors = data->have_priors ? data->priors_mult->data.db : 0;       
1045
1046         // create random class mask
1047         int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1048         for (int i = 0; i < vm; i++)
1049         {
1050             valid_cidx[i] = -1;
1051         }
1052         for (int si = 0; si < n; si++)
1053         {
1054             int c = labels[si];
1055             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1056                 continue;
1057             valid_cidx[c]++;
1058         }
1059
1060         int valid_ccount = 0;
1061         for (int i = 0; i < vm; i++)
1062             if (valid_cidx[i] >= 0)
1063             {
1064                 valid_cidx[i] = valid_ccount;
1065                 valid_ccount++;
1066             }
1067         if (valid_ccount > 1)
1068         {
1069             CvRNG* rng = forest->get_rng();
1070             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1071
1072             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1073             CvMat submask;
1074             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1075             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1076             cvSet( &submask, cvScalar(1) );
1077             for (int i = 0; i < valid_ccount; i++)
1078             {
1079                 uchar temp;
1080                 int i1 = cvRandInt( rng ) % valid_ccount;
1081                 int i2 = cvRandInt( rng ) % valid_ccount;
1082                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1083             }
1084
1085             split = _split ? _split : data->new_split_cat( 0, -1.0f );
1086             split->var_idx = vi;
1087             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1088
1089             // calculate Gini index
1090             double lbest_val = 0, rbest_val = 0;
1091             if( !priors )
1092             {
1093                 int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));
1094                 int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));
1095                 int L = 0, R = 0;
1096                 // init arrays of class instance counters on both sides of the split
1097                 for(int i = 0; i < cm; i++ )
1098                 {
1099                     lc[i] = 0;
1100                     rc[i] = 0;
1101                 }
1102                 for( int si = 0; si < n; si++ )
1103                 {
1104                     int r = responses[si];
1105                     int var_class_idx = labels[si];
1106                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1107                         continue;
1108                     int mask_class_idx = valid_cidx[var_class_idx];
1109                     if (var_class_mask->data.ptr[mask_class_idx])
1110                     {
1111                         lc[r]++;
1112                         L++;                 
1113                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1114                     }
1115                     else
1116                     {
1117                         rc[r]++;
1118                         R++;
1119                     }
1120                 }
1121                 for (int i = 0; i < cm; i++)
1122                 {
1123                     lbest_val += lc[i]*lc[i];
1124                     rbest_val += rc[i]*rc[i];
1125                 }                
1126                 best_val = (lbest_val*R + rbest_val*L) / ((double)(L*R));
1127             }
1128             else
1129             {
1130                 double* lc = (double*)cvStackAlloc(cm*sizeof(lc[0]));
1131                 double* rc = (double*)cvStackAlloc(cm*sizeof(rc[0]));
1132                 double L = 0, R = 0;
1133                 // init arrays of class instance counters on both sides of the split
1134                 for(int i = 0; i < cm; i++ )
1135                 {
1136                     lc[i] = 0;
1137                     rc[i] = 0;
1138                 }
1139                 for( int si = 0; si < n; si++ )
1140                 {
1141                     int r = responses[si];
1142                     int var_class_idx = labels[si];
1143                     if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1144                         continue;
1145                     double p = priors[si];
1146                     int mask_class_idx = valid_cidx[var_class_idx];
1147                     
1148                     if (var_class_mask->data.ptr[mask_class_idx])
1149                     {
1150                         lc[r]+=p;
1151                         L+=p;                 
1152                         split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1153                     }
1154                     else
1155                     {
1156                         rc[r]+=p;
1157                         R+=p;
1158                     }
1159                 }
1160                 for (int i = 0; i < cm; i++)
1161                 {
1162                     lbest_val += lc[i]*lc[i];
1163                     rbest_val += rc[i]*rc[i];
1164                 }
1165                 best_val = (lbest_val*R + rbest_val*L) / (L*R);
1166             }
1167             split->quality = (float)best_val;
1168
1169             cvReleaseMat(&var_class_mask);
1170         }   
1171     }  
1172
1173     return split;
1174 }
1175
1176 CvDTreeSplit* CvForestERTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1177 {
1178     const float epsilon = FLT_EPSILON*2;
1179     const float split_delta = (1 + FLT_EPSILON) * FLT_EPSILON;
1180     int n = node->sample_count;
1181     float* values_buf = data->get_pred_float_buf();
1182     const float* values = 0;
1183     int* missing_buf = data->get_pred_int_buf();
1184     const int* missing = 0;
1185     data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1186     float* responses_buf =  data->get_resp_float_buf();
1187     const float* responses = 0;
1188     data->get_ord_responses( node, responses_buf, &responses );
1189
1190     double best_val = init_quality, split_val = 0, lsum = 0, rsum = 0;
1191     int L = 0, R = 0;
1192
1193     bool is_find_split = false;
1194     float pmin, pmax;
1195     int smpi = 0;
1196     while ( missing[smpi] && (smpi < n) )
1197         smpi++;
1198
1199     assert(smpi < n);
1200
1201     pmin = values[smpi];
1202     pmax = pmin;
1203     for (; smpi < n; smpi++)
1204     {
1205         float ptemp = values[smpi];
1206         int m = missing[smpi];
1207         if (m) continue;
1208         if ( ptemp < pmin)
1209             pmin = ptemp;
1210         if ( ptemp > pmax)
1211             pmax = ptemp;
1212     }
1213     float fdiff = pmax-pmin;
1214     if (fdiff > epsilon)
1215     {
1216         is_find_split = true;
1217         CvRNG* rng = &data->rng;
1218         split_val = pmin + cvRandReal(rng) * fdiff ;
1219         if (split_val - pmin <= FLT_EPSILON)
1220             split_val = pmin + split_delta;
1221         if (pmax - split_val <= FLT_EPSILON)
1222             split_val = pmax - split_delta;    
1223
1224         for (int si = 0; si < n; si++)
1225         {
1226             float r = responses[si];
1227             float val = values[si];
1228             int m = missing[si];
1229             if (m) continue;
1230             if (val < split_val)
1231             {
1232                 lsum += r;
1233                 L++;
1234             }
1235             else
1236             {
1237                 rsum += r;
1238                 R++;            
1239             }
1240         }
1241         best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1242     }
1243
1244     CvDTreeSplit* split = 0;
1245     if( is_find_split )
1246     {
1247         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
1248         split->var_idx = vi;
1249         split->ord.c = (float)split_val;
1250         split->ord.split_point = -1;
1251         split->inversed = 0;
1252         split->quality = (float)best_val;
1253     }
1254     return split;
1255 }
1256
1257 CvDTreeSplit* CvForestERTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
1258 {
1259     int ci = data->get_var_type(vi);
1260     int n = node->sample_count;
1261     int vm = data->cat_count->data.i[ci];
1262     double best_val = init_quality;
1263     CvDTreeSplit *split = 0;
1264     float lsum = 0, rsum = 0;
1265
1266     if ( vm > 1 )
1267     {
1268         int* labels_buf = data->get_pred_int_buf();
1269         const int* labels = 0;
1270         data->get_cat_var_data( node, vi, labels_buf, &labels );
1271
1272         float* responses_buf =  data->get_resp_float_buf();
1273         const float* responses = 0;
1274         data->get_ord_responses( node, responses_buf, &responses );
1275
1276         // create random class mask
1277         int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));
1278         for (int i = 0; i < vm; i++)
1279         {
1280             valid_cidx[i] = -1;
1281         }
1282         for (int si = 0; si < n; si++)
1283         {
1284             int c = labels[si];
1285             if ( ((c == 65535) && data->is_buf_16u) || ((c<0) && (!data->is_buf_16u)) )
1286                         continue;
1287             valid_cidx[c]++;
1288         }
1289
1290         int valid_ccount = 0;
1291         for (int i = 0; i < vm; i++)
1292             if (valid_cidx[i] >= 0)
1293             {
1294                 valid_cidx[i] = valid_ccount;
1295                 valid_ccount++;
1296             }
1297         if (valid_ccount > 1)
1298         {
1299             CvRNG* rng = forest->get_rng();
1300             int l_cval_count = 1 + cvRandInt(rng) % (valid_ccount-1);
1301
1302             CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );
1303             CvMat submask;
1304             memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));
1305             cvGetCols( var_class_mask, &submask, 0, l_cval_count );
1306             cvSet( &submask, cvScalar(1) );
1307             for (int i = 0; i < valid_ccount; i++)
1308             {
1309                 uchar temp;
1310                 int i1 = cvRandInt( rng ) % valid_ccount;
1311                 int i2 = cvRandInt( rng ) % valid_ccount;
1312                 CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );
1313             }
1314
1315             split = _split ? _split : data->new_split_cat( 0, -1.0f);
1316             split->var_idx = vi;
1317             memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
1318
1319             int L = 0, R = 0;
1320             for( int si = 0; si < n; si++ )
1321             {
1322                 float r = responses[si];
1323                 int var_class_idx = labels[si];
1324                 if ( ((var_class_idx == 65535) && data->is_buf_16u) || ((var_class_idx<0) && (!data->is_buf_16u)) )
1325                         continue;
1326                 int mask_class_idx = valid_cidx[var_class_idx];
1327                 if (var_class_mask->data.ptr[mask_class_idx])
1328                 {
1329                     lsum += r;
1330                     L++;                 
1331                     split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);
1332                 }
1333                 else
1334                 {
1335                     rsum += r;
1336                     R++;
1337                 }
1338             }
1339             best_val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1340
1341             split->quality = (float)best_val;
1342
1343             cvReleaseMat(&var_class_mask);
1344         }   
1345     }  
1346
1347     return split;
1348 }
1349
1350 //void CvForestERTree::complete_node_dir( CvDTreeNode* node )
1351 //{
1352 //    int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
1353 //    int nz = n - node->get_num_valid(node->split->var_idx);
1354 //    char* dir = (char*)data->direction->data.ptr;
1355 //
1356 //    // try to complete direction using surrogate splits
1357 //    if( nz && data->params.use_surrogates )
1358 //    {
1359 //        CvDTreeSplit* split = node->split->next;
1360 //        for( ; split != 0 && nz; split = split->next )
1361 //        {
1362 //            int inversed_mask = split->inversed ? -1 : 0;
1363 //            vi = split->var_idx;
1364 //
1365 //            if( data->get_var_type(vi) >= 0 ) // split on categorical var
1366 //            {
1367 //                int* labels_buf = data->pred_int_buf;
1368 //                const int* labels = 0;
1369 //                data->get_cat_var_data(node, vi, labels_buf, &labels);
1370 //                const int* subset = split->subset;
1371 //
1372 //                for( i = 0; i < n; i++ )
1373 //                {
1374 //                    int idx = labels[i];
1375 //                    if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
1376 //                        
1377 //                    {
1378 //                        int d = CV_DTREE_CAT_DIR(idx,subset);
1379 //                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
1380 //                        if( --nz )
1381 //                            break;
1382 //                    }
1383 //                }
1384 //            }
1385 //            else // split on ordered var
1386 //            {
1387 //                float* values_buf = data->pred_float_buf;
1388 //                const float* values = 0;
1389 //                uchar* missing_buf = (uchar*)data->pred_int_buf;
1390 //                const uchar* missing = 0;
1391 //                data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1392 //                float split_val = node->split->ord.c;
1393 //                
1394 //                for( i = 0; i < n; i++ )
1395 //                {
1396 //                    if( !dir[i] && !missing[i])
1397 //                    {
1398 //                        int d = values[i] <= split_val ? -1 : 1;
1399 //                        dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
1400 //                        if( --nz )
1401 //                            break;
1402 //                    }
1403 //                }
1404 //            }
1405 //        }
1406 //    }
1407 //
1408 //    // find the default direction for the rest
1409 //    if( nz )
1410 //    {
1411 //        for( i = nr = 0; i < n; i++ )
1412 //            nr += dir[i] > 0;
1413 //        nl = n - nr - nz;
1414 //        d0 = nl > nr ? -1 : nr > nl;
1415 //    }
1416 //
1417 //    // make sure that every sample is directed either to the left or to the right
1418 //    for( i = 0; i < n; i++ )
1419 //    {
1420 //        int d = dir[i];
1421 //        if( !d )
1422 //        {
1423 //            d = d0;
1424 //            if( !d )
1425 //                d = d1, d1 = -d1;
1426 //        }
1427 //        d = d > 0;
1428 //        dir[i] = (char)d; // remap (-1,1) to (0,1)
1429 //    }
1430 //}
1431
1432 void CvForestERTree::split_node_data( CvDTreeNode* node )
1433 {
1434     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
1435     char* dir = (char*)data->direction->data.ptr;
1436     CvDTreeNode *left = 0, *right = 0;
1437     int new_buf_idx = data->get_child_buf_idx( node );
1438     CvMat* buf = data->buf;
1439     int* temp_buf = (int*)cvStackAlloc(n*sizeof(temp_buf[0]));
1440
1441     complete_node_dir(node);
1442
1443     for( i = nl = nr = 0; i < n; i++ )
1444     {
1445         int d = dir[i];
1446         nr += d;
1447         nl += d^1;
1448     }
1449
1450     bool split_input_data;
1451     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
1452     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
1453
1454     split_input_data = node->depth + 1 < data->params.max_depth &&
1455         (node->left->sample_count > data->params.min_sample_count ||
1456         node->right->sample_count > data->params.min_sample_count);
1457
1458     // split ordered vars
1459     for( vi = 0; vi < data->var_count; vi++ )
1460     {
1461         int ci = data->get_var_type(vi);
1462         if (ci >= 0) continue;
1463         
1464         int n1 = node->get_num_valid(vi), nr1 = 0;
1465
1466         float* values_buf = data->get_pred_float_buf();
1467         const float* values = 0;
1468         int* missing_buf = data->get_pred_int_buf();
1469         const int* missing = 0;
1470         data->get_ord_var_data( node, vi, values_buf, missing_buf, &values, &missing );
1471
1472         for( i = 0; i < n; i++ )
1473             nr1 += (!missing[i] & dir[i]);
1474         left->set_num_valid(vi, n1 - nr1);
1475         right->set_num_valid(vi, nr1);                
1476     }
1477     // split categorical vars, responses and cv_labels using new_idx relocation table
1478     for( vi = 0; vi < data->get_work_var_count() + data->ord_var_count; vi++ )
1479     {
1480         int ci = data->get_var_type(vi);
1481         if (ci < 0) continue;
1482
1483         int n1 = node->get_num_valid(vi), nr1 = 0;
1484         
1485         int *src_lbls_buf = data->get_pred_int_buf();
1486         const int* src_lbls = 0;
1487         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
1488
1489         for(i = 0; i < n; i++)
1490             temp_buf[i] = src_lbls[i];
1491
1492         if (data->is_buf_16u)
1493         {
1494             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + 
1495                 ci*scount + left->offset);
1496             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + 
1497                 ci*scount + right->offset);
1498             
1499             for( i = 0; i < n; i++ )
1500             {
1501                 int d = dir[i];
1502                 int idx = temp_buf[i];
1503                 if (d)
1504                 {
1505                     *rdst = (unsigned short)idx;
1506                     rdst++;
1507                     nr1 += (idx != 65535);
1508                 }
1509                 else
1510                 {
1511                     *ldst = (unsigned short)idx;
1512                     ldst++;
1513                 }
1514             }
1515
1516             if( vi < data->var_count )
1517             {
1518                 left->set_num_valid(vi, n1 - nr1);
1519                 right->set_num_valid(vi, nr1);
1520             }
1521         }
1522         else
1523         {
1524             int *ldst = buf->data.i + left->buf_idx*buf->cols + 
1525                 ci*scount + left->offset;
1526             int *rdst = buf->data.i + right->buf_idx*buf->cols + 
1527                 ci*scount + right->offset;
1528             
1529             for( i = 0; i < n; i++ )
1530             {
1531                 int d = dir[i];
1532                 int idx = temp_buf[i];
1533                 if (d)
1534                 {
1535                     *rdst = idx;
1536                     rdst++;
1537                     nr1 += (idx >= 0);
1538                 }
1539                 else
1540                 {
1541                     *ldst = idx;
1542                     ldst++;
1543                 }
1544                 
1545             }
1546
1547             if( vi < data->var_count )
1548             {
1549                 left->set_num_valid(vi, n1 - nr1);
1550                 right->set_num_valid(vi, nr1);
1551             }
1552         }        
1553     }
1554
1555
1556     // split sample indices
1557     int *sample_idx_src_buf = data->get_sample_idx_buf();
1558     const int* sample_idx_src = 0;
1559     if (split_input_data)
1560     {
1561         data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
1562
1563         for(i = 0; i < n; i++)
1564             temp_buf[i] = sample_idx_src[i];
1565
1566         int pos = data->get_work_var_count();
1567        
1568         if (data->is_buf_16u)
1569         {
1570             unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
1571                 pos*scount + left->offset);
1572             unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + 
1573                 pos*scount + right->offset);
1574             
1575             for (i = 0; i < n; i++)
1576             {
1577                 int d = dir[i];
1578                 unsigned short idx = (unsigned short)temp_buf[i];
1579                 if (d)
1580                 {
1581                     *rdst = idx;
1582                     rdst++;
1583                 }
1584                 else
1585                 {
1586                     *ldst = idx;
1587                     ldst++;
1588                 }
1589             }
1590         }
1591         else
1592         {
1593             int* ldst = buf->data.i + left->buf_idx*buf->cols + 
1594                 pos*scount + left->offset;
1595             int* rdst = buf->data.i + right->buf_idx*buf->cols + 
1596                 pos*scount + right->offset;
1597             for (i = 0; i < n; i++)
1598             {
1599                 int d = dir[i];
1600                 int idx = temp_buf[i];
1601                 if (d)
1602                 {
1603                     *rdst = idx;
1604                     rdst++;
1605                 }
1606                 else
1607                 {
1608                     *ldst = idx;
1609                     ldst++;
1610                 }
1611             }
1612         }
1613     }
1614     
1615     // deallocate the parent node data that is not needed anymore
1616     data->free_node_data(node);    
1617 }
1618
1619 CvERTrees::CvERTrees()
1620 {
1621 }
1622
1623 CvERTrees::~CvERTrees()
1624 {
1625 }
1626
1627 bool CvERTrees::train( const CvMat* _train_data, int _tflag,
1628                         const CvMat* _responses, const CvMat* _var_idx,
1629                         const CvMat* _sample_idx, const CvMat* _var_type,
1630                         const CvMat* _missing_mask, CvRTParams params )
1631 {
1632     bool result = false;
1633
1634     CV_FUNCNAME("CvERTrees::train");
1635     __BEGIN__
1636     int var_count = 0;
1637
1638     clear();
1639
1640     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
1641         params.regression_accuracy, params.use_surrogates, params.max_categories,
1642         params.cv_folds, params.use_1se_rule, false, params.priors );
1643
1644     data = new CvERTreeTrainData();
1645     CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
1646         _sample_idx, _var_type, _missing_mask, tree_params, true));
1647
1648     var_count = data->var_count;
1649     if( params.nactive_vars > var_count )
1650         params.nactive_vars = var_count;
1651     else if( params.nactive_vars == 0 )
1652         params.nactive_vars = (int)sqrt((double)var_count);
1653     else if( params.nactive_vars < 0 )
1654         CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
1655
1656     // Create mask of active variables at the tree nodes
1657     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
1658     if( params.calc_var_importance )
1659     {
1660         CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
1661         cvZero(var_importance);
1662     }
1663     { // initialize active variables mask
1664         CvMat submask1, submask2;
1665         CV_Assert( (active_var_mask->cols >= 1) && (params.nactive_vars > 0) && (params.nactive_vars <= active_var_mask->cols) );
1666         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
1667         cvSet( &submask1, cvScalar(1) );
1668         if( params.nactive_vars < active_var_mask->cols )
1669         {
1670             cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
1671             cvZero( &submask2 );
1672         }
1673     }
1674
1675     CV_CALL(result = grow_forest( params.term_crit ));
1676
1677     result = true;
1678
1679     __END__
1680     return result;
1681     
1682 }
1683
1684 bool CvERTrees::train( CvMLData* data, CvRTParams params)
1685 {
1686    bool result = false;
1687
1688     CV_FUNCNAME( "CvERTrees::train" );
1689
1690     __BEGIN__;
1691
1692     CV_CALL( result = CvRTrees::train( data, params) );
1693
1694     __END__;
1695
1696     return result;
1697 }
1698
1699 bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
1700 {
1701     bool result = false;
1702
1703     CvMat* sample_idx_for_tree      = 0;
1704
1705     CV_FUNCNAME("CvERTrees::grow_forest");
1706     __BEGIN__;
1707
1708     const int max_ntrees = term_crit.max_iter;
1709     const double max_oob_err = term_crit.epsilon;
1710
1711     const int dims = data->var_count;
1712     float maximal_response = 0;
1713
1714     CvMat* oob_sample_votes        = 0;
1715     CvMat* oob_responses       = 0;
1716
1717     float* oob_samples_perm_ptr= 0;
1718
1719     float* samples_ptr     = 0;
1720     uchar* missing_ptr     = 0;
1721     float* true_resp_ptr   = 0;
1722     bool is_oob_or_vimportance = ((max_oob_err > 0) && (term_crit.type != CV_TERMCRIT_ITER)) || var_importance;
1723
1724     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
1725     // oob_num_of_predictions[i] = number of summands
1726     //                            (number of predictions for the i-th sample)
1727     // initialize these variable to avoid warning C4701
1728     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
1729     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
1730      
1731     nsamples = data->sample_count;
1732     nclasses = data->get_num_classes();
1733
1734     if ( is_oob_or_vimportance )
1735     {
1736         if( data->is_classifier )
1737         {
1738             CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
1739             cvZero(oob_sample_votes);
1740         }
1741         else
1742         {
1743             // oob_responses[0,i] = oob_predictions_sum[i]
1744             //    = sum of predicted values for the i-th sample
1745             // oob_responses[1,i] = oob_num_of_predictions[i]
1746             //    = number of summands (number of predictions for the i-th sample)
1747             CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
1748             cvZero(oob_responses);
1749             cvGetRow( oob_responses, &oob_predictions_sum, 0 );
1750             cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
1751         }
1752         
1753         CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1754         CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
1755         CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
1756         CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));            
1757
1758         CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
1759         {
1760             double minval, maxval;
1761             CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
1762             cvMinMaxLoc( &responses, &minval, &maxval );
1763             maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
1764         }
1765     }
1766    
1767     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
1768     memset( trees, 0, sizeof(trees[0])*max_ntrees );
1769
1770     CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));
1771
1772     for (int i = 0; i < nsamples; i++)
1773         sample_idx_for_tree->data.i[i] = i;
1774     ntrees = 0;
1775     while( ntrees < max_ntrees )
1776     {
1777         int i, oob_samples_count = 0;
1778         double ncorrect_responses = 0; // used for estimation of variable importance
1779         CvForestTree* tree = 0;
1780
1781         trees[ntrees] = new CvForestERTree();
1782         tree = (CvForestERTree*)trees[ntrees];
1783         CV_CALL(tree->train( data, 0, this ));
1784
1785         if ( is_oob_or_vimportance )
1786         {
1787             CvMat sample, missing;
1788             // form array of OOB samples indices and get these samples
1789             sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
1790             missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
1791
1792             oob_error = 0;
1793             for( i = 0; i < nsamples; i++,
1794                 sample.data.fl += dims, missing.data.ptr += dims )
1795             {
1796                 CvDTreeNode* predicted_node = 0;
1797                 
1798                 // predict oob samples
1799                 if( !predicted_node )
1800                     CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
1801
1802                 if( !data->is_classifier ) //regression
1803                 {
1804                     double avg_resp, resp = predicted_node->value;
1805                     oob_predictions_sum.data.fl[i] += (float)resp;
1806                     oob_num_of_predictions.data.fl[i] += 1;
1807
1808                     // compute oob error
1809                     avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
1810                     avg_resp -= true_resp_ptr[i];
1811                     oob_error += avg_resp*avg_resp;
1812                     resp = (resp - true_resp_ptr[i])/maximal_response;
1813                     ncorrect_responses += exp( -resp*resp );
1814                 }
1815                 else //classification
1816                 {
1817                     double prdct_resp;
1818                     CvPoint max_loc;
1819                     CvMat votes;
1820
1821                     cvGetRow(oob_sample_votes, &votes, i);
1822                     votes.data.i[predicted_node->class_idx]++;
1823
1824                     // compute oob error
1825                     cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
1826
1827                     prdct_resp = data->cat_map->data.i[max_loc.x];
1828                     oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
1829
1830                     ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
1831                 }
1832                 oob_samples_count++;
1833             }
1834             if( oob_samples_count > 0 )
1835                 oob_error /= (double)oob_samples_count;
1836
1837             // estimate variable importance
1838             if( var_importance && oob_samples_count > 0 )
1839             {
1840                 int m;
1841
1842                 memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
1843                 for( m = 0; m < dims; m++ )
1844                 {
1845                     double ncorrect_responses_permuted = 0;
1846                     // randomly permute values of the m-th variable in the oob samples
1847                     float* mth_var_ptr = oob_samples_perm_ptr + m;
1848
1849                     for( i = 0; i < nsamples; i++ )
1850                     {
1851                         int i1, i2;
1852                         float temp;
1853
1854                         i1 = cvRandInt( &rng ) % nsamples;
1855                         i2 = cvRandInt( &rng ) % nsamples;
1856                         CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
1857
1858                         // turn values of (m-1)-th variable, that were permuted
1859                         // at the previous iteration, untouched
1860                         if( m > 1 )
1861                             oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
1862                     }
1863
1864                     // predict "permuted" cases and calculate the number of votes for the
1865                     // correct class in the variable-m-permuted oob data
1866                     sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
1867                     missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
1868                     for( i = 0; i < nsamples; i++,
1869                         sample.data.fl += dims, missing.data.ptr += dims )
1870                     {
1871                         double predct_resp, true_resp;
1872
1873                         predct_resp = tree->predict(&sample, &missing, true)->value;
1874                         true_resp   = true_resp_ptr[i];
1875                         if( data->is_classifier )
1876                             ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
1877                         else
1878                         {
1879                             true_resp = (true_resp - predct_resp)/maximal_response;
1880                             ncorrect_responses_permuted += exp( -true_resp*true_resp );
1881                         }
1882                     }
1883                     var_importance->data.fl[m] += (float)(ncorrect_responses
1884                         - ncorrect_responses_permuted);
1885                 }
1886             }
1887         }
1888         ntrees++;
1889         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
1890             break;
1891     }
1892     if( var_importance )
1893     {
1894         for ( int vi = 0; vi < var_importance->cols; vi++ )
1895                 var_importance->data.fl[vi] = ( var_importance->data.fl[vi] > 0 ) ?
1896                     var_importance->data.fl[vi] : 0;
1897         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
1898     }
1899
1900     result = true;
1901     
1902     cvFree( &oob_samples_perm_ptr );
1903     cvFree( &samples_ptr );
1904     cvFree( &missing_ptr );
1905     cvFree( &true_resp_ptr );
1906     
1907     cvReleaseMat( &sample_idx_for_tree );
1908
1909     cvReleaseMat( &oob_sample_votes );
1910     cvReleaseMat( &oob_responses );
1911
1912     __END__;
1913
1914     return result;
1915 }
1916
1917 using namespace cv;
1918
1919 bool CvERTrees::train( const Mat& _train_data, int _tflag,
1920                       const Mat& _responses, const Mat& _var_idx,
1921                       const Mat& _sample_idx, const Mat& _var_type,
1922                       const Mat& _missing_mask, CvRTParams params )
1923 {
1924     CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
1925     sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
1926     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
1927                  sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
1928                  mmask.data.ptr ? &mmask : 0, params);
1929 }
1930
1931 // End of file.
1932