//M*/
#include "_ml.h"
+#include <ctype.h>
+
+using namespace cv;
static const float ord_nan = FLT_MAX*0.5f;
static const int min_block_size = 1 << 16;
CvDTreeTrainData::CvDTreeTrainData()
{
var_idx = var_type = cat_count = cat_ofs = cat_map =
- priors = counts = buf = direction = split_buf = 0;
+ priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
tree_storage = temp_storage = 0;
clear();
CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx,
const CvMat* _sample_idx, const CvMat* _var_type,
- const CvMat* _missing_mask,
- CvDTreeParams _params, bool _shared, bool _add_weights )
+ const CvMat* _missing_mask, const CvDTreeParams& _params,
+ bool _shared, bool _add_labels )
{
var_idx = var_type = cat_count = cat_ofs = cat_map =
- priors = counts = buf = direction = split_buf = 0;
+ priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
+
tree_storage = temp_storage = 0;
-
+
set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
- _var_type, _missing_mask, _params, _shared, _add_weights );
+ _var_type, _missing_mask, _params, _shared, _add_labels );
}
bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
{
bool ok = false;
-
+
CV_FUNCNAME( "CvDTreeTrainData::set_params" );
__BEGIN__;
return ok;
}
-
#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
-#define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
-static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
+#define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])
+static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )
+static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )
+
+#define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
+static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
- const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
- bool _shared, bool _add_weights )
+ const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
+ bool _shared, bool _add_labels, bool _update_data )
{
- CvMat* sample_idx = 0;
+ CvMat* sample_indices = 0;
CvMat* var_type0 = 0;
CvMat* tmp_map = 0;
int** int_ptr = 0;
+ CvPair16u32s* pair16u32s_ptr = 0;
+ CvDTreeTrainData* data = 0;
+ float *_fdst = 0;
+ int *_idst = 0;
+ unsigned short* udst = 0;
+ int* idst = 0;
CV_FUNCNAME( "CvDTreeTrainData::set_data" );
int total_c_count = 0;
int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
- int vi, i;
+ int vi, i, size;
char err[100];
const int *sidx = 0, *vidx = 0;
+
+ if( _update_data && data_root )
+ {
+ data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
+ _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
+
+ // compare new and old train data
+ if( !(data->var_count == var_count &&
+ cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
+ cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
+ cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
+ CV_ERROR( CV_StsBadArg,
+ "The new training data must have the same types and the input and output variables "
+ "and the same categories for categorical variables" );
+
+ cvReleaseMat( &priors );
+ cvReleaseMat( &priors_mult );
+ cvReleaseMat( &buf );
+ cvReleaseMat( &direction );
+ cvReleaseMat( &split_buf );
+ cvReleaseMemStorage( &temp_storage );
+
+ priors = data->priors; data->priors = 0;
+ priors_mult = data->priors_mult; data->priors_mult = 0;
+ buf = data->buf; data->buf = 0;
+ buf_count = data->buf_count; buf_size = data->buf_size;
+ sample_count = data->sample_count;
+
+ direction = data->direction; data->direction = 0;
+ split_buf = data->split_buf; data->split_buf = 0;
+ temp_storage = data->temp_storage; data->temp_storage = 0;
+ nv_heap = data->nv_heap; cv_heap = data->cv_heap;
+
+ data_root = new_node( 0, sample_count, 0, 0 );
+ EXIT;
+ }
clear();
// check parameter types and sizes
CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
+
+ train_data = _train_data;
+ responses = _responses;
+
if( _tflag == CV_ROW_SAMPLE )
{
ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
if( _missing_mask )
mv_step = _missing_mask->step, ms_step = 1;
}
+ tflag = _tflag;
sample_count = sample_all;
var_count = var_all;
-
+
if( _sample_idx )
{
- CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
- sidx = sample_idx->data.i;
- sample_count = sample_idx->rows + sample_idx->cols - 1;
+ CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
+ sidx = sample_indices->data.i;
+ sample_count = sample_indices->rows + sample_indices->cols - 1;
}
if( _var_idx )
var_count = var_idx->rows + var_idx->cols - 1;
}
+ is_buf_16u = false;
+ if ( sample_count < 65536 )
+ is_buf_16u = true;
+
if( !CV_IS_MAT(_responses) ||
(CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
- _responses->rows != 1 && _responses->cols != 1 ||
+ (_responses->rows != 1 && _responses->cols != 1) ||
_responses->rows + _responses->cols - 1 != sample_all )
CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
"floating-point vector containing as many elements as "
"the total number of samples in the training data matrix" );
+
+
+ CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
- CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
-
+
+
cat_var_count = 0;
ord_var_count = -1;
// in case of single ordered predictor we need dummy cv_labels
// for safe split_node_data() operation
- have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
- have_weights = _add_weights;
+ have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
- buf_size = (ord_var_count*2 + cat_var_count + 1 +
- (have_cv_labels ? 1 : 0) + (have_weights ? 1 : 0))*sample_count + 2;
+ work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
+ + (have_labels ? 1 : 0); // for cv_labels
+
+ buf_size = (work_var_count + 1 /*for sample_indices*/) * sample_count;
shared = _shared;
- buf_count = shared ? 3 : 2;
- CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
- CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
- CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
- CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
+ buf_count = shared ? 2 : 1;
+
+ if ( is_buf_16u )
+ {
+ CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
+ CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
+ }
+ else
+ {
+ CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
+ CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
+ }
+
+ size = is_classifier ? (cat_var_count+1) : cat_var_count;
+ size = !size ? 1 : size;
+ CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
+ CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
+
+ size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
+ size = !size ? 1 : size;
+ CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
// now calculate the maximum size of split,
// create memory storage that will keep nodes and splits of the decision tree
CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
- temp_block_size = nv_size = var_count*sizeof(int);
+ nv_size = var_count*sizeof(int);
+ nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
+
+ temp_block_size = nv_size;
+
if( cv_n )
{
if( sample_count < cv_n*MAX(params.min_sample_count,10) )
CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
- CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
max_c_count = 1;
+ _fdst = 0;
+ _idst = 0;
+ if (ord_var_count)
+ _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
+ if (is_buf_16u && (cat_var_count || is_classifier))
+ _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
+
// transform the training data to convenient representation
for( vi = 0; vi <= var_count; vi++ )
{
fdata = _responses->data.fl;
}
- if( vi < var_count && ci >= 0 ||
- vi == var_count && is_classifier ) // process categorical variable or response
+ if( (vi < var_count && ci>=0) ||
+ (vi == var_count && is_classifier) ) // process categorical variable or response
{
int c_count, prev_label;
- int* c_map, *dst = get_cat_var_data( data_root, vi );
-
+ int* c_map;
+
+ if (is_buf_16u)
+ udst = (unsigned short*)(buf->data.s + vi*sample_count);
+ else
+ idst = buf->data.i + vi*sample_count;
+
// copy data
for( i = 0; i < sample_count; i++ )
{
{
float t = fdata[si*step];
val = cvRound(t);
- if( val != t )
+ if( fabs(t - val) > FLT_EPSILON )
{
sprintf( err, "%d-th value of %d-th (categorical) "
"variable is not an integer", i, vi );
}
num_valid++;
}
- dst[i] = val;
- int_ptr[i] = dst + i;
+ if (is_buf_16u)
+ {
+ _idst[i] = val;
+ pair16u32s_ptr[i].u = udst + i;
+ pair16u32s_ptr[i].i = _idst + i;
+ }
+ else
+ {
+ idst[i] = val;
+ int_ptr[i] = idst + i;
+ }
}
- // sort all the values, including the missing measurements
- // that should all move to the end
- icvSortIntPtr( int_ptr, sample_count, 0 );
- //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
-
c_count = num_valid > 0;
-
- // count the categories
- for( i = 1; i < num_valid; i++ )
- c_count += *int_ptr[i] != *int_ptr[i-1];
+ if (is_buf_16u)
+ {
+ icvSortPairs( pair16u32s_ptr, sample_count, 0 );
+ // count the categories
+ for( i = 1; i < num_valid; i++ )
+ if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
+ c_count ++ ;
+ }
+ else
+ {
+ icvSortIntPtr( int_ptr, sample_count, 0 );
+ // count the categories
+ for( i = 1; i < num_valid; i++ )
+ c_count += *int_ptr[i] != *int_ptr[i-1];
+ }
if( vi > 0 )
max_c_count = MAX( max_c_count, c_count );
c_map = cat_map->data.i + total_c_count;
total_c_count += c_count;
- // compact the class indices and build the map
- prev_label = ~*int_ptr[0];
c_count = -1;
-
- for( i = 0; i < num_valid; i++ )
+ if (is_buf_16u)
{
- int cur_label = *int_ptr[i];
- if( cur_label != prev_label )
- c_map[++c_count] = prev_label = cur_label;
- *int_ptr[i] = c_count;
+ // compact the class indices and build the map
+ prev_label = ~*pair16u32s_ptr[0].i;
+ for( i = 0; i < num_valid; i++ )
+ {
+ int cur_label = *pair16u32s_ptr[i].i;
+ if( cur_label != prev_label )
+ c_map[++c_count] = prev_label = cur_label;
+ *pair16u32s_ptr[i].u = (unsigned short)c_count;
+ }
+ // replace labels for missing values with -1
+ for( ; i < sample_count; i++ )
+ *pair16u32s_ptr[i].u = 65535;
}
-
- // replace labels for missing values with -1
- for( ; i < sample_count; i++ )
- *int_ptr[i] = -1;
+ else
+ {
+ // compact the class indices and build the map
+ prev_label = ~*int_ptr[0];
+ for( i = 0; i < num_valid; i++ )
+ {
+ int cur_label = *int_ptr[i];
+ if( cur_label != prev_label )
+ c_map[++c_count] = prev_label = cur_label;
+ *int_ptr[i] = c_count;
+ }
+ // replace labels for missing values with -1
+ for( ; i < sample_count; i++ )
+ *int_ptr[i] = -1;
+ }
}
else if( ci < 0 ) // process ordered variable
{
- CvPair32s32f* dst = get_ord_var_data( data_root, vi );
+ if (is_buf_16u)
+ udst = (unsigned short*)(buf->data.s + vi*sample_count);
+ else
+ idst = buf->data.i + vi*sample_count;
for( i = 0; i < sample_count; i++ )
{
"variable (=%g) is too large", i, vi, val );
CV_ERROR( CV_StsBadArg, err );
}
- num_valid++;
}
- dst[i].i = i;
- dst[i].val = val;
- }
-
- icvSortPairs( dst, sample_count, 0 );
- }
- else // special case: process ordered response,
- // it will be stored similarly to categorical vars (i.e. no pairs)
- {
- float* dst = get_ord_responses( data_root );
-
- for( i = 0; i < sample_count; i++ )
- {
- float val = ord_nan;
- int si = sidx ? sidx[i] : i;
- if( idata )
- val = (float)idata[si*step];
+ num_valid++;
+ if (is_buf_16u)
+ udst[i] = (unsigned short)i;
else
- val = fdata[si*step];
-
- if( fabs(val) >= ord_nan )
- {
- sprintf( err, "%d-th value of %d-th (ordered) "
- "variable (=%g) is out of range", i, vi, val );
- CV_ERROR( CV_StsBadArg, err );
- }
- dst[i] = val;
+ idst[i] = i;
+ _fdst[i] = val;
+
}
-
- cat_count->data.i[cat_var_count] = 0;
- cat_ofs->data.i[cat_var_count] = total_c_count;
- num_valid = sample_count;
+ if (is_buf_16u)
+ icvSortUShAux( udst, num_valid, _fdst);
+ else
+ icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );
}
-
+
if( vi < var_count )
data_root->set_num_valid(vi, num_valid);
}
+ // set sample labels
+ if (is_buf_16u)
+ udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);
+ else
+ idst = buf->data.i + work_var_count*sample_count;
+
+ for (i = 0; i < sample_count; i++)
+ {
+ if (udst)
+ udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
+ else
+ idst[i] = sidx ? sidx[i] : i;
+ }
+
if( cv_n )
{
- int* dst = get_cv_labels(data_root);
+ unsigned short* udst = 0;
+ int* idst = 0;
CvRNG* r = &rng;
- for( i = vi = 0; i < sample_count; i++ )
+ if (is_buf_16u)
{
- dst[i] = vi++;
- vi &= vi < cv_n ? -1 : 0;
- }
+ udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
+ for( i = vi = 0; i < sample_count; i++ )
+ {
+ udst[i] = (unsigned short)vi++;
+ vi &= vi < cv_n ? -1 : 0;
+ }
- for( i = 0; i < sample_count; i++ )
+ for( i = 0; i < sample_count; i++ )
+ {
+ int a = cvRandInt(r) % sample_count;
+ int b = cvRandInt(r) % sample_count;
+ unsigned short unsh = (unsigned short)vi;
+ CV_SWAP( udst[a], udst[b], unsh );
+ }
+ }
+ else
{
- int a = cvRandInt(r) % sample_count;
- int b = cvRandInt(r) % sample_count;
- CV_SWAP( dst[a], dst[b], vi );
+ idst = buf->data.i + (get_work_var_count()-1)*sample_count;
+ for( i = vi = 0; i < sample_count; i++ )
+ {
+ idst[i] = vi++;
+ vi &= vi < cv_n ? -1 : 0;
+ }
+
+ for( i = 0; i < sample_count; i++ )
+ {
+ int a = cvRandInt(r) % sample_count;
+ int b = cvRandInt(r) % sample_count;
+ CV_SWAP( idst[a], idst[b], vi );
+ }
}
}
- cat_map->cols = MAX( total_c_count, 1 );
+ if ( cat_map )
+ cat_map->cols = MAX( total_c_count, 1 );
max_split_size = cvAlign(sizeof(CvDTreeSplit) +
(MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
have_priors = is_classifier && params.priors;
if( is_classifier )
{
- int m = get_num_classes(), rows = 4;
+ int m = get_num_classes();
double sum = 0;
CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
for( i = 0; i < m; i++ )
priors->data.db[i] = val;
sum += val;
}
+
// normalize weights
if( have_priors )
cvScale( priors, priors, 1./sum );
- if( cat_var_count > 0 || params.cv_folds > 0 )
- {
- // need storage for cjk (see find_split_cat_gini) and risks/errors
- rows += MAX( max_c_count, params.cv_folds ) + 1;
- // add buffer for k-means clustering
- if( m > 2 && max_c_count > params.max_categories )
- rows += params.max_categories + (max_c_count+m-1)/m;
- }
-
- CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
+ CV_CALL( priors_mult = cvCloneMat( priors ));
+ CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
}
+
CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
__END__;
+ if( data )
+ delete data;
+
+ if (_fdst)
+ cvFree( &_fdst );
+ if (_idst)
+ cvFree( &_idst );
cvFree( &int_ptr );
- cvReleaseMat( &sample_idx );
+ cvFree( &pair16u32s_ptr);
cvReleaseMat( &var_type0 );
+ cvReleaseMat( &sample_indices );
cvReleaseMat( &tmp_map );
}
+void CvDTreeTrainData::do_responses_copy()
+{
+ responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
+ cvCopy( responses, responses_copy);
+ responses = responses_copy;
+}
CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
{
CvDTreeNode* root = 0;
CvMat* isubsample_idx = 0;
CvMat* subsample_co = 0;
-
+
CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
__BEGIN__;
if( !data_root )
CV_ERROR( CV_StsError, "No training data has been set" );
-
+
if( _subsample_idx )
CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
int* sidx = isubsample_idx->data.i;
// co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
int* co, cur_ofs = 0;
- int vi, i, total = data_root->sample_count;
+ int vi, i;
+ int work_var_count = get_work_var_count();
int count = isubsample_idx->rows + isubsample_idx->cols - 1;
+
root = new_node( 0, count, 1, 0 );
- CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
+ CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
cvZero( subsample_co );
co = subsample_co->data.i;
for( i = 0; i < count; i++ )
co[sidx[i]*2]++;
- for( i = 0; i < total; i++ )
+ for( i = 0; i < sample_count; i++ )
{
if( co[i*2] )
{
co[i*2+1] = -1;
}
- for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
+ cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
+ for( vi = 0; vi < work_var_count; vi++ )
{
int ci = get_var_type(vi);
if( ci >= 0 || vi >= var_count )
{
- const int* src = get_cat_var_data( data_root, vi );
- int* dst = get_cat_var_data( root, vi );
int num_valid = 0;
+ const int* src = get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
- for( i = 0; i < count; i++ )
+ if (is_buf_16u)
+ {
+ unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+ vi*sample_count + root->offset);
+ for( i = 0; i < count; i++ )
+ {
+ int val = src[sidx[i]];
+ udst[i] = (unsigned short)val;
+ num_valid += val >= 0;
+ }
+ }
+ else
{
- int val = src[sidx[i]];
- dst[i] = val;
- num_valid += val >= 0;
+ int* idst = buf->data.i + root->buf_idx*buf->cols +
+ vi*sample_count + root->offset;
+ for( i = 0; i < count; i++ )
+ {
+ int val = src[sidx[i]];
+ idst[i] = val;
+ num_valid += val >= 0;
+ }
}
if( vi < var_count )
}
else
{
- const CvPair32s32f* src = get_ord_var_data( data_root, vi );
- CvPair32s32f* dst = get_ord_var_data( root, vi );
+ int *src_idx_buf = (int*)(uchar*)inn_buf;
+ float *src_val_buf = (float*)(src_idx_buf + sample_count);
+ int* sample_indices_buf = (int*)(src_val_buf + sample_count);
+ const int* src_idx = 0;
+ const float* src_val = 0;
+ get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
int j = 0, idx, count_i;
int num_valid = data_root->get_num_valid(vi);
- for( i = 0; i < num_valid; i++ )
+ if (is_buf_16u)
{
- idx = src[i].i;
- count_i = co[idx*2];
- if( count_i )
+ unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+ vi*sample_count + data_root->offset);
+ for( i = 0; i < num_valid; i++ )
{
- float val = src[i].val;
- for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
- {
- dst[j].val = val;
- dst[j].i = cur_ofs;
- }
+ idx = src_idx[i];
+ count_i = co[idx*2];
+ if( count_i )
+ for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+ udst_idx[j] = (unsigned short)cur_ofs;
}
- }
- root->set_num_valid(vi, j);
+ root->set_num_valid(vi, j);
- for( ; i < total; i++ )
+ for( ; i < sample_count; i++ )
+ {
+ idx = src_idx[i];
+ count_i = co[idx*2];
+ if( count_i )
+ for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+ udst_idx[j] = (unsigned short)cur_ofs;
+ }
+ }
+ else
{
- idx = src[i].i;
- count_i = co[idx*2];
- if( count_i )
+ int* idst_idx = buf->data.i + root->buf_idx*buf->cols +
+ vi*sample_count + root->offset;
+ for( i = 0; i < num_valid; i++ )
{
- float val = src[i].val;
- for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
- {
- dst[j].val = val;
- dst[j].i = cur_ofs;
- }
+ idx = src_idx[i];
+ count_i = co[idx*2];
+ if( count_i )
+ for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+ idst_idx[j] = cur_ofs;
+ }
+
+ root->set_num_valid(vi, j);
+
+ for( ; i < sample_count; i++ )
+ {
+ idx = src_idx[i];
+ count_i = co[idx*2];
+ if( count_i )
+ for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
+ idst_idx[j] = cur_ofs;
}
}
}
}
+ // sample indices subsampling
+ const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
+ if (is_buf_16u)
+ {
+ unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols +
+ get_work_var_count()*sample_count + root->offset);
+ for (i = 0; i < count; i++)
+ sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
+ }
+ else
+ {
+ int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols +
+ get_work_var_count()*sample_count + root->offset;
+ for (i = 0; i < count; i++)
+ sample_idx_dst[i] = sample_idx_src[sidx[i]];
+ }
}
__END__;
{
CvMat* subsample_idx = 0;
CvMat* subsample_co = 0;
-
+
CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
__BEGIN__;
int* sidx = 0;
int* co = 0;
+ cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
if( _subsample_idx )
{
CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
}
}
- memset( missing, 1, count*var_count );
+ if( missing )
+ memset( missing, 1, count*var_count );
for( vi = 0; vi < var_count; vi++ )
{
if( ci >= 0 ) // categorical
{
float* dst = values + vi;
- uchar* m = missing + vi;
- const int* src = get_cat_var_data(data_root, vi);
+ uchar* m = missing ? missing + vi : 0;
+ const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
- for( i = 0; i < count; i++, dst += var_count, m += var_count )
+ for( i = 0; i < count; i++, dst += var_count )
{
int idx = sidx ? sidx[i] : i;
int val = src[idx];
*dst = (float)val;
- *m = val < 0;
+ if( m )
+ {
+ *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
+ m += var_count;
+ }
}
}
else // ordered
{
float* dst = values + vi;
- uchar* m = missing + vi;
- const CvPair32s32f* src = get_ord_var_data(data_root, vi);
+ uchar* m = missing ? missing + vi : 0;
int count1 = data_root->get_num_valid(vi);
+ float *src_val_buf = (float*)(uchar*)inn_buf;
+ int* src_idx_buf = (int*)(src_val_buf + sample_count);
+ int* sample_indices_buf = src_idx_buf + sample_count;
+ const float *src_val = 0;
+ const int* src_idx = 0;
+ get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
for( i = 0; i < count1; i++ )
{
- int idx = src[i].i;
+ int idx = src_idx[i];
int count_i = 1;
if( co )
{
cur_ofs = idx*var_count;
if( count_i )
{
- float val = src[i].val;
+ float val = src_val[i];
for( ; count_i > 0; count_i--, cur_ofs += var_count )
{
dst[cur_ofs] = val;
- m[cur_ofs] = 0;
+ if( m )
+ m[cur_ofs] = 0;
}
}
}
}
// copy responses
- if( is_classifier )
+ if( responses )
{
- const int* src = get_class_labels(data_root);
- for( i = 0; i < count; i++ )
+ if( is_classifier )
{
- int idx = sidx ? sidx[i] : i;
- int val = get_class_idx ? src[idx] :
- cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
- responses[i] = (float)val;
+ const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
+ for( i = 0; i < count; i++ )
+ {
+ int idx = sidx ? sidx[i] : i;
+ int val = get_class_idx ? src[idx] :
+ cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
+ responses[i] = (float)val;
+ }
}
- }
- else
- {
- const float* src = get_ord_responses(data_root);
- for( i = 0; i < count; i++ )
+ else
{
- int idx = sidx ? sidx[i] : i;
- responses[i] = src[idx];
+ float* val_buf = (float*)(uchar*)inn_buf;
+ int* sample_idx_buf = (int*)(val_buf + sample_count);
+ const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
+ for( i = 0; i < count; i++ )
+ {
+ int idx = sidx ? sidx[i] : i;
+ responses[i] = _values[idx];
+ }
}
}
{
CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
split->var_idx = vi;
+ split->condensed_idx = INT_MIN;
split->ord.c = cmp_val;
split->ord.split_point = split_point;
split->inversed = inversed;
int i, n = (max_c_count + 31)/32;
split->var_idx = vi;
+ split->condensed_idx = INT_MIN;
split->inversed = 0;
split->quality = quality;
for( i = 0; i < n; i++ )
cvReleaseMat( &direction );
cvReleaseMat( &split_buf );
cvReleaseMemStorage( &temp_storage );
+ cvReleaseMat( &responses_copy );
cv_heap = nv_heap = 0;
}
cvReleaseMat( &cat_ofs );
cvReleaseMat( &cat_map );
cvReleaseMat( &priors );
-
+ cvReleaseMat( &priors_mult );
+
node_heap = split_heap = 0;
sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
- have_cv_labels = have_priors = is_classifier = false;
+ have_labels = have_priors = is_classifier = false;
buf_count = buf_size = 0;
shared = false;
-
+
data_root = 0;
rng = cvRNG(-1);
return var_type->data.i[vi];
}
-
-CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
+void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
+ const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
{
- int oi = ~get_var_type(vi);
- assert( 0 <= oi && oi < ord_var_count );
- return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
- n->offset + oi*n->sample_count*2);
+ int vidx = var_idx ? var_idx->data.i[vi] : vi;
+ int node_sample_count = n->sample_count;
+ int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
+
+ const int* sample_indices = get_sample_indices(n, sample_indices_buf);
+
+ if( !is_buf_16u )
+ *sorted_indices = buf->data.i + n->buf_idx*buf->cols +
+ vi*sample_count + n->offset;
+ else {
+ const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
+ vi*sample_count + n->offset );
+ for( int i = 0; i < node_sample_count; i++ )
+ sorted_indices_buf[i] = short_indices[i];
+ *sorted_indices = sorted_indices_buf;
+ }
+
+ if( tflag == CV_ROW_SAMPLE )
+ {
+ for( int i = 0; i < node_sample_count &&
+ ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
+ {
+ int idx = (*sorted_indices)[i];
+ idx = sample_indices[idx];
+ ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
+ }
+ }
+ else
+ for( int i = 0; i < node_sample_count &&
+ ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
+ {
+ int idx = (*sorted_indices)[i];
+ idx = sample_indices[idx];
+ ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
+ }
+
+ *ord_values = ord_values_buf;
}
-int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
+const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
{
- return get_cat_var_data( n, var_count );
+ if (is_classifier)
+ return get_cat_var_data( n, var_count, labels_buf);
+ return 0;
}
-
-float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
+const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
{
- return (float*)get_cat_var_data( n, var_count );
+ return get_cat_var_data( n, get_work_var_count(), indices_buf );
}
-
-int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
+const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
{
- return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
+ int sample_count = n->sample_count;
+ int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
+ const int* indices = get_sample_indices(n, sample_indices_buf);
+
+ for( int i = 0; i < sample_count &&
+ (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
+ {
+ int idx = indices[i];
+ values_buf[i] = *(responses->data.fl + idx * r_step);
+ }
+
+ return values_buf;
}
-int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
+const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
{
- int ci = get_var_type(vi);
- assert( 0 <= ci && ci <= cat_var_count + 1 );
- return buf->data.i + n->buf_idx*buf->cols + n->offset +
- (ord_var_count*2 + ci)*n->sample_count;
+ if (have_labels)
+ return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
+ return 0;
}
-float* CvDTreeTrainData::get_weights( CvDTreeNode* n )
+const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
{
- return have_weights ?
- (float*)get_cat_var_data( n, var_count + 1 + (params.cv_folds > 0) ) : 0;
+ const int* cat_values = 0;
+ if( !is_buf_16u )
+ cat_values = buf->data.i + n->buf_idx*buf->cols +
+ vi*sample_count + n->offset;
+ else {
+ const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols +
+ vi*sample_count + n->offset);
+ for( int i = 0; i < n->sample_count; i++ )
+ cat_values_buf[i] = short_values[i];
+ cat_values = cat_values_buf;
+ }
+ return cat_values;
}
}
-/////////////////////// Decision Tree /////////////////////////
-
-CvDTree::CvDTree()
+void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
{
- data = 0;
- var_importance = 0;
- default_model_name = "my_tree";
+ CV_FUNCNAME( "CvDTreeTrainData::write_params" );
- clear();
-}
+ __BEGIN__;
+ int vi, vcount = var_count;
-void CvDTree::clear()
-{
- cvReleaseMat( &var_importance );
- if( data )
+ cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
+ cvWriteInt( fs, "var_all", var_all );
+ cvWriteInt( fs, "var_count", var_count );
+ cvWriteInt( fs, "ord_var_count", ord_var_count );
+ cvWriteInt( fs, "cat_var_count", cat_var_count );
+
+ cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
+ cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
+
+ if( is_classifier )
{
- if( !data->shared )
- delete data;
- else
- free_tree();
- data = 0;
+ cvWriteInt( fs, "max_categories", params.max_categories );
+ }
+ else
+ {
+ cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
}
- root = 0;
- pruned_tree_idx = -1;
-}
+ cvWriteInt( fs, "max_depth", params.max_depth );
+ cvWriteInt( fs, "min_sample_count", params.min_sample_count );
+ cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
-CvDTree::~CvDTree()
-{
- clear();
-}
+ if( params.cv_folds > 1 )
+ {
+ cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
+ cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
+ }
+ if( priors )
+ cvWrite( fs, "priors", priors );
-const CvDTreeNode* CvDTree::get_root() const
-{
- return root;
-}
+ cvEndWriteStruct( fs );
+ if( var_idx )
+ cvWrite( fs, "var_idx", var_idx );
-int CvDTree::get_pruned_tree_idx() const
-{
- return pruned_tree_idx;
-}
+ cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
+ for( vi = 0; vi < vcount; vi++ )
+ cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
-CvDTreeTrainData* CvDTree::get_data()
-{
- return data;
+ cvEndWriteStruct( fs );
+
+ if( cat_count && (cat_var_count > 0 || is_classifier) )
+ {
+ CV_ASSERT( cat_count != 0 );
+ cvWrite( fs, "cat_count", cat_count );
+ cvWrite( fs, "cat_map", cat_map );
+ }
+
+ __END__;
}
-bool CvDTree::train( const CvMat* _train_data, int _tflag,
- const CvMat* _responses, const CvMat* _var_idx,
- const CvMat* _sample_idx, const CvMat* _var_type,
- const CvMat* _missing_mask, CvDTreeParams _params )
+void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
{
- bool result = false;
-
- CV_FUNCNAME( "CvDTree::train" );
+ CV_FUNCNAME( "CvDTreeTrainData::read_params" );
__BEGIN__;
- clear();
- data = new CvDTreeTrainData( _train_data, _tflag, _responses,
- _var_idx, _sample_idx, _var_type,
- _missing_mask, _params, false );
- CV_CALL( result = do_train(0));
+ CvFileNode *tparams_node, *vartype_node;
+ CvSeqReader reader;
+ int vi, max_split_size, tree_block_size;
- __END__;
+ is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
+ var_all = cvReadIntByName( fs, node, "var_all" );
+ var_count = cvReadIntByName( fs, node, "var_count", var_all );
+ cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
+ ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
- return result;
-}
+ tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
+
+ if( tparams_node ) // training parameters are not necessary
+ {
+ params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
+
+ if( is_classifier )
+ {
+ params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
+ }
+ else
+ {
+ params.regression_accuracy =
+ (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
+ }
+
+ params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
+ params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
+ params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
+
+ if( params.cv_folds > 1 )
+ {
+ params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
+ params.truncate_pruned_tree =
+ cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
+ }
+
+ priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
+ if( priors )
+ {
+ if( !CV_IS_MAT(priors) )
+ CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
+ priors_mult = cvCloneMat( priors );
+ }
+ }
+
+ CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
+ if( var_idx )
+ {
+ if( !CV_IS_MAT(var_idx) ||
+ (var_idx->cols != 1 && var_idx->rows != 1) ||
+ var_idx->cols + var_idx->rows - 1 != var_count ||
+ CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
+ CV_ERROR( CV_StsParseError,
+ "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
+
+ for( vi = 0; vi < var_count; vi++ )
+ if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
+ CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
+ }
+
+ ////// read var type
+ CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
+
+ cat_var_count = 0;
+ ord_var_count = -1;
+ vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
+
+ if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
+ var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
+ else
+ {
+ if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
+ vartype_node->data.seq->total != var_count )
+ CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
+
+ cvStartReadSeq( vartype_node->data.seq, &reader );
+
+ for( vi = 0; vi < var_count; vi++ )
+ {
+ CvFileNode* n = (CvFileNode*)reader.ptr;
+ if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
+ CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
+ var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
+ CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
+ }
+ }
+ var_type->data.i[var_count] = cat_var_count;
+
+ ord_var_count = ~ord_var_count;
+ if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
+ CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
+ //////
+
+ if( cat_var_count > 0 || is_classifier )
+ {
+ int ccount, total_c_count = 0;
+ CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
+ CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
+
+ if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
+ (cat_count->cols != 1 && cat_count->rows != 1) ||
+ CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
+ cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
+ (cat_map->cols != 1 && cat_map->rows != 1) ||
+ CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
+ CV_ERROR( CV_StsParseError,
+ "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
+
+ ccount = cat_var_count + is_classifier;
+
+ CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
+ cat_ofs->data.i[0] = 0;
+ max_c_count = 1;
+
+ for( vi = 0; vi < ccount; vi++ )
+ {
+ int val = cat_count->data.i[vi];
+ if( val <= 0 )
+ CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
+ max_c_count = MAX( max_c_count, val );
+ cat_ofs->data.i[vi+1] = total_c_count += val;
+ }
+
+ if( cat_map->cols + cat_map->rows - 1 != total_c_count )
+ CV_ERROR( CV_StsBadSize,
+ "cat_map vector length is not equal to the total number of categories in all categorical vars" );
+ }
+
+ max_split_size = cvAlign(sizeof(CvDTreeSplit) +
+ (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
+
+ tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
+ tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
+ CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
+ CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
+ sizeof(CvDTreeNode), tree_storage ));
+ CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
+ max_split_size, tree_storage ));
+
+ __END__;
+}
+
+/////////////////////// Decision Tree /////////////////////////
+
+CvDTree::CvDTree()
+{
+ data = 0;
+ var_importance = 0;
+ default_model_name = "my_tree";
+
+ clear();
+}
+
+
+void CvDTree::clear()
+{
+ cvReleaseMat( &var_importance );
+ if( data )
+ {
+ if( !data->shared )
+ delete data;
+ else
+ free_tree();
+ data = 0;
+ }
+ root = 0;
+ pruned_tree_idx = -1;
+}
+
+
+CvDTree::~CvDTree()
+{
+ clear();
+}
+
+
+const CvDTreeNode* CvDTree::get_root() const
+{
+ return root;
+}
+
+
+int CvDTree::get_pruned_tree_idx() const
+{
+ return pruned_tree_idx;
+}
+
+
+CvDTreeTrainData* CvDTree::get_data()
+{
+ return data;
+}
+bool CvDTree::train( const CvMat* _train_data, int _tflag,
+ const CvMat* _responses, const CvMat* _var_idx,
+ const CvMat* _sample_idx, const CvMat* _var_type,
+ const CvMat* _missing_mask, CvDTreeParams _params )
+{
+ bool result = false;
+
+ CV_FUNCNAME( "CvDTree::train" );
+
+ __BEGIN__;
+
+ clear();
+ data = new CvDTreeTrainData( _train_data, _tflag, _responses,
+ _var_idx, _sample_idx, _var_type,
+ _missing_mask, _params, false );
+ CV_CALL( result = do_train(0) );
+
+ __END__;
+
+ return result;
+}
+
+bool CvDTree::train( const Mat& _train_data, int _tflag,
+ const Mat& _responses, const Mat& _var_idx,
+ const Mat& _sample_idx, const Mat& _var_type,
+ const Mat& _missing_mask, CvDTreeParams _params )
+{
+ CvMat tdata = _train_data, responses = _responses, vidx=_var_idx,
+ sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask;
+ return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
+ vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
+}
+
+
+bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
+{
+ bool result = false;
+
+ CV_FUNCNAME( "CvDTree::train" );
+
+ __BEGIN__;
+
+ const CvMat* values = _data->get_values();
+ const CvMat* response = _data->get_responses();
+ const CvMat* missing = _data->get_missing();
+ const CvMat* var_types = _data->get_var_types();
+ const CvMat* train_sidx = _data->get_train_sample_idx();
+ const CvMat* var_idx = _data->get_var_idx();
+
+ CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
+ train_sidx, var_types, missing, _params ) );
+
+ __END__;
+
+ return result;
+}
+
bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
{
bool result = false;
root = data->subsample_data( _subsample_idx );
CV_CALL( try_split_node(root));
-
+
if( data->params.cv_folds > 0 )
CV_CALL( prune_cv());
}
-#define DTREE_CAT_DIR(idx,subset) \
- (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
-
void CvDTree::try_split_node( CvDTreeNode* node )
{
CvDTreeSplit* best_split = 0;
}
else if( can_split )
{
- const float* responses = data->get_ord_responses( node );
- float diff = responses[n-1] - responses[0];
- if( diff < data->params.regression_accuracy )
+ if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
can_split = false;
}
// TODO: check the split quality ...
node->split = best_split;
}
-
if( !can_split || !best_split )
{
data->free_node_data(node);
}
quality_scale = calc_node_dir( node );
-
if( data->params.use_surrogates )
{
// find all the surrogate splits
}
}
}
-
split_node_data( node );
try_split_node( node->left );
try_split_node( node->right );
// the function returns scale coefficients for surrogate split quality factors.
// the scale is applied to normalize surrogate split quality relatively to the
// best (primary) split quality. That is, if a surrogate split is absolutely
-// identical to the primary split, its quality will be set to the maximum value =
+// identical to the primary split, its quality will be set to the maximum value =
// quality of the primary split; otherwise, it will be lower.
// besides, the function compute node->maxlr,
// minimum possible quality (w/o considering the above mentioned scale)
if( data->get_var_type(vi) >= 0 ) // split on categorical var
{
- const int* labels = data->get_cat_var_data(node,vi);
+ cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
+ int* labels_buf = (int*)inn_buf;
+ const int* labels = data->get_cat_var_data( node, vi, labels_buf );
const int* subset = node->split->subset;
-
if( !data->have_priors )
{
int sum = 0, sum_abs = 0;
for( i = 0; i < n; i++ )
{
int idx = labels[i];
- int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
+ int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
+ CV_DTREE_CAT_DIR(idx,subset) : 0;
sum += d; sum_abs += d & 1;
dir[i] = (char)d;
}
}
else
{
- const int* responses = data->get_class_labels(node);
- const double* priors = data->priors->data.db;
+ const double* priors = data->priors_mult->data.db;
double sum = 0, sum_abs = 0;
+ int* responses_buf = labels_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
for( i = 0; i < n; i++ )
{
int idx = labels[i];
double w = priors[responses[i]];
- int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
+ int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
sum += d*w; sum_abs += (d & 1)*w;
dir[i] = (char)d;
}
}
else // split on ordered var
{
- const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
int split_point = node->split->ord.split_point;
int n1 = node->get_num_valid(vi);
-
+ cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
+ float* val_buf = (float*)(uchar*)inn_buf;
+ int* sorted_buf = (int*)(val_buf + n);
+ int* sample_idx_buf = sorted_buf + n;
+ const float* val = 0;
+ const int* sorted = 0;
+ data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
+
assert( 0 <= split_point && split_point < n1-1 );
if( !data->have_priors )
{
for( i = 0; i <= split_point; i++ )
- dir[sorted[i].i] = (char)-1;
+ dir[sorted[i]] = (char)-1;
for( ; i < n1; i++ )
- dir[sorted[i].i] = (char)1;
+ dir[sorted[i]] = (char)1;
for( ; i < n; i++ )
- dir[sorted[i].i] = (char)0;
+ dir[sorted[i]] = (char)0;
L = split_point-1;
R = n1 - split_point + 1;
}
else
{
- const int* responses = data->get_class_labels(node);
- const double* priors = data->priors->data.db;
+ const double* priors = data->priors_mult->data.db;
+ int* responses_buf = sample_idx_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
L = R = 0;
for( i = 0; i <= split_point; i++ )
{
- int idx = sorted[i].i;
+ int idx = sorted[i];
double w = priors[responses[idx]];
dir[idx] = (char)-1;
L += w;
for( ; i < n1; i++ )
{
- int idx = sorted[i].i;
+ int idx = sorted[i];
double w = priors[responses[idx]];
dir[idx] = (char)1;
R += w;
}
for( ; i < n; i++ )
- dir[sorted[i].i] = (char)0;
+ dir[sorted[i]] = (char)0;
}
}
-
node->maxlr = MAX( L, R );
return node->split->quality/(L + R);
}
-CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
+namespace cv
{
- int vi;
- CvDTreeSplit *best_split = 0, *split = 0, *t;
- for( vi = 0; vi < data->var_count; vi++ )
+DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
+{
+ tree = _tree;
+ node = _node;
+ splitSize = tree->get_data()->split_heap->elem_size;
+
+ bestSplit = (CvDTreeSplit*)(new char[splitSize]);
+ memset((CvDTreeSplit*)bestSplit, 0, splitSize);
+ bestSplit->quality = -1;
+ bestSplit->condensed_idx = INT_MIN;
+ split = (CvDTreeSplit*)(new char[splitSize]);
+ memset((CvDTreeSplit*)split, 0, splitSize);
+ //haveSplit = false;
+}
+
+DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
+{
+ tree = finder.tree;
+ node = finder.node;
+ splitSize = tree->get_data()->split_heap->elem_size;
+
+ bestSplit = (CvDTreeSplit*)(new char[splitSize]);
+ memcpy((CvDTreeSplit*)(bestSplit), (const CvDTreeSplit*)finder.bestSplit, splitSize);
+ split = (CvDTreeSplit*)(new char[splitSize]);
+ memset((CvDTreeSplit*)split, 0, splitSize);
+}
+
+void DTreeBestSplitFinder::operator()(const BlockedRange& range)
+{
+ int vi, vi1 = range.begin(), vi2 = range.end();
+ int n = node->sample_count;
+ CvDTreeTrainData* data = tree->get_data();
+ AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
+
+ for( vi = vi1; vi < vi2; vi++ )
{
+ CvDTreeSplit *res;
int ci = data->get_var_type(vi);
if( node->get_num_valid(vi) <= 1 )
continue;
if( data->is_classifier )
{
if( ci >= 0 )
- split = find_split_cat_class( node, vi );
+ res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
else
- split = find_split_ord_class( node, vi );
+ res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
}
else
{
if( ci >= 0 )
- split = find_split_cat_reg( node, vi );
+ res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
else
- split = find_split_ord_reg( node, vi );
+ res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
}
- if( split )
- {
- if( !best_split || best_split->quality < split->quality )
- CV_SWAP( best_split, split, t );
- if( split )
- cvSetRemoveByPtr( data->split_heap, split );
- }
+ if( res && bestSplit->quality < split->quality )
+ memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );
}
+}
- return best_split;
+void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
+{
+ if( bestSplit->quality < rhs.bestSplit->quality )
+ memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)rhs.bestSplit, splitSize );
}
+}
+
+
+CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
+{
+ DTreeBestSplitFinder finder( this, node );
+
+ cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
+ CvDTreeSplit *bestSplit = data->new_split_cat( 0, -1.0f );
+ memcpy( bestSplit, finder.bestSplit, finder.splitSize );
+
+ return bestSplit;
+}
-CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
+ float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
{
const float epsilon = FLT_EPSILON*2;
- const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
- const int* responses = data->get_class_labels(node);
int n = node->sample_count;
int n1 = node->get_num_valid(vi);
int m = data->get_num_classes();
+
+ int base_size = 2*m*sizeof(int);
+ cv::AutoBuffer<uchar> inn_buf(base_size);
+ if( !_ext_buf )
+ inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
+ uchar* base_buf = (uchar*)inn_buf;
+ uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
+ float* values_buf = (float*)ext_buf;
+ int* sorted_indices_buf = (int*)(values_buf + n);
+ int* sample_indices_buf = sorted_indices_buf + n;
+ const float* values = 0;
+ const int* sorted_indices = 0;
+ data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
+ &sorted_indices, sample_indices_buf );
+ int* responses_buf = sample_indices_buf + n;
+ const int* responses = data->get_class_labels( node, responses_buf );
+
const int* rc0 = data->counts->data.i;
- int* lc = (int*)(rc0 + m);
+ int* lc = (int*)base_buf;
int* rc = lc + m;
int i, best_i = -1;
- double lsum2 = 0, rsum2 = 0, best_val = 0;
- const double* priors = data->have_priors ? data->priors->data.db : 0;
+ double lsum2 = 0, rsum2 = 0, best_val = init_quality;
+ const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
// init arrays of class instance counters on both sides of the split
for( i = 0; i < m; i++ )
// compensate for missing values
for( i = n1; i < n; i++ )
- rc[responses[sorted[i].i]]--;
+ {
+ rc[responses[sorted_indices[i]]]--;
+ }
if( !priors )
{
for( i = 0; i < n1 - 1; i++ )
{
- int idx = responses[sorted[i].i];
+ int idx = responses[sorted_indices[i]];
int lv, rv;
L++; R--;
lv = lc[idx]; rv = rc[idx];
rsum2 -= rv*2 - 1;
lc[idx] = lv + 1; rc[idx] = rv - 1;
- if( sorted[i].val + epsilon < sorted[i+1].val )
+ if( values[i] + epsilon < values[i+1] )
{
- double val = lsum2/L + rsum2/R;
+ double val = (lsum2*R + rsum2*L)/((double)L*R);
if( best_val < val )
{
best_val = val;
for( i = 0; i < n1 - 1; i++ )
{
- int idx = responses[sorted[i].i];
+ int idx = responses[sorted_indices[i]];
int lv, rv;
double p = priors[idx], p2 = p*p;
L += p; R -= p;
rsum2 -= p2*(rv*2 - 1);
lc[idx] = lv + 1; rc[idx] = rv - 1;
- if( sorted[i].val + epsilon < sorted[i+1].val )
+ if( values[i] + epsilon < values[i+1] )
{
- double val = lsum2/L + rsum2/R;
+ double val = (lsum2*R + rsum2*L)/((double)L*R);
if( best_val < val )
{
best_val = val;
}
}
- return best_i >= 0 ? data->new_split_ord( vi,
- (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
- 0, (float)best_val ) : 0;
+ CvDTreeSplit* split = 0;
+ if( best_i >= 0 )
+ {
+ split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
+ split->var_idx = vi;
+ split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
+ split->ord.split_point = best_i;
+ split->inversed = 0;
+ split->quality = (float)best_val;
+ }
+ return split;
}
int iters = 0, max_iters = 100;
int i, j, idx;
double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
- double *v_weights = buf, *c_weights = buf + k;
+ double *v_weights = buf, *c_weights = buf + n;
bool modified = true;
CvRNG* r = &data->rng;
// assign labels randomly
- for( i = idx = 0; i < n; i++ )
+ for( i = 0; i < n; i++ )
{
int sum = 0;
const int* v = vectors + i*m;
- labels[i] = idx++;
- idx &= idx < k ? -1 : 0;
+ labels[i] = i < k ? i : (cvRandInt(r) % k);
// compute weight of each vector
for( j = 0; j < m; j++ )
}
-CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
+ CvDTreeSplit* _split, uchar* _ext_buf )
{
- CvDTreeSplit* split;
- const int* labels = data->get_cat_var_data(node, vi);
- const int* responses = data->get_class_labels(node);
int ci = data->get_var_type(vi);
int n = node->sample_count;
int m = data->get_num_classes();
int _mi = data->cat_count->data.i[ci], mi = _mi;
- const int* rc0 = data->counts->data.i;
- int* lc = (int*)(rc0 + m);
+
+ int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
+ if( m > 2 && mi > data->params.max_categories )
+ base_size += (m*min(data->params.max_categories, n) + mi)*sizeof(int);
+ else
+ base_size += mi*sizeof(int*);
+ cv::AutoBuffer<uchar> inn_buf(base_size);
+ if( !_ext_buf )
+ inn_buf.allocate(base_size + 2*n*sizeof(int));
+ uchar* base_buf = (uchar*)inn_buf;
+ uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
+
+ int* lc = (int*)base_buf;
int* rc = lc + m;
int* _cjk = rc + m*2, *cjk = _cjk;
- double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
+ double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
+
+ int* labels_buf = (int*)ext_buf;
+ const int* labels = data->get_cat_var_data(node, vi, labels_buf);
+ int* responses_buf = labels_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
+
int* cluster_labels = 0;
int** int_ptr = 0;
int i, j, k, idx;
double L = 0, R = 0;
- double best_val = 0;
+ double best_val = init_quality;
int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
- const double* priors = data->priors->data.db;
+ const double* priors = data->priors_mult->data.db;
// init array of counters:
// c_{jk} - number of samples that have vi-th input variable = j and response = k.
for( i = 0; i < n; i++ )
{
- j = labels[i];
- k = responses[i];
- cjk[j*m + k]++;
+ j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
+ k = responses[i];
+ cjk[j*m + k]++;
}
if( m > 2 )
if( mi > data->params.max_categories )
{
mi = MIN(data->params.max_categories, n);
- cjk += _mi*m;
- cluster_labels = cjk + mi*m;
+ cjk = (int*)(c_weights + _mi);
+ cluster_labels = cjk + m*mi;
cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
}
subset_i = 1;
else
{
assert( m == 2 );
- int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
+ int_ptr = (int**)(c_weights + _mi);
for( j = 0; j < mi; j++ )
int_ptr[j] = cjk + j*2 + 1;
icvSortIntPtr( int_ptr, mi, 0 );
if( L > FLT_EPSILON && R > FLT_EPSILON )
{
- double val = lsum2/L + rsum2/R;
+ double val = (lsum2*R + rsum2*L)/((double)L*R);
if( best_val < val )
{
best_val = val;
}
}
- if( best_subset < 0 )
- return 0;
-
- split = data->new_split_cat( vi, (float)best_val );
-
- if( m == 2 )
+ CvDTreeSplit* split = 0;
+ if( best_subset >= 0 )
{
- for( i = 0; i <= best_subset; i++ )
+ split = _split ? _split : data->new_split_cat( 0, -1.0f );
+ split->var_idx = vi;
+ split->quality = (float)best_val;
+ memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
+ if( m == 2 )
{
- idx = (int)(int_ptr[i] - cjk) >> 1;
- split->subset[idx >> 5] |= 1 << (idx & 31);
+ for( i = 0; i <= best_subset; i++ )
+ {
+ idx = (int)(int_ptr[i] - cjk) >> 1;
+ split->subset[idx >> 5] |= 1 << (idx & 31);
+ }
}
- }
- else
- {
- for( i = 0; i < _mi; i++ )
+ else
{
- idx = cluster_labels ? cluster_labels[i] : i;
- if( best_subset & (1 << idx) )
- split->subset[i >> 5] |= 1 << (i & 31);
+ for( i = 0; i < _mi; i++ )
+ {
+ idx = cluster_labels ? cluster_labels[i] : i;
+ if( best_subset & (1 << idx) )
+ split->subset[i >> 5] |= 1 << (i & 31);
+ }
}
}
-
return split;
}
-CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
{
const float epsilon = FLT_EPSILON*2;
- const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
- const float* responses = data->get_ord_responses(node);
int n = node->sample_count;
int n1 = node->get_num_valid(vi);
+
+ cv::AutoBuffer<uchar> inn_buf;
+ if( !_ext_buf )
+ inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
+ uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
+ float* values_buf = (float*)ext_buf;
+ int* sorted_indices_buf = (int*)(values_buf + n);
+ int* sample_indices_buf = sorted_indices_buf + n;
+ const float* values = 0;
+ const int* sorted_indices = 0;
+ data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
+ float* responses_buf = (float*)(sample_indices_buf + n);
+ const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
+
int i, best_i = -1;
- double best_val = 0, lsum = 0, rsum = node->value*n;
+ double best_val = init_quality, lsum = 0, rsum = node->value*n;
int L = 0, R = n1;
// compensate for missing values
for( i = n1; i < n; i++ )
- rsum -= responses[sorted[i].i];
+ rsum -= responses[sorted_indices[i]];
// find the optimal split
for( i = 0; i < n1 - 1; i++ )
{
- float t = responses[sorted[i].i];
+ float t = responses[sorted_indices[i]];
L++; R--;
lsum += t;
rsum -= t;
- if( sorted[i].val + epsilon < sorted[i+1].val )
+ if( values[i] + epsilon < values[i+1] )
{
- double val = lsum*lsum/L + rsum*rsum/R;
+ double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
if( best_val < val )
{
best_val = val;
}
}
- return best_i >= 0 ? data->new_split_ord( vi,
- (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
- 0, (float)best_val ) : 0;
+ CvDTreeSplit* split = 0;
+ if( best_i >= 0 )
+ {
+ split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
+ split->var_idx = vi;
+ split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
+ split->ord.split_point = best_i;
+ split->inversed = 0;
+ split->quality = (float)best_val;
+ }
+ return split;
}
-
-CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
{
- CvDTreeSplit* split;
- const int* labels = data->get_cat_var_data(node, vi);
- const float* responses = data->get_ord_responses(node);
int ci = data->get_var_type(vi);
int n = node->sample_count;
int mi = data->cat_count->data.i[ci];
- double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
- int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
- double** sum_ptr = 0;
+
+ int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
+ cv::AutoBuffer<uchar> inn_buf(base_size);
+ if( !_ext_buf )
+ inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
+ uchar* base_buf = (uchar*)inn_buf;
+ uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
+ int* labels_buf = (int*)ext_buf;
+ const int* labels = data->get_cat_var_data(node, vi, labels_buf);
+ float* responses_buf = (float*)(labels_buf + n);
+ int* sample_indices_buf = (int*)(responses_buf + n);
+ const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
+
+ double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
+ int* counts = (int*)(sum + mi) + 1;
+ double** sum_ptr = (double**)(counts + mi);
int i, L = 0, R = 0;
- double best_val = 0, lsum = 0, rsum = 0;
+ double best_val = init_quality, lsum = 0, rsum = 0;
int best_subset = -1, subset_i;
for( i = -1; i < mi; i++ )
// calculate sum response and weight of each category of the input var
for( i = 0; i < n; i++ )
{
- int idx = labels[i];
+ int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
double s = sum[idx] + responses[i];
int nc = counts[idx] + 1;
sum[idx] = s;
icvSortDblPtr( sum_ptr, mi, 0 );
- // revert back to unnormalized sum
+ // revert back to unnormalized sums
// (there should be a very little loss of accuracy)
for( i = 0; i < mi; i++ )
sum[i] *= counts[i];
double s = sum[idx];
lsum += s; L += ni;
rsum -= s; R -= ni;
-
+
if( L && R )
{
- double val = lsum*lsum/L + rsum*rsum/R;
+ double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
if( best_val < val )
{
best_val = val;
}
}
- if( best_subset < 0 )
- return 0;
-
- split = data->new_split_cat( vi, (float)best_val );
- for( i = 0; i <= best_subset; i++ )
+ CvDTreeSplit* split = 0;
+ if( best_subset >= 0 )
{
- int idx = (int)(sum_ptr[i] - sum);
- split->subset[idx >> 5] |= 1 << (idx & 31);
+ split = _split ? _split : data->new_split_cat( 0, -1.0f);
+ split->var_idx = vi;
+ split->quality = (float)best_val;
+ memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
+ for( i = 0; i <= best_subset; i++ )
+ {
+ int idx = (int)(sum_ptr[i] - sum);
+ split->subset[idx >> 5] |= 1 << (idx & 31);
+ }
}
-
return split;
}
-
-CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
{
const float epsilon = FLT_EPSILON*2;
- const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
const char* dir = (char*)data->direction->data.ptr;
- int n1 = node->get_num_valid(vi);
+ int n = node->sample_count, n1 = node->get_num_valid(vi);
+ cv::AutoBuffer<uchar> inn_buf;
+ if( !_ext_buf )
+ inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
+ uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
+ float* values_buf = (float*)ext_buf;
+ int* sorted_indices_buf = (int*)(values_buf + n);
+ int* sample_indices_buf = sorted_indices_buf + n;
+ const float* values = 0;
+ const int* sorted_indices = 0;
+ data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
// LL - number of samples that both the primary and the surrogate splits send to the left
// LR - ... primary split sends to the left and the surrogate split sends to the right
// RL - ... primary split sends to the right and the surrogate split sends to the left
// RR - ... both send to the right
int i, best_i = -1, best_inversed = 0;
- double best_val;
+ double best_val;
if( !data->have_priors )
{
int LL = 0, RL = 0, LR, RR;
int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
int sum = 0, sum_abs = 0;
-
+
for( i = 0; i < n1; i++ )
{
- int d = dir[sorted[i].i];
+ int d = dir[sorted_indices[i]];
sum += d; sum_abs += d & 1;
}
// now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
for( i = 0; i < n1 - 1; i++ )
{
- int d = dir[sorted[i].i];
+ int d = dir[sorted_indices[i]];
if( d < 0 )
{
LL++; LR--;
- if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
+ if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
{
best_val = LL + RR;
best_i = i; best_inversed = 0;
else if( d > 0 )
{
RL++; RR--;
- if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
+ if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
{
best_val = RL + LR;
best_i = i; best_inversed = 1;
double LL = 0, RL = 0, LR, RR;
double worst_val = node->maxlr;
double sum = 0, sum_abs = 0;
- const double* priors = data->priors->data.db;
- const int* responses = data->get_class_labels(node);
+ const double* priors = data->priors_mult->data.db;
+ int* responses_buf = sample_indices_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
best_val = worst_val;
-
+
for( i = 0; i < n1; i++ )
{
- int idx = sorted[i].i;
+ int idx = sorted_indices[i];
double w = priors[responses[idx]];
int d = dir[idx];
sum += d*w; sum_abs += (d & 1)*w;
// now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
for( i = 0; i < n1 - 1; i++ )
{
- int idx = sorted[i].i;
+ int idx = sorted_indices[i];
double w = priors[responses[idx]];
int d = dir[idx];
if( d < 0 )
{
LL += w; LR -= w;
- if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
+ if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
{
best_val = LL + RR;
best_i = i; best_inversed = 0;
else if( d > 0 )
{
RL += w; RR -= w;
- if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
+ if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
{
best_val = RL + LR;
best_i = i; best_inversed = 1;
}
}
}
-
- return best_i >= 0 ? data->new_split_ord( vi,
- (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
- best_inversed, (float)best_val ) : 0;
+ return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
+ (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
}
-CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
+CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
{
- const int* labels = data->get_cat_var_data(node, vi);
const char* dir = (char*)data->direction->data.ptr;
int n = node->sample_count;
+ int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
+
+ int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
+ cv::AutoBuffer<uchar> inn_buf(base_size);
+ if( !_ext_buf )
+ inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
+ uchar* base_buf = (uchar*)inn_buf;
+ uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
+
+ int* labels_buf = (int*)ext_buf;
+ const int* labels = data->get_cat_var_data(node, vi, labels_buf);
// LL - number of samples that both the primary and the surrogate splits send to the left
// LR - ... primary split sends to the left and the surrogate split sends to the right
// RL - ... primary split sends to the right and the surrogate split sends to the left
// RR - ... both send to the right
CvDTreeSplit* split = data->new_split_cat( vi, 0 );
- int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
double best_val = 0;
- double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
+ double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
double* rc = lc + mi + 1;
-
+
for( i = -1; i < mi; i++ )
lc[i] = rc[i] = 0;
// sent to the left (lc) and to the right (rc) by the primary split
if( !data->have_priors )
{
- int* _lc = data->counts->data.i + 1;
+ int* _lc = (int*)rc + 1;
int* _rc = _lc + mi + 1;
for( i = -1; i < mi; i++ )
for( i = 0; i < n; i++ )
{
- int idx = labels[i];
+ int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
int d = dir[i];
int sum = _lc[idx] + d;
int sum_abs = _rc[idx] + (d & 1);
}
else
{
- const double* priors = data->priors->data.db;
- const int* responses = data->get_class_labels(node);
+ const double* priors = data->priors_mult->data.db;
+ int* responses_buf = labels_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
for( i = 0; i < n; i++ )
{
- int idx = labels[i];
+ int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
double w = priors[responses[i]];
int d = dir[i];
double sum = lc[idx] + d*w;
{
split->subset[i >> 5] |= 1 << (i & 31);
best_val += lval;
+ l_win++;
}
else
best_val += rval;
}
split->quality = (float)best_val;
- if( split->quality <= node->maxlr )
+ if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
cvSetRemoveByPtr( data->split_heap, split ), split = 0;
return split;
void CvDTree::calc_node_value( CvDTreeNode* node )
{
int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
- const int* cv_labels = data->get_cv_labels(node);
+ int m = data->get_num_classes();
+
+ int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
+ int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
+ cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
+ uchar* base_buf = (uchar*)inn_buf;
+ uchar* ext_buf = base_buf + base_size;
+
+ int* cv_labels_buf = (int*)ext_buf;
+ const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
if( data->is_classifier )
{
// compute the number of instances of each class
int* cls_count = data->counts->data.i;
- const int* responses = data->get_class_labels(node);
- int m = data->get_num_classes();
- int* cv_cls_count = cls_count + m;
+ int* responses_buf = cv_labels_buf + n;
+ const int* responses = data->get_class_labels(node, responses_buf);
+ int* cv_cls_count = (int*)base_buf;
double max_val = -1, total_weight = 0;
int max_k = -1;
- double* priors = data->priors->data.db;
+ double* priors = data->priors_mult->data.db;
for( k = 0; k < m; k++ )
cls_count[k] = 0;
cls_count[k] += cv_cls_count[j*m + k];
}
+ if( data->have_priors && node->parent == 0 )
+ {
+ // compute priors_mult from priors, take the sample ratio into account.
+ double sum = 0;
+ for( k = 0; k < m; k++ )
+ {
+ int n_k = cls_count[k];
+ priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
+ sum += priors[k];
+ }
+ sum = 1./sum;
+ for( k = 0; k < m; k++ )
+ priors[k] *= sum;
+ }
+
for( k = 0; k < m; k++ )
{
double val = cls_count[k]*priors[k];
// over the samples with cv_labels(*)==j.
double sum = 0, sum2 = 0;
- const float* values = data->get_ord_responses(node);
+ float* values_buf = (float*)(cv_labels_buf + n);
+ int* sample_indices_buf = (int*)(values_buf + n);
+ const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
double *cv_sum = 0, *cv_sum2 = 0;
int* cv_count = 0;
-
+
if( cv_n == 0 )
{
- // if cross-validation is not used, we even do not compute node_risk
- // (so the tree sequence T1>...>{root} may not be built).
for( i = 0; i < n; i++ )
- sum += values[i];
+ {
+ double t = values[i];
+ sum += t;
+ sum2 += t*t;
+ }
}
else
{
- cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
- cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
- cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
+ cv_sum = (double*)base_buf;
+ cv_sum2 = cv_sum + cv_n;
+ cv_count = (int*)(cv_sum2 + cv_n);
for( j = 0; j < cv_n; j++ )
{
sum += cv_sum[j];
sum2 += cv_sum2[j];
}
-
- node->node_risk = sum2 - (sum/n)*sum;
}
+ node->node_risk = sum2 - (sum/n)*sum;
node->value = sum/n;
for( j = 0; j < cv_n; j++ )
// try to complete direction using surrogate splits
if( nz && data->params.use_surrogates )
{
+ cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
CvDTreeSplit* split = node->split->next;
for( ; split != 0 && nz; split = split->next )
{
if( data->get_var_type(vi) >= 0 ) // split on categorical var
{
- const int* labels = data->get_cat_var_data(node, vi);
+ int* labels_buf = (int*)(uchar*)inn_buf;
+ const int* labels = data->get_cat_var_data(node, vi, labels_buf);
const int* subset = split->subset;
for( i = 0; i < n; i++ )
{
- int idx;
- if( !dir[i] && (idx = labels[i]) >= 0 )
+ int idx = labels[i];
+ if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
+
{
- int d = DTREE_CAT_DIR(idx,subset);
+ int d = CV_DTREE_CAT_DIR(idx,subset);
dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
if( --nz )
break;
}
else // split on ordered var
{
- const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
+ float* values_buf = (float*)(uchar*)inn_buf;
+ int* sorted_indices_buf = (int*)(values_buf + n);
+ int* sample_indices_buf = sorted_indices_buf + n;
+ const float* values = 0;
+ const int* sorted_indices = 0;
+ data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
int split_point = split->ord.split_point;
int n1 = node->get_num_valid(vi);
for( i = 0; i < n1; i++ )
{
- int idx = sorted[i].i;
+ int idx = sorted_indices[i];
if( !dir[idx] )
{
int d = i <= split_point ? -1 : 1;
void CvDTree::split_node_data( CvDTreeNode* node )
{
- int vi, i, n = node->sample_count, nl, nr;
+ int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
char* dir = (char*)data->direction->data.ptr;
CvDTreeNode *left = 0, *right = 0;
int* new_idx = data->split_buf->data.i;
int new_buf_idx = data->get_child_buf_idx( node );
+ int work_var_count = data->get_work_var_count();
+ CvMat* buf = data->buf;
+ cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
+ int* temp_buf = (int*)(uchar*)inn_buf;
complete_node_dir(node);
nl += d^1;
}
+ bool split_input_data;
node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
- node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
- (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
+ node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
+
+ split_input_data = node->depth + 1 < data->params.max_depth &&
+ (node->left->sample_count > data->params.min_sample_count ||
+ node->right->sample_count > data->params.min_sample_count);
// split ordered variables, keep both halves sorted.
for( vi = 0; vi < data->var_count; vi++ )
{
int ci = data->get_var_type(vi);
- int n1 = node->get_num_valid(vi);
- CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
- CvPair32s32f tl, tr;
- if( ci >= 0 )
+ if( ci >= 0 || !split_input_data )
continue;
- src = data->get_ord_var_data(node, vi);
- ldst0 = ldst = data->get_ord_var_data(left, vi);
- rdst0 = rdst = data->get_ord_var_data(right, vi);
- tl = ldst0[nl]; tr = rdst0[nr];
+ int n1 = node->get_num_valid(vi);
+ float* src_val_buf = (float*)(uchar*)(temp_buf + n);
+ int* src_sorted_idx_buf = (int*)(src_val_buf + n);
+ int* src_sample_idx_buf = src_sorted_idx_buf + n;
+ const float* src_val = 0;
+ const int* src_sorted_idx = 0;
+ data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
+
+ for(i = 0; i < n; i++)
+ temp_buf[i] = src_sorted_idx[i];
+
+ if (data->is_buf_16u)
+ {
+ unsigned short *ldst, *rdst, *ldst0, *rdst0;
+ //unsigned short tl, tr;
+ ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
+ vi*scount + left->offset);
+ rdst0 = rdst = (unsigned short*)(ldst + nl);
+
+ // split sorted
+ for( i = 0; i < n1; i++ )
+ {
+ int idx = temp_buf[i];
+ int d = dir[idx];
+ idx = new_idx[idx];
+ if (d)
+ {
+ *rdst = (unsigned short)idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = (unsigned short)idx;
+ ldst++;
+ }
+ }
- // split sorted
- for( i = 0; i < n1; i++ )
- {
- int idx = src[i].i;
- float val = src[i].val;
- int d = dir[idx];
- idx = new_idx[idx];
- ldst->i = rdst->i = idx;
- ldst->val = rdst->val = val;
- ldst += d^1;
- rdst += d;
+ left->set_num_valid(vi, (int)(ldst - ldst0));
+ right->set_num_valid(vi, (int)(rdst - rdst0));
+
+ // split missing
+ for( ; i < n; i++ )
+ {
+ int idx = temp_buf[i];
+ int d = dir[idx];
+ idx = new_idx[idx];
+ if (d)
+ {
+ *rdst = (unsigned short)idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = (unsigned short)idx;
+ ldst++;
+ }
+ }
}
+ else
+ {
+ int *ldst0, *ldst, *rdst0, *rdst;
+ ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols +
+ vi*scount + left->offset;
+ rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols +
+ vi*scount + right->offset;
- left->set_num_valid(vi, (int)(ldst - ldst0));
- right->set_num_valid(vi, (int)(rdst - rdst0));
+ // split sorted
+ for( i = 0; i < n1; i++ )
+ {
+ int idx = temp_buf[i];
+ int d = dir[idx];
+ idx = new_idx[idx];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
- // split missing
- for( ; i < n; i++ )
- {
- int idx = src[i].i;
- int d = dir[idx];
- idx = new_idx[idx];
- ldst->i = rdst->i = idx;
- ldst->val = rdst->val = ord_nan;
- ldst += d^1;
- rdst += d;
- }
+ left->set_num_valid(vi, (int)(ldst - ldst0));
+ right->set_num_valid(vi, (int)(rdst - rdst0));
- ldst0[nl] = tl; rdst0[nr] = tr;
+ // split missing
+ for( ; i < n; i++ )
+ {
+ int idx = temp_buf[i];
+ int d = dir[idx];
+ idx = new_idx[idx];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
+ }
}
// split categorical vars, responses and cv_labels using new_idx relocation table
- for( vi = 0; vi <= data->var_count + data->have_cv_labels + data->have_weights; vi++ )
+ for( vi = 0; vi < work_var_count; vi++ )
{
int ci = data->get_var_type(vi);
int n1 = node->get_num_valid(vi), nr1 = 0;
- int *src, *ldst0, *rdst0, *ldst, *rdst;
- int tl, tr;
-
- if( ci < 0 )
+
+ if( ci < 0 || (vi < data->var_count && !split_input_data) )
continue;
- src = data->get_cat_var_data(node, vi);
- ldst0 = ldst = data->get_cat_var_data(left, vi);
- rdst0 = rdst = data->get_cat_var_data(right, vi);
- tl = ldst0[nl]; tr = rdst0[nr];
+ int *src_lbls_buf = temp_buf + n;
+ const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
- for( i = 0; i < n; i++ )
- {
- int d = dir[i];
- int val = src[i];
- *ldst = *rdst = val;
- ldst += d^1;
- rdst += d;
- nr1 += (val >= 0)&d;
- }
+ for(i = 0; i < n; i++)
+ temp_buf[i] = src_lbls[i];
- if( vi < data->var_count )
+ if (data->is_buf_16u)
{
- left->set_num_valid(vi, n1 - nr1);
- right->set_num_valid(vi, nr1);
+ unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols +
+ vi*scount + left->offset);
+ unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols +
+ vi*scount + right->offset);
+
+ for( i = 0; i < n; i++ )
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = (unsigned short)idx;
+ rdst++;
+ nr1 += (idx != 65535 )&d;
+ }
+ else
+ {
+ *ldst = (unsigned short)idx;
+ ldst++;
+ }
+ }
+
+ if( vi < data->var_count )
+ {
+ left->set_num_valid(vi, n1 - nr1);
+ right->set_num_valid(vi, nr1);
+ }
}
+ else
+ {
+ int *ldst = buf->data.i + left->buf_idx*buf->cols +
+ vi*scount + left->offset;
+ int *rdst = buf->data.i + right->buf_idx*buf->cols +
+ vi*scount + right->offset;
+
+ for( i = 0; i < n; i++ )
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ nr1 += (idx >= 0)&d;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+
+ }
- ldst0[nl] = tl; rdst0[nr] = tr;
+ if( vi < data->var_count )
+ {
+ left->set_num_valid(vi, n1 - nr1);
+ right->set_num_valid(vi, nr1);
+ }
+ }
}
+
+ // split sample indices
+ int *sample_idx_src_buf = temp_buf + n;
+ const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
+
+ for(i = 0; i < n; i++)
+ temp_buf[i] = sample_idx_src[i];
+
+ int pos = data->get_work_var_count();
+ if (data->is_buf_16u)
+ {
+ unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols +
+ pos*scount + left->offset);
+ unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols +
+ pos*scount + right->offset);
+ for (i = 0; i < n; i++)
+ {
+ int d = dir[i];
+ unsigned short idx = (unsigned short)temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
+ }
+ else
+ {
+ int* ldst = buf->data.i + left->buf_idx*buf->cols +
+ pos*scount + left->offset;
+ int* rdst = buf->data.i + right->buf_idx*buf->cols +
+ pos*scount + right->offset;
+ for (i = 0; i < n; i++)
+ {
+ int d = dir[i];
+ int idx = temp_buf[i];
+ if (d)
+ {
+ *rdst = idx;
+ rdst++;
+ }
+ else
+ {
+ *ldst = idx;
+ ldst++;
+ }
+ }
+ }
+
// deallocate the parent node data that is not needed anymore
- data->free_node_data(node);
+ data->free_node_data(node);
}
+float CvDTree::calc_error( CvMLData* _data, int type, vector<float> *resp )
+{
+ float err = 0;
+ const CvMat* values = _data->get_values();
+ const CvMat* response = _data->get_responses();
+ const CvMat* missing = _data->get_missing();
+ const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
+ const CvMat* var_types = _data->get_var_types();
+ int* sidx = sample_idx ? sample_idx->data.i : 0;
+ int r_step = CV_IS_MAT_CONT(response->type) ?
+ 1 : response->step / CV_ELEM_SIZE(response->type);
+ bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
+ int sample_count = sample_idx ? sample_idx->cols : 0;
+ sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
+ float* pred_resp = 0;
+ if( resp && (sample_count > 0) )
+ {
+ resp->resize( sample_count );
+ pred_resp = &((*resp)[0]);
+ }
+
+ if ( is_classifier )
+ {
+ for( int i = 0; i < sample_count; i++ )
+ {
+ CvMat sample, miss;
+ int si = sidx ? sidx[i] : i;
+ cvGetRow( values, &sample, si );
+ if( missing )
+ cvGetRow( missing, &miss, si );
+ float r = (float)predict( &sample, missing ? &miss : 0 )->value;
+ if( pred_resp )
+ pred_resp[i] = r;
+ int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
+ err += d;
+ }
+ err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
+ }
+ else
+ {
+ for( int i = 0; i < sample_count; i++ )
+ {
+ CvMat sample, miss;
+ int si = sidx ? sidx[i] : i;
+ cvGetRow( values, &sample, si );
+ if( missing )
+ cvGetRow( missing, &miss, si );
+ float r = (float)predict( &sample, missing ? &miss : 0 )->value;
+ if( pred_resp )
+ pred_resp[i] = r;
+ float d = r - response->data.fl[si*r_step];
+ err += d*d;
+ }
+ err = sample_count ? err / (float)sample_count : -FLT_MAX;
+ }
+ return err;
+}
void CvDTree::prune_cv()
{
CvMat* ab = 0;
CvMat* temp = 0;
CvMat* err_jk = 0;
-
+
// 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
// 2. choose the best tree index (if need, apply 1SE rule).
// 3. store the best index and cut the branches.
double* err;
double min_err = 0, min_err_se = 0;
int min_idx = -1;
-
+
CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
// build the main tree sequence, calculate alpha's
}
ab->data.db[0] = 0.;
- for( ti = 1; ti < tree_count-1; ti++ )
- ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
- ab->data.db[tree_count-1] = DBL_MAX*0.5;
-
- CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
- err = err_jk->data.db;
- for( j = 0; j < cv_n; j++ )
+ if( tree_count > 0 )
{
- int tj = 0, tk = 0;
- for( ; tk < tree_count; tj++ )
- {
- double min_alpha = update_tree_rnc(tj, j);
- if( cut_tree(tj, j, min_alpha) )
- min_alpha = DBL_MAX;
+ for( ti = 1; ti < tree_count-1; ti++ )
+ ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
+ ab->data.db[tree_count-1] = DBL_MAX*0.5;
- for( ; tk < tree_count; tk++ )
+ CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
+ err = err_jk->data.db;
+
+ for( j = 0; j < cv_n; j++ )
+ {
+ int tj = 0, tk = 0;
+ for( ; tk < tree_count; tj++ )
{
- if( ab->data.db[tk] > min_alpha )
- break;
- err[j*tree_count + tk] = root->tree_error;
+ double min_alpha = update_tree_rnc(tj, j);
+ if( cut_tree(tj, j, min_alpha) )
+ min_alpha = DBL_MAX;
+
+ for( ; tk < tree_count; tk++ )
+ {
+ if( ab->data.db[tk] > min_alpha )
+ break;
+ err[j*tree_count + tk] = root->tree_error;
+ }
}
}
- }
- for( ti = 0; ti < tree_count; ti++ )
- {
- double sum_err = 0;
- for( j = 0; j < cv_n; j++ )
- sum_err += err[j*tree_count + ti];
- if( ti == 0 || sum_err < min_err )
+ for( ti = 0; ti < tree_count; ti++ )
{
- min_err = sum_err;
- min_idx = ti;
- if( use_1se )
- min_err_se = sqrt( sum_err*(n - sum_err) );
+ double sum_err = 0;
+ for( j = 0; j < cv_n; j++ )
+ sum_err += err[j*tree_count + ti];
+ if( ti == 0 || sum_err < min_err )
+ {
+ min_err = sum_err;
+ min_idx = ti;
+ if( use_1se )
+ min_err_se = sqrt( sum_err*(n - sum_err) );
+ }
+ else if( sum_err < min_err + min_err_se )
+ min_idx = ti;
}
- else if( sum_err < min_err + min_err_se )
- min_idx = ti;
}
pruned_tree_idx = min_idx;
{
CvDTreeNode* node = root;
double min_alpha = DBL_MAX;
-
+
for(;;)
{
CvDTreeNode* parent;
}
node = node->left;
}
-
+
for( parent = node->parent; parent && parent->right == node;
node = parent, parent = parent->parent )
{
}
node = node->left;
}
-
+
for( parent = node->parent; parent && parent->right == node;
node = parent, parent = parent->parent )
;
void CvDTree::free_prune_data(bool cut_tree)
{
CvDTreeNode* node = root;
-
+
for(;;)
{
CvDTreeNode* parent;
break;
node = node->left;
}
-
+
for( parent = node->parent; parent && parent->right == node;
node = parent, parent = parent->parent )
{
}
}
-
CvDTreeNode* CvDTree::predict( const CvMat* _sample,
const CvMat* _missing, bool preprocessed_input ) const
{
CV_ERROR( CV_StsError, "The tree has not been trained yet" );
if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
- _sample->cols != 1 && _sample->rows != 1 ||
- _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
- _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
+ (_sample->cols != 1 && _sample->rows != 1) ||
+ (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
+ (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
CV_ERROR( CV_StsBadArg,
"the input sample must be 1d floating-point vector with the same "
"number of elements as the total number of variables used for training" );
sample = _sample->data.fl;
- step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
+ step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
if( data->cat_count && !preprocessed_input ) // cache for categorical variables
{
vtype = data->var_type->data.i;
vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
- cmap = data->cat_map->data.i;
- cofs = data->cat_ofs->data.i;
+ cmap = data->cat_map ? data->cat_map->data.i : 0;
+ cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
while( node->Tn > pruned_tree_idx && node->left )
{
if( c < 0 )
{
int a = c = cofs[ci];
- int b = cofs[ci+1];
+ int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
+
int ival = cvRound(val);
if( ival != val )
CV_ERROR( CV_StsBadArg,
"one of input categorical variable is not an integer" );
-
+
+ int sh = 0;
while( a < b )
{
+ sh++;
c = (a + b) >> 1;
if( ival < cmap[c] )
b = c;
catbuf[ci] = c -= cofs[ci];
}
}
- dir = DTREE_CAT_DIR(c, split->subset);
+ c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
+ dir = CV_DTREE_CAT_DIR(c, split->subset);
}
if( split->inversed )
}
+CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
+{
+ CvMat sample = _sample, mmask = _missing;
+ return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
+}
+
+
const CvMat* CvDTree::get_var_importance()
{
if( !var_importance )
for( ;; node = node->left )
{
CvDTreeSplit* split = node->split;
-
+
if( !node->left || node->Tn <= pruned_tree_idx )
break;
-
+
for( ; split != 0; split = split->next )
importance[split->var_idx] += split->quality;
}
}
-void CvDTree::write_train_data_params( CvFileStorage* fs )
-{
- CV_FUNCNAME( "CvDTree::write_train_data_params" );
-
- __BEGIN__;
-
- int vi, vcount = data->var_count;
-
- cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
- cvWriteInt( fs, "var_all", data->var_all );
- cvWriteInt( fs, "var_count", data->var_count );
- cvWriteInt( fs, "ord_var_count", data->ord_var_count );
- cvWriteInt( fs, "cat_var_count", data->cat_var_count );
-
- cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
- cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
-
- if( data->is_classifier )
- {
- cvWriteInt( fs, "max_categories", data->params.max_categories );
- }
- else
- {
- cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
- }
-
- cvWriteInt( fs, "max_depth", data->params.max_depth );
- cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
- cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
-
- if( data->params.cv_folds > 1 )
- {
- cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
- cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
- }
-
- if( data->priors )
- cvWrite( fs, "priors", data->priors );
-
- cvEndWriteStruct( fs );
-
- if( data->var_idx )
- cvWrite( fs, "var_idx", data->var_idx );
-
- cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
-
- for( vi = 0; vi < vcount; vi++ )
- cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
-
- cvEndWriteStruct( fs );
-
- if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
- {
- CV_ASSERT( data->cat_count != 0 );
- cvWrite( fs, "cat_count", data->cat_count );
- cvWrite( fs, "cat_map", data->cat_map );
- }
-
- __END__;
-}
-
-
-void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
+void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
{
int ci;
-
+
cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
cvWriteInt( fs, "var", split->var_idx );
cvWriteReal( fs, "quality", split->quality );
{
int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
for( i = 0; i < n; i++ )
- to_right += DTREE_CAT_DIR(i,split->subset) > 0;
+ to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
// ad-hoc rule when to use inverse categorical split notation
// to achieve more compact and clear representation
default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
-
+
cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
"in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
+
for( i = 0; i < n; i++ )
{
- int dir = DTREE_CAT_DIR(i,split->subset);
+ int dir = CV_DTREE_CAT_DIR(i,split->subset);
if( dir*default_dir < 0 )
cvWriteInt( fs, 0, i );
}
}
-void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
+void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
{
CvDTreeSplit* split;
-
+
cvStartWriteStruct( fs, 0, CV_NODE_MAP );
cvWriteInt( fs, "depth", node->depth );
cvWriteInt( fs, "sample_count", node->sample_count );
cvWriteReal( fs, "value", node->value );
-
+
if( data->is_classifier )
cvWriteInt( fs, "norm_class_idx", node->class_idx );
}
-void CvDTree::write_tree_nodes( CvFileStorage* fs )
+void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
{
//CV_FUNCNAME( "CvDTree::write_tree_nodes" );
break;
node = node->left;
}
-
+
for( parent = node->parent; parent && parent->right == node;
node = parent, parent = parent->parent )
;
}
-void CvDTree::write( CvFileStorage* fs, const char* name )
+void CvDTree::write( CvFileStorage* fs, const char* name ) const
{
//CV_FUNCNAME( "CvDTree::write" );
cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
- write_train_data_params( fs );
-
- cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
- get_var_importance();
- cvWrite( fs, "var_importance", var_importance );
-
- cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
- write_tree_nodes( fs );
- cvEndWriteStruct( fs );
+ //get_var_importance();
+ data->write_params( fs );
+ //if( var_importance )
+ //cvWrite( fs, "var_importance", var_importance );
+ write( fs );
cvEndWriteStruct( fs );
}
-void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
+void CvDTree::write( CvFileStorage* fs ) const
{
- CV_FUNCNAME( "CvDTree::read_train_data_params" );
+ //CV_FUNCNAME( "CvDTree::write" );
__BEGIN__;
-
- CvDTreeParams params;
- CvFileNode *tparams_node, *vartype_node;
- CvSeqReader reader;
- int is_classifier, vi, cat_var_count, ord_var_count;
- int max_split_size, tree_block_size;
-
- data = new CvDTreeTrainData;
-
- is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
- data->is_classifier = is_classifier;
- data->var_all = cvReadIntByName( fs, node, "var_all" );
- data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
- data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
- data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
-
- tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
-
- if( tparams_node ) // training parameters are not necessary
- {
- data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
-
- if( is_classifier )
- {
- data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
- }
- else
- {
- data->params.regression_accuracy =
- (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
- }
-
- data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
- data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
- data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
-
- if( data->params.cv_folds > 1 )
- {
- data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
- data->params.truncate_pruned_tree =
- cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
- }
-
- data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
- if( data->priors && !CV_IS_MAT(data->priors) )
- CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
- }
-
- CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
- if( data->var_idx )
- {
- if( !CV_IS_MAT(data->var_idx) ||
- data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
- data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
- CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
- CV_ERROR( CV_StsParseError,
- "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
-
- for( vi = 0; vi < data->var_count; vi++ )
- if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
- CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
- }
-
- ////// read var type
- CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
-
- vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
- if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
- vartype_node->data.seq->total != data->var_count )
- CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
-
- cvStartReadSeq( vartype_node->data.seq, &reader );
-
- cat_var_count = 0;
- ord_var_count = -1;
- for( vi = 0; vi < data->var_count; vi++ )
- {
- CvFileNode* n = (CvFileNode*)reader.ptr;
- if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
- CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
- data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
- CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
- }
-
- ord_var_count = ~ord_var_count;
- if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
- CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
- //////
-
- if( data->cat_var_count > 0 || is_classifier )
- {
- int ccount, max_c_count = 0, total_c_count = 0;
- CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
- CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
-
- if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
- data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
- CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
- data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
- data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
- CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
- CV_ERROR( CV_StsParseError,
- "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
-
- ccount = cat_var_count + is_classifier;
-
- CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
- data->cat_ofs->data.i[0] = 0;
- for( vi = 0; vi < ccount; vi++ )
- {
- int val = data->cat_count->data.i[vi];
- if( val <= 0 )
- CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
- max_c_count = MAX( max_c_count, val );
- data->cat_ofs->data.i[vi+1] = total_c_count += val;
- }
-
- if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
- CV_ERROR( CV_StsBadSize,
- "cat_map vector length is not equal to the total number of categories in all categorical vars" );
-
- data->max_c_count = max_c_count;
- }
-
- max_split_size = cvAlign(sizeof(CvDTreeSplit) +
- (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
+ cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
- tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
- tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
- CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
- CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
- sizeof(CvDTreeNode), data->tree_storage ));
- CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
- max_split_size, data->tree_storage ));
+ cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
+ write_tree_nodes( fs );
+ cvEndWriteStruct( fs );
__END__;
}
CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
{
CvDTreeSplit* split = 0;
-
+
CV_FUNCNAME( "CvDTree::read_split" );
__BEGIN__;
int vi, ci;
-
+
if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
ci = data->get_var_type(vi);
if( ci >= 0 ) // split on categorical var
{
- int i, n = data->cat_count->data.i[ci], inversed = 0;
+ int i, n = data->cat_count->data.i[ci], inversed = 0, val;
CvSeqReader reader;
CvFileNode* inseq;
split = data->new_split_cat( vi, 0 );
inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
inversed = 1;
}
- if( !inseq || CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ )
+ if( !inseq ||
+ (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
CV_ERROR( CV_StsParseError,
"Either 'in' or 'not_in' tags should be inside a categorical split data" );
- cvStartReadSeq( inseq->data.seq, &reader );
-
- for( i = 0; i < reader.seq->total; i++ )
+ if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
{
- CvFileNode* inode = (CvFileNode*)reader.ptr;
- int val = inode->data.i;
- if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
+ val = inseq->data.i;
+ if( (unsigned)val >= (unsigned)n )
CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
split->subset[val >> 5] |= 1 << (val & 31);
- CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
+ }
+ else
+ {
+ cvStartReadSeq( inseq->data.seq, &reader );
+
+ for( i = 0; i < reader.seq->total; i++ )
+ {
+ CvFileNode* inode = (CvFileNode*)reader.ptr;
+ val = inode->data.i;
+ if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
+ CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
+
+ split->subset[val >> 5] |= 1 << (val & 31);
+ CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
+ }
}
// for categorical splits we do not use inversed splits,
split->ord.c = (float)cvReadReal( cmp_node );
}
-
+
split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
__END__;
-
+
return split;
}
CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
{
CvDTreeNode* node = 0;
-
+
CV_FUNCNAME( "CvDTree::read_node" );
__BEGIN__;
}
__END__;
-
+
return node;
}
for( i = 0; i < reader.seq->total; i++ )
{
CvDTreeNode* node;
-
+
CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
if( !parent->left )
parent->left = node;
void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
+{
+ CvDTreeTrainData* _data = new CvDTreeTrainData();
+ _data->read_params( fs, fnode );
+
+ read( fs, fnode, _data );
+ get_var_importance();
+}
+
+
+// a special entry point for reading weak decision trees from the tree ensembles
+void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
{
CV_FUNCNAME( "CvDTree::read" );
CvFileNode* tree_nodes;
clear();
- read_train_data_params( fs, fnode );
+ data = _data;
- tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
+ tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
CV_ERROR( CV_StsParseError, "nodes tag is missing" );
- pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
-
+ pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
read_tree_nodes( fs, tree_nodes );
- get_var_importance(); // recompute variable importance
__END__;
}