const int dims = data->var_count;\r
float maximal_response = 0;\r
\r
+#define RF_OOB\r
+#ifdef RF_OOB\r
// oob_predictions_sum[i] = sum of predicted values for the i-th sample\r
// oob_num_of_predictions[i] = number of summands\r
// (number of predictions for the i-th sample)\r
// initialize these variable to avoid warning C4701\r
CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );\r
CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );\r
-\r
+#endif\r
nsamples = data->sample_count;\r
nclasses = data->get_num_classes();\r
\r
trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
\r
+#ifdef RF_OOB\r
if( data->is_classifier )\r
{\r
CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));\r
cvGetRow( oob_responses, &oob_predictions_sum, 0 );\r
cvGetRow( oob_responses, &oob_num_of_predictions, 1 );\r
}\r
+#endif\r
CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 ));\r
CV_CALL(sample_idx_for_tree = cvCreateMat( 1, nsamples, CV_32SC1 ));\r
+#ifdef RF_OOB\r
CV_CALL(oob_samples_perm_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
CV_CALL(samples_ptr = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
CV_CALL(missing_ptr = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));\r
cvMinMaxLoc( &responses, &minval, &maxval );\r
maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );\r
}\r
+#endif\r
\r
ntrees = 0;\r
while( ntrees < max_ntrees )\r
trees[ntrees] = new CvForestTree();\r
tree = trees[ntrees];\r
CV_CALL(tree->train( data, sample_idx_for_tree, this ));\r
-#define RF_OOB\r
+\r
#ifdef RF_OOB\r
CvMat sample, missing;\r
// form array of OOB samples indices and get these samples\r
\r
cvReleaseMat( &sample_idx_mask_for_tree );\r
cvReleaseMat( &sample_idx_for_tree );\r
+\r
+#ifdef RF_OOB\r
cvReleaseMat( &oob_sample_votes );\r
cvReleaseMat( &oob_responses );\r
\r
cvFree( &samples_ptr );\r
cvFree( &missing_ptr );\r
cvFree( &true_resp_ptr );\r
+#endif\r
\r
return result;\r
}\r
int* best_l_sample_idx = (int*)cvAlloc(n*sizeof(best_l_sample_idx[0]));\r
int best_ln = 0;\r
int* best_r_sample_idx = (int*)cvAlloc(n*sizeof(best_r_sample_idx[0]));\r
- int best_rn = 0;
+ int best_rn = 0;\r
int* tsidx = 0;\r
-
+\r
if( forest )\r
{\r
int var_count;\r
{\r
if( ci >= 0)\r
{\r
- CV_ERROR( CV_StsBadArg, "ERTrees do not support categorical variables" );\r
+ split = find_split_cat_class( node, vi); \r
}\r
else\r
{ \r
const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
double best_val = 0;\r
\r
- const float* pred = ((CvERTreeTrainData*)data)->pred->data.fl;\r
+ const float* ord_pred = ((CvERTreeTrainData*)data)->ord_pred->data.fl;\r
+ int ci = data->get_var_type(vi);\r
+ const int idx = ((CvERTreeTrainData*)data)->get_ord_var_idx(ci);\r
const int* sidx = ((CvERTreeNode*)node)->sample_idx;\r
const int* resp = ((CvERTreeTrainData*)data)->resp->data.i;\r
- int pstep = ((CvERTreeTrainData*)data)->pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->pred->type);\r
+ int pstep = ((CvERTreeTrainData*)data)->ord_pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->ord_pred->type);\r
int rstep = ((CvERTreeTrainData*)data)->resp->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->resp->type);\r
bool is_find_split = false;\r
\r
if( !priors )\r
{\r
float pmin, pmax;\r
- pmin = pred[sidx[0]*pstep + vi];\r
+ pmin = ord_pred[sidx[0]*pstep + idx];\r
pmax = pmin;\r
for (int si = 1; si < n; si++)\r
{\r
- float ptemp = pred[sidx[si]*pstep + vi];\r
+ float ptemp = ord_pred[sidx[si]*pstep + idx];\r
if ( ptemp < pmin)\r
pmin = ptemp;\r
if ( ptemp > pmax)\r
for( int si = 0; si < n; si++ )\r
{\r
int r = resp[sidx[si]*rstep];\r
- if ((pred[sidx[si]*pstep + vi]) < split_val)\r
+ if ((ord_pred[sidx[si]*pstep + idx]) < split_val)\r
{\r
lc[r]++;\r
l_sind[L] = sidx[si];\r
return is_find_split ? data->new_split_ord( vi, (float)split_val, -1, 0, (float)best_val ) : 0;\r
}\r
\r
+CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi )\r
+{\r
+ int ci = data->get_var_type(vi);\r
+ int n = node->sample_count;\r
+ int cm = ((CvERTreeTrainData*)data)->get_num_classes(); \r
+ int vm = ((CvERTreeTrainData*)data)->cat_count->data.i[ci];\r
+ double best_val = 0;\r
+ CvDTreeSplit *split = 0;\r
+\r
+ if ( vm > 1 )\r
+ {\r
+ const int* cat_pred = ((CvERTreeTrainData*)data)->cat_pred->data.i;\r
+ const int* resp = ((CvERTreeTrainData*)data)->resp->data.i;\r
+ const int* sidx = ((CvERTreeNode*)node)->sample_idx;\r
+ int pstep = ((CvERTreeTrainData*)data)->cat_pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->cat_pred->type);\r
+ int rstep = ((CvERTreeTrainData*)data)->resp->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->resp->type);\r
+\r
+ int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));\r
+ int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));\r
+ \r
+ const double* priors = data->have_priors ? data->priors_mult->data.db : 0; \r
+\r
+ if( !priors )\r
+ {\r
+ int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));\r
+ for (int i = 0; i < vm; i++)\r
+ {\r
+ valid_cidx[i] = -1;\r
+ }\r
+ for (int si = 0; si < n; si++)\r
+ {\r
+ int c = cat_pred[sidx[si]*pstep + ci];\r
+ valid_cidx[c]++;\r
+ }\r
+\r
+ int valid_ccount = 0;\r
+ for (int i = 0; i < vm; i++)\r
+ if (valid_cidx[i] >= 0)\r
+ {\r
+ valid_cidx[i] = valid_ccount;\r
+ valid_ccount++;\r
+ }\r
+ if (valid_ccount > 1)\r
+ {\r
+ CvRNG* rng = forest->get_rng();\r
+ int lcv_count = 1 + cvRandInt(rng) % (valid_ccount-1);\r
+\r
+ CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );\r
+ CvMat submask;\r
+ memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));\r
+ cvGetCols( var_class_mask, &submask, 0, lcv_count );\r
+ cvSet( &submask, cvScalar(1) );\r
+ for (int i = 0; i < valid_ccount; i++)\r
+ {\r
+ uchar temp;\r
+ int i1 = cvRandInt( rng ) % valid_ccount;\r
+ int i2 = cvRandInt( rng ) % valid_ccount;\r
+ CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );\r
+ }\r
+\r
+ // init arrays of class instance counters on both sides of the split\r
+ for(int i = 0; i < cm; i++ )\r
+ {\r
+ lc[i] = 0;\r
+ rc[i] = 0;\r
+ }\r
+\r
+ if (!((CvERTreeNode*)node)->l_sample_idx)\r
+ {\r
+ ((CvERTreeNode*)node)->l_sample_idx = (int*)cvAlloc( sizeof(((CvERTreeNode*)node)->l_sample_idx[0])*n );\r
+ ((CvERTreeNode*)node)->ln = 0;\r
+ }\r
+ if (!((CvERTreeNode*)node)->r_sample_idx)\r
+ {\r
+ ((CvERTreeNode*)node)->r_sample_idx = (int*)cvAlloc( sizeof(((CvERTreeNode*)node)->r_sample_idx[0])*n );\r
+ ((CvERTreeNode*)node)->rn = 0;\r
+ }\r
+\r
+ int* l_sind = ((CvERTreeNode*)node)->l_sample_idx;\r
+ int* r_sind = ((CvERTreeNode*)node)->r_sample_idx;\r
+\r
+ split = data->new_split_cat( vi, -1. );\r
+\r
+ // calculate Gini index\r
+ double lbest_val = 0, rbest_val = 0;\r
+ int L = 0, R = 0;\r
+ \r
+ for( int si = 0; si < n; si++ )\r
+ {\r
+ int r = resp[sidx[si]*rstep];\r
+ int var_class_idx = cat_pred[sidx[si]*pstep + ci];\r
+ int mask_class_idx = valid_cidx[var_class_idx];\r
+ if (var_class_mask->data.ptr[mask_class_idx])\r
+ {\r
+ lc[r]++;\r
+ l_sind[L] = sidx[si];\r
+ L++; \r
+ split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);\r
+ }\r
+ else\r
+ {\r
+ rc[r]++;\r
+ r_sind[R] = sidx[si];\r
+ R++;\r
+ }\r
+ }\r
+ for (int i = 0; i < cm; i++)\r
+ {\r
+ lbest_val += lc[i]*lc[i];\r
+ rbest_val += rc[i]*rc[i];\r
+ }\r
+ lbest_val = lbest_val/L;\r
+ rbest_val = rbest_val/R;\r
+ best_val = lbest_val + rbest_val;\r
+\r
+ split->quality = (float)best_val;\r
+\r
+ ((CvERTreeNode*)node)->ln = L;\r
+ ((CvERTreeNode*)node)->rn = R;\r
+\r
+ cvReleaseMat(&var_class_mask);\r
+ } \r
+ }\r
+ } \r
+ \r
+ return split;\r
+}\r
+\r
void CvForestERTree::calc_node_value( CvDTreeNode* node )\r
{\r
CV_FUNCNAME("CvForestERTree::calc_node_value");\r
}\r
\r
node->class_idx = max_k;\r
- const int* ldata = ((CvERTreeTrainData*)data)->class_lables->data.i;\r
+ const int* ldata = ((CvERTreeTrainData*)data)->class_lables[data->cat_var_count]->data.i;\r
node->value = ldata[max_k]; \r
\r
node->node_risk = total_weight - max_val;\r
data->free_node_data(node);\r
}\r
\r
+CvDTreeNode* CvForestERTree::predict( const CvMat* _sample,\r
+ const CvMat* _missing, bool preprocessed_input ) const\r
+{\r
+ CvDTreeNode* result = 0;\r
+\r
+ CV_FUNCNAME( "CvForestERTree::predict" );\r
+\r
+ __BEGIN__;\r
+\r
+ int i, step, mstep = 0;\r
+ const float* sample;\r
+ const uchar* m = 0;\r
+ CvDTreeNode* node = root;\r
+ const int* vtype;\r
+ const int* vidx;\r
+\r
+ if( !node )\r
+ CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
+\r
+ if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
+ (_sample->cols != 1 && _sample->rows != 1) ||\r
+ (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
+ (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
+ CV_ERROR( CV_StsBadArg,\r
+ "the input sample must be 1d floating-point vector with the same "\r
+ "number of elements as the total number of variables used for training" );\r
+\r
+ sample = _sample->data.fl;\r
+ step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
+\r
+ if( _missing )\r
+ {\r
+ if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
+ !CV_ARE_SIZES_EQ(_missing, _sample) )\r
+ CV_ERROR( CV_StsBadArg,\r
+ "the missing data mask must be 8-bit vector of the same size as input sample" );\r
+ m = _missing->data.ptr;\r
+ mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
+ }\r
+\r
+ vtype = data->var_type->data.i;\r
+ vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
+\r
+ while( node->Tn > pruned_tree_idx && node->left )\r
+ {\r
+ CvDTreeSplit* split = node->split;\r
+ int dir = 0;\r
+ for( ; !dir && split != 0; split = split->next )\r
+ {\r
+ int vi = split->var_idx;\r
+ int ci = vtype[vi];\r
+ i = vidx ? vidx[vi] : vi;\r
+ float val = sample[i*step];\r
+ if( m && m[i*mstep] )\r
+ continue;\r
+ if( ci < 0 ) // ordered\r
+ dir = val <= split->ord.c ? -1 : 1;\r
+ else // categorical\r
+ {\r
+ int c;\r
+ int c_count = data->cat_count->data.i[ci];\r
+ if( preprocessed_input )\r
+ c = cvRound(val);\r
+ else\r
+ {\r
+ int ival = cvRound(val);\r
+ int i = 0;\r
+ int* labls = ((CvERTreeTrainData*)data)->class_lables[ci]->data.i;\r
+ if( ival != val )\r
+ CV_ERROR( CV_StsBadArg,\r
+ "one of input categorical variable is not an integer" );\r
+ for (i = 0; i < c_count; i++)\r
+ if (ival == labls[i]) break;\r
+ c = i;\r
+ }\r
+ if (c == c_count)\r
+ {\r
+ CvRNG* rng = &data->rng;\r
+ dir = 2*(cvRandInt(rng)%2)-1;\r
+ }\r
+ else\r
+ dir = CV_DTREE_CAT_DIR(c, split->subset);\r
+ }\r
+\r
+ if( split->inversed )\r
+ dir = -dir;\r
+ }\r
+\r
+ if( !dir )\r
+ {\r
+ double diff = node->right->sample_count - node->left->sample_count;\r
+ dir = diff < 0 ? -1 : 1;\r
+ }\r
+ node = dir < 0 ? node->left : node->right;\r
+ }\r
+\r
+ result = node;\r
+\r
+ __END__;\r
+\r
+ return result;\r
+}\r
+\r
+\r
bool CvERTrees::train( const CvMat* _train_data, int _tflag,\r
const CvMat* _responses, const CvMat* _var_idx,\r
const CvMat* _sample_idx, const CvMat* _var_type,\r
{\r
bool result = false;\r
\r
+ CV_FUNCNAME("CvERTrees::grow_forest");\r
+ __BEGIN__;\r
+\r
+ const int max_ntrees = term_crit.max_iter;\r
+ const double max_oob_err = term_crit.epsilon;\r
+\r
+ nsamples = data->sample_count;\r
+ nclasses = ((CvERTreeTrainData*)data)->get_num_classes();\r
+\r
+ trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
+\r
+ memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
+\r
+//#define ET_OOB\r
+#ifdef ET_OOB\r
CvMat* oob_sample_votes = 0;\r
CvMat* oob_responses = 0;\r
\r
uchar* missing_ptr = 0;\r
float* true_resp_ptr = 0;\r
\r
- CV_FUNCNAME("CvERTrees::grow_forest");\r
- __BEGIN__;\r
-\r
- const int max_ntrees = term_crit.max_iter;\r
- const double max_oob_err = term_crit.epsilon;\r
-\r
const int dims = data->var_count;\r
- float maximal_response = 0;\r
+ float maximal_response = 0; \r
\r
CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );\r
CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );\r
\r
- nsamples = data->sample_count;\r
- nclasses = ((CvERTreeTrainData*)data)->get_num_classes();\r
-\r
- trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
-\r
- memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
-\r
if( data->is_classifier )\r
{\r
CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));\r
cvMinMaxLoc( &responses, &minval, &maxval );\r
maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );\r
}\r
-\r
+#endif\r
ntrees = 0;\r
while( ntrees < max_ntrees )\r
{\r
- int oob_samples_count = 0;\r
- double ncorrect_responses = 0;\r
CvForestERTree* tree = 0;\r
\r
trees[ntrees] = new CvForestERTree();\r
tree = (CvForestERTree*)trees[ntrees];\r
CV_CALL(tree->train( data, 0, this ));\r
-#define ET_OOB\r
+\r
#ifdef ET_OOB\r
- int i;\r
+ int i, oob_samples_count = 0;\r
+ double ncorrect_responses = 0;\r
CvMat sample, missing;\r
sample = cvMat( 1, dims, CV_32FC1, samples_ptr );\r
missing = cvMat( 1, dims, CV_8UC1, missing_ptr );\r
// compute oob error\r
cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );\r
\r
- prdct_resp = ((CvERTreeTrainData*)data)->class_lables->data.i[max_loc.x];\r
+ prdct_resp = ((CvERTreeTrainData*)data)->class_lables[data->cat_var_count]->data.i[max_loc.x];\r
oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;\r
\r
ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;\r
\r
__END__;\r
\r
+#ifdef ET_OOB\r
cvReleaseMat( &oob_sample_votes );\r
cvReleaseMat( &oob_responses );\r
\r
cvFree( &samples_ptr );\r
cvFree( &missing_ptr );\r
cvFree( &true_resp_ptr );\r
-\r
+#endif\r
return result;\r
}\r
\r
ntrees = 0; \r
}\r
\r
-void CvERTrees::read( CvFileStorage* fs, CvFileNode* node )
-{
+void CvERTrees::read( CvFileStorage* fs, CvFileNode* node )\r
+{\r
CV_FUNCNAME("CvERTrees::read");\r
- __BEGIN__;
- fs = 0; node = 0;
- CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );
- __END__;
-};
-
-void CvERTrees::write( CvFileStorage* fs, const char* name )
-{
+ __BEGIN__;\r
+ fs = 0; node = 0;\r
+ CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );\r
+ __END__;\r
+};\r
+\r
+void CvERTrees::write( CvFileStorage* fs, const char* name )\r
+{\r
CV_FUNCNAME("CvERTrees::write");\r
- __BEGIN__;
- fs = 0; name = 0;
- CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );
- __END__;
+ __BEGIN__;\r
+ fs = 0; name = 0;\r
+ CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );\r
+ __END__;\r
};\r
// End of file.\r
\r
CvERTreeTrainData::CvERTreeTrainData()\r
{\r
- pred = resp = class_lables = 0;\r
+ ord_pred = cat_pred = resp = 0; class_lables = 0;\r
}\r
\r
\r
__END__;\r
}\r
\r
- \r
+\r
void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
\r
int sample_all = 0, r_type = 0, cv_n;\r
int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
- int vi;\r
+ int vi, _pstep, _rstep, postep, pcstep, rstep;\r
const int *sidx = 0;\r
time_t _time;\r
+ int* _idata, *idata;\r
+ float* _fdata, *fdata;\r
+ int cat_count_size, cl_size;\r
\r
if (_var_idx || _sample_idx || _missing_mask)\r
CV_ERROR(CV_StsBadArg, "arguments _var_idx, _sample_idx, _missing_mask are not supported");\r
// compare new and old train data\r
if( !(data->var_count == var_count &&\r
cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
- cvNorm( data->resp, resp, CV_C ) < FLT_EPSILON &&\r
- cvNorm( data->pred, pred, CV_C ) < FLT_EPSILON) )\r
+ cvNorm( data->ord_pred, ord_pred, CV_C ) < FLT_EPSILON &&\r
+ cvNorm( data->cat_pred, cat_pred, CV_C ) < FLT_EPSILON &&\r
+ cvNorm( data->resp, resp, CV_C ) < FLT_EPSILON) )\r
CV_ERROR( CV_StsBadArg,\r
"The new training data must have the same types and the input and output variables "\r
"and the same categories for categorical variables" );\r
if( !CV_IS_MAT(_responses) ||\r
(CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
- _responses->rows != 1 && _responses->cols != 1 ||\r
+ (_responses->rows != 1 && _responses->cols != 1) ||\r
_responses->rows + _responses->cols - 1 != sample_all )\r
CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
"floating-point vector containing as many elements as "\r
cat_var_count++ : ord_var_count--;\r
}\r
ord_var_count = ~ord_var_count;\r
- if (cat_var_count)\r
- CV_ERROR( CV_StsBadArg, "ERTrees support categorical variables only");\r
\r
cv_n = params.cv_folds;\r
- if( cv_n ) \r
+ if( cv_n )\r
CV_ERROR( CV_StsBadArg, "pruning is not supported ERTrees, params.cv_folds must be equel 0" );\r
\r
var_type->data.i[var_count] = cat_var_count;\r
max_c_count = 1;\r
\r
shared = true;\r
- \r
- if( _tflag == CV_ROW_SAMPLE )\r
+\r
+ if (ord_var_count)\r
+ CV_CALL( ord_pred = cvCreateMat( sample_count, ord_var_count, CV_32FC1 ));\r
+ if (cat_var_count)\r
+ CV_CALL( cat_pred = cvCreateMat( sample_count, cat_var_count, CV_32SC1 ));\r
+ CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
+ cat_count_size = is_classifier ? cat_var_count+1 : cat_var_count;\r
+ CV_CALL( cat_count = cvCreateMat( 1, cat_count_size, CV_32SC1 ));\r
+\r
+ CvMat *train_data;\r
+ CV_CALL( train_data = cvCreateMat( sample_count, var_count, _train_data->type ));\r
+ if( _tflag == CV_COL_SAMPLE )\r
{\r
- CV_CALL( pred = cvCreateMat( _train_data->rows, _train_data->cols, _train_data->type ));\r
- CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
- cvCopy( _train_data, pred );\r
+ cvTranspose( _train_data, train_data );\r
cvConvertScale( _responses, resp );\r
+ cvTranspose( resp, resp );\r
}\r
else\r
{\r
- CV_CALL( pred = cvCreateMat( _train_data->cols, _train_data->rows, _train_data->type ));\r
- CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
- cvTranspose( _train_data, pred );\r
- cvConvertScale( _responses, resp );\r
- cvTranspose( resp, resp );\r
+ cvCopy(_train_data, train_data);\r
+ cvConvertScale(_responses, resp);\r
}\r
\r
+ _pstep = train_data->step / CV_ELEM_SIZE(train_data->type);\r
+ _rstep = _responses->step / CV_ELEM_SIZE(_responses->type);\r
+ postep = ord_pred ? ord_pred->step / CV_ELEM_SIZE(ord_pred->type) : 0;\r
+ pcstep = cat_pred ? cat_pred->step / CV_ELEM_SIZE(cat_pred->type) : 0;\r
+ rstep = resp->type ? resp->step / CV_ELEM_SIZE(resp->type) : 0;\r
+ \r
+ _idata = 0; _fdata = 0;\r
+ if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
+ _idata = train_data->data.i;\r
+ else\r
+ _fdata = train_data->data.fl;\r
+\r
+ idata = cat_pred ? cat_pred->data.i : 0;\r
+ fdata = ord_pred ? ord_pred->data.fl : 0;\r
+ for (int vi = 0; vi < var_count; vi++)\r
+ {\r
+ int ci = get_var_type(vi);\r
+ if (ci >= 0)\r
+ {\r
+ for (int si = 0; si < sample_count; si++)\r
+ idata[si*pcstep + ci] = _idata ? _idata[si*_pstep + vi] : cvRound(_fdata[si*_pstep + vi]);\r
+ }\r
+ else\r
+ {\r
+ int idx = get_ord_var_idx(ci); \r
+ for (int si = 0; si < sample_count; si++)\r
+ fdata[si*postep + idx] = _idata ? (float)_idata[si*_pstep + vi] : _fdata[si*_pstep + vi];\r
+ }\r
+ }\r
+ \r
+ cl_size = is_classifier ? (cat_var_count + 1) : cat_var_count;\r
+ if (cl_size)\r
+ class_lables = (CvMat**)cvAlloc( cl_size * sizeof(class_lables[0]));\r
+\r
+ for (int vi = 0; vi < var_count; vi++)\r
+ {\r
+ int ci = get_var_type(vi);\r
+ if (ci >= 0)\r
+ {\r
+ int c_count;\r
+ // calculate count of categories\r
+ for( int i = 0; i < sample_count; i++ )\r
+ {\r
+ int_ptr[i] = &idata[i*pcstep + ci];\r
+ }\r
+\r
+ icvSortIntPtr( int_ptr, sample_count, 0 );\r
+\r
+ c_count = 1;\r
+ for( int i = 1; i < sample_count; i++ )\r
+ c_count += *int_ptr[i] != *int_ptr[i-1];\r
+ cat_count->data.i[ci] = c_count;\r
+\r
+ class_lables[ci] = cvCreateMat( 1, c_count, _train_data->type);\r
+\r
+ int *lbs = class_lables[ci]->data.i;\r
+ int c_idx = 0;\r
+ lbs[c_idx] = *int_ptr[0];\r
+ for( int si = 1; si < sample_count; si++ )\r
+ if (*int_ptr[si] != *int_ptr[si-1])\r
+ {\r
+ c_idx++;\r
+ lbs[c_idx] = *int_ptr[si];\r
+ }\r
+\r
+ for( int si = 0; si < sample_count; si++ )\r
+ {\r
+ for(int j = 0; j < c_count; j++ )\r
+ if ( abs(lbs[j] - idata[si*pcstep + ci]) < FLT_EPSILON )\r
+ {\r
+ idata[si*pcstep + ci] = j;\r
+ break;\r
+ }\r
+ }\r
+ }\r
+ }\r
+ cvReleaseMat( &train_data );\r
if( is_classifier ) \r
{\r
+ int c_count;\r
// calculate count of categories\r
- int *idata = resp->data.i;\r
- int step = CV_IS_MAT_CONT(resp->type) ?\r
- 1 : resp->step / CV_ELEM_SIZE(resp->type);\r
+ idata = resp->data.i;\r
for( int i = 0; i < sample_count; i++ )\r
{\r
int si = sidx ? sidx[i] : i;\r
- int_ptr[i] = &idata[si*step];\r
+ int_ptr[i] = &idata[si*rstep];\r
}\r
\r
icvSortIntPtr( int_ptr, sample_count, 0 );\r
\r
- int c_count = 1;\r
+ c_count = 1;\r
for( int i = 1; i < sample_count; i++ )\r
c_count += *int_ptr[i] != *int_ptr[i-1];\r
\r
- class_lables = cvCreateMat( 1, c_count, resp->type);\r
+ cat_count->data.i[cat_var_count] = c_count;\r
+\r
+ class_lables[cat_var_count] = cvCreateMat( 1, c_count, _responses->type);\r
\r
- int *lbs = class_lables->data.i;\r
+ int *lbs = class_lables[cat_var_count]->data.i;\r
int c_idx = 0;\r
lbs[c_idx] = *int_ptr[0];\r
for( int i = 1; i < sample_count; i++ )\r
lbs[c_idx] = *int_ptr[i];\r
}\r
\r
- for( int i = 0; i < sample_count; i++ )\r
+ for( int si = 0; si < sample_count; si++ )\r
{\r
- int si = sidx ? sidx[i] : i;\r
- si = si * step;\r
- for(int j = 0; j < c_count; j++ )\r
- if ( abs(lbs[j] - idata[si*step]) < FLT_EPSILON )\r
+ for(int j = 0; j < c_count; j++ )\r
+ if ( abs(lbs[j] - idata[si*rstep]) < FLT_EPSILON )\r
{\r
- idata[si*step] = j;\r
+ idata[si*rstep] = j;\r
break;\r
}\r
}\r
int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
int* sidx = 0;\r
int* co = 0;\r
-\r
+ int postep, pcstep, rstep;\r
+ \r
if( _subsample_idx )\r
{\r
CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
}\r
if( missing )\r
memset( missing, 1, count*var_count );\r
+\r
+ postep = ord_pred ? ord_pred->step/CV_ELEM_SIZE(ord_pred->type) : 0;\r
+ pcstep = cat_pred ? cat_pred->step/CV_ELEM_SIZE(cat_pred->type) : 0;\r
+ rstep = resp->step / CV_ELEM_SIZE(resp->type);\r
+ \r
for( vi = 0; vi < var_count; vi++ )\r
{\r
int ci = get_var_type(vi);\r
+ float* dst = values + vi;\r
+ int count1 = data_root->get_num_valid(vi);\r
+ uchar* m = missing ? missing + vi : 0; \r
if( ci >= 0 ) // categorical\r
{\r
- CV_ERROR( CV_StsBadArg, "ERTrees do not supprot categorical variables" );\r
+ for( i = 0; i < count1; i++ ) // count1 == sample_count\r
+ { \r
+ int c = cat_pred->data.i[i*pcstep + ci];\r
+ int val = get_class_idx ? c : class_lables[ci]->data.i[c];\r
+ cur_ofs = i*var_count;\r
+ dst[cur_ofs] = (float) val;\r
+ if( m )\r
+ m[cur_ofs] = 0;\r
+ }\r
}\r
else // ordered\r
{\r
- float* dst = values + vi;\r
- uchar* m = missing ? missing + vi : 0;\r
- int count1 = data_root->get_num_valid(vi);\r
-\r
for( i = 0; i < count1; i++ ) // count1 == sample_count\r
{\r
+ int idx = get_ord_var_idx(ci);\r
cur_ofs = i*var_count; \r
- int step = pred->step / CV_ELEM_SIZE(pred->type);\r
- dst[cur_ofs] = pred->data.fl[i*step + vi];\r
+ dst[cur_ofs] = ord_pred->data.fl[i*postep + idx];\r
if( m )\r
m[cur_ofs] = 0;\r
}\r
{\r
if( is_classifier )\r
for( i = 0; i < count; i++ )// count == sample_count\r
- {\r
- int idx = sidx ? sidx[i] : i;\r
-\r
- int rstep = resp->step / CV_ELEM_SIZE(resp->type);\r
- int cstep = resp->step / CV_ELEM_SIZE(resp->type);\r
- int r = resp->data.i[idx*rstep];\r
- int val = get_class_idx ? r : class_lables->data.i[r*cstep];\r
- responses[i] = (float)val; \r
+ { \r
+ int c = resp->data.i[i*rstep];\r
+ int val = get_class_idx ? c : class_lables[cat_var_count]->data.i[c];\r
responses[i] = (float)val;\r
}\r
else\r
\r
void CvERTreeTrainData::free_train_data()\r
{\r
- cvReleaseMat( &pred );\r
+ cvReleaseMat( &ord_pred );\r
+ cvReleaseMat( &cat_pred );\r
cvReleaseMat( &resp );\r
- cvReleaseMat( &class_lables );\r
+ for (int i = 0; i < cat_var_count; i++)\r
+ cvReleaseMat( &class_lables[i] );\r
+ if (is_classifier)\r
+ cvReleaseMat( &class_lables[cat_var_count] );\r
+ cvFree(&class_lables);\r
CvDTreeTrainData :: free_train_data();\r
}\r
\r
\r
cvReleaseMat( &var_idx );\r
cvReleaseMat( &var_type );\r
- cvReleaseMat( &pred );\r
+ cvReleaseMat( &ord_pred );\r
+ cvReleaseMat( &cat_pred );\r
+ cvReleaseMat( &cat_count );\r
cvReleaseMat( &resp );\r
- cvReleaseMat( &class_lables );\r
+ for (int i = 0; i < cat_var_count; i++)\r
+ cvReleaseMat( &class_lables[i] );\r
+ if (is_classifier)\r
+ cvReleaseMat( &class_lables[cat_var_count] );\r
+ cvFree(&class_lables);\r
cvReleaseMat( &priors );\r
cvReleaseMat( &priors_mult );\r
\r
\r
int CvERTreeTrainData::get_num_classes() const\r
{\r
- return is_classifier ? class_lables->cols : 0;\r
+ return is_classifier ? class_lables[cat_var_count]->cols : 0;\r
}\r
\r
void CvERTreeTrainData::write_params( CvFileStorage* fs )\r