]> rtime.felk.cvut.cz Git - opencv.git/commitdiff
git-svn-id: https://code.ros.org/svn/opencv/trunk@1585 73c94f0f-984f-4a5f-82bc-2d8db8...
authormdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 10 Feb 2009 15:02:28 +0000 (15:02 +0000)
committermdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 10 Feb 2009 15:02:28 +0000 (15:02 +0000)
opencv/include/opencv/ml.h
opencv/src/ml/mlrtrees.cpp
opencv/src/ml/mltree.cpp

index 35467bda175f41e8014809e98514727fe669244e..d5018ae316675c07ad40314d68566124157b40c9 100644 (file)
@@ -1058,9 +1058,12 @@ struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
     virtual void write_params( CvFileStorage* fs );
     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
 
-    CvMat* pred;
+    virtual int get_ord_var_idx(int ci) {return 1-ci;};
+
+    CvMat* ord_pred;
+    CvMat* cat_pred;
     CvMat* resp;
-    CvMat* class_lables;
+    CvMat** class_lables;
 
     // TODO add support _var_idx, _sample_idx, _missing_mask, _add_labels,
     // categorical variables, priors, pruning
@@ -1068,10 +1071,14 @@ struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
 
 class CV_EXPORTS CvForestERTree : public CvForestTree
 {
+public:
+    virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
+        bool preprocessed_input=false ) const;
 protected:
     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
     virtual void try_split_node( CvDTreeNode* n );
     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
     virtual void calc_node_value( CvDTreeNode* node );
    virtual void split_node_data( CvDTreeNode* n );
 };
@@ -1079,7 +1086,7 @@ protected:
 class CV_EXPORTS CvERTrees : public CvRTrees
 {
 public:
-    bool train( const CvMat* _train_data, int _tflag,
+    virtual bool train( const CvMat* _train_data, int _tflag,
         const CvMat* _responses, const CvMat* _var_idx,
         const CvMat* _sample_idx, const CvMat* _var_type,
         const CvMat* _missing_mask, CvRTParams params );
index ff7c11f9008e18214164df4638d92e312099e6d2..5cfb02c8038e6eb3f841840633b011faa85e0705 100644 (file)
@@ -307,19 +307,22 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
     const int dims = data->var_count;\r
     float maximal_response = 0;\r
 \r
+#define RF_OOB\r
+#ifdef RF_OOB\r
     // oob_predictions_sum[i] = sum of predicted values for the i-th sample\r
     // oob_num_of_predictions[i] = number of summands\r
     //                            (number of predictions for the i-th sample)\r
     // initialize these variable to avoid warning C4701\r
     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );\r
     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );\r
-\r
+#endif\r
     nsamples = data->sample_count;\r
     nclasses = data->get_num_classes();\r
 \r
     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
     memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
 \r
+#ifdef RF_OOB\r
     if( data->is_classifier )\r
     {\r
         CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));\r
@@ -336,8 +339,10 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
         cvGetRow( oob_responses, &oob_predictions_sum, 0 );\r
         cvGetRow( oob_responses, &oob_num_of_predictions, 1 );\r
     }\r
+#endif\r
     CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 ));\r
     CV_CALL(sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 ));\r
+#ifdef RF_OOB\r
     CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
     CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));\r
     CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));\r
@@ -350,6 +355,7 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
         cvMinMaxLoc( &responses, &minval, &maxval );\r
         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );\r
     }\r
+#endif\r
 \r
     ntrees = 0;\r
     while( ntrees < max_ntrees )\r
@@ -369,7 +375,7 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
         trees[ntrees] = new CvForestTree();\r
         tree = trees[ntrees];\r
         CV_CALL(tree->train( data, sample_idx_for_tree, this ));\r
-#define RF_OOB\r
+\r
 #ifdef RF_OOB\r
         CvMat sample, missing;\r
         // form array of OOB samples indices and get these samples\r
@@ -493,6 +499,8 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
 \r
     cvReleaseMat( &sample_idx_mask_for_tree );\r
     cvReleaseMat( &sample_idx_for_tree );\r
+\r
+#ifdef RF_OOB\r
     cvReleaseMat( &oob_sample_votes );\r
     cvReleaseMat( &oob_responses );\r
 \r
@@ -500,6 +508,7 @@ bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
     cvFree( &samples_ptr );\r
     cvFree( &missing_ptr );\r
     cvFree( &true_resp_ptr );\r
+#endif\r
 \r
     return result;\r
 }\r
@@ -790,9 +799,9 @@ CvDTreeSplit* CvForestERTree::find_best_split( CvDTreeNode* node )
     int* best_l_sample_idx = (int*)cvAlloc(n*sizeof(best_l_sample_idx[0]));\r
     int best_ln = 0;\r
     int* best_r_sample_idx = (int*)cvAlloc(n*sizeof(best_r_sample_idx[0]));\r
-    int best_rn = 0;
+    int best_rn = 0;\r
     int* tsidx = 0;\r
-
+\r
     if( forest )\r
     {\r
         int var_count;\r
@@ -824,7 +833,7 @@ CvDTreeSplit* CvForestERTree::find_best_split( CvDTreeNode* node )
         {\r
             if( ci >= 0)\r
             {\r
-                CV_ERROR( CV_StsBadArg, "ERTrees do not support categorical variables" );\r
+                split = find_split_cat_class( node, vi); \r
             }\r
             else\r
             {               \r
@@ -934,10 +943,12 @@ CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi )
     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;\r
     double best_val = 0;\r
 \r
-    const float* pred = ((CvERTreeTrainData*)data)->pred->data.fl;\r
+    const float* ord_pred = ((CvERTreeTrainData*)data)->ord_pred->data.fl;\r
+    int ci = data->get_var_type(vi);\r
+    const int idx = ((CvERTreeTrainData*)data)->get_ord_var_idx(ci);\r
     const int* sidx = ((CvERTreeNode*)node)->sample_idx;\r
     const int* resp = ((CvERTreeTrainData*)data)->resp->data.i;\r
-    int pstep = ((CvERTreeTrainData*)data)->pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->pred->type);\r
+    int pstep = ((CvERTreeTrainData*)data)->ord_pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->ord_pred->type);\r
     int rstep = ((CvERTreeTrainData*)data)->resp->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->resp->type);\r
     bool is_find_split = false;\r
 \r
@@ -947,11 +958,11 @@ CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi )
     if( !priors )\r
     {\r
         float pmin, pmax;\r
-        pmin = pred[sidx[0]*pstep + vi];\r
+        pmin = ord_pred[sidx[0]*pstep + idx];\r
         pmax = pmin;\r
         for (int si = 1; si < n; si++)\r
         {\r
-            float ptemp = pred[sidx[si]*pstep + vi];\r
+            float ptemp = ord_pred[sidx[si]*pstep + idx];\r
             if ( ptemp < pmin)\r
                 pmin = ptemp;\r
             if ( ptemp > pmax)\r
@@ -995,7 +1006,7 @@ CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi )
             for( int si = 0; si < n; si++ )\r
             {\r
                 int r = resp[sidx[si]*rstep];\r
-                if ((pred[sidx[si]*pstep + vi]) < split_val)\r
+                if ((ord_pred[sidx[si]*pstep + idx]) < split_val)\r
                 {\r
                     lc[r]++;\r
                     l_sind[L] = sidx[si];\r
@@ -1024,6 +1035,134 @@ CvDTreeSplit* CvForestERTree::find_split_ord_class( CvDTreeNode* node, int vi )
     return is_find_split ? data->new_split_ord( vi, (float)split_val, -1, 0, (float)best_val ) : 0;\r
 }\r
 \r
+CvDTreeSplit* CvForestERTree::find_split_cat_class( CvDTreeNode* node, int vi )\r
+{\r
+    int ci = data->get_var_type(vi);\r
+    int n = node->sample_count;\r
+    int cm = ((CvERTreeTrainData*)data)->get_num_classes(); \r
+    int vm = ((CvERTreeTrainData*)data)->cat_count->data.i[ci];\r
+    double best_val = 0;\r
+    CvDTreeSplit *split = 0;\r
+\r
+    if ( vm > 1 )\r
+    {\r
+        const int* cat_pred = ((CvERTreeTrainData*)data)->cat_pred->data.i;\r
+        const int* resp = ((CvERTreeTrainData*)data)->resp->data.i;\r
+        const int* sidx = ((CvERTreeNode*)node)->sample_idx;\r
+        int pstep = ((CvERTreeTrainData*)data)->cat_pred->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->cat_pred->type);\r
+        int rstep = ((CvERTreeTrainData*)data)->resp->step / CV_ELEM_SIZE(((CvERTreeTrainData*)data)->resp->type);\r
+\r
+        int* lc = (int*)cvStackAlloc(cm*sizeof(lc[0]));\r
+        int* rc = (int*)cvStackAlloc(cm*sizeof(rc[0]));\r
+        \r
+        const double* priors = data->have_priors ? data->priors_mult->data.db : 0;       \r
+\r
+        if( !priors )\r
+        {\r
+            int *valid_cidx = (int*)cvStackAlloc(vm*sizeof(valid_cidx[0]));\r
+            for (int i = 0; i < vm; i++)\r
+            {\r
+                valid_cidx[i] = -1;\r
+            }\r
+            for (int si = 0; si < n; si++)\r
+            {\r
+                int c = cat_pred[sidx[si]*pstep + ci];\r
+                valid_cidx[c]++;\r
+            }\r
+\r
+            int valid_ccount = 0;\r
+            for (int i = 0; i < vm; i++)\r
+                if (valid_cidx[i] >= 0)\r
+                {\r
+                    valid_cidx[i] = valid_ccount;\r
+                    valid_ccount++;\r
+                }\r
+            if (valid_ccount > 1)\r
+            {\r
+                CvRNG* rng = forest->get_rng();\r
+                int lcv_count = 1 + cvRandInt(rng) % (valid_ccount-1);\r
+\r
+                CvMat* var_class_mask = cvCreateMat( 1, valid_ccount, CV_8UC1 );\r
+                CvMat submask;\r
+                memset(var_class_mask->data.ptr, 0, valid_ccount*CV_ELEM_SIZE(var_class_mask->type));\r
+                cvGetCols( var_class_mask, &submask, 0, lcv_count );\r
+                cvSet( &submask, cvScalar(1) );\r
+                for (int i = 0; i < valid_ccount; i++)\r
+                {\r
+                    uchar temp;\r
+                    int i1 = cvRandInt( rng ) % valid_ccount;\r
+                    int i2 = cvRandInt( rng ) % valid_ccount;\r
+                    CV_SWAP( var_class_mask->data.ptr[i1], var_class_mask->data.ptr[i2], temp );\r
+                }\r
+\r
+                // init arrays of class instance counters on both sides of the split\r
+                for(int i = 0; i < cm; i++ )\r
+                {\r
+                    lc[i] = 0;\r
+                    rc[i] = 0;\r
+                }\r
+\r
+                if (!((CvERTreeNode*)node)->l_sample_idx)\r
+                {\r
+                    ((CvERTreeNode*)node)->l_sample_idx = (int*)cvAlloc( sizeof(((CvERTreeNode*)node)->l_sample_idx[0])*n );\r
+                    ((CvERTreeNode*)node)->ln = 0;\r
+                }\r
+                if (!((CvERTreeNode*)node)->r_sample_idx)\r
+                {\r
+                    ((CvERTreeNode*)node)->r_sample_idx = (int*)cvAlloc( sizeof(((CvERTreeNode*)node)->r_sample_idx[0])*n );\r
+                    ((CvERTreeNode*)node)->rn = 0;\r
+                }\r
+\r
+                int* l_sind = ((CvERTreeNode*)node)->l_sample_idx;\r
+                int* r_sind = ((CvERTreeNode*)node)->r_sample_idx;\r
+\r
+                split = data->new_split_cat( vi, -1. );\r
+\r
+                // calculate Gini index\r
+                double lbest_val = 0, rbest_val = 0;\r
+                int L = 0, R = 0;\r
+                                \r
+                for( int si = 0; si < n; si++ )\r
+                {\r
+                    int r = resp[sidx[si]*rstep];\r
+                    int var_class_idx = cat_pred[sidx[si]*pstep + ci];\r
+                    int mask_class_idx = valid_cidx[var_class_idx];\r
+                    if (var_class_mask->data.ptr[mask_class_idx])\r
+                    {\r
+                        lc[r]++;\r
+                        l_sind[L] = sidx[si];\r
+                        L++;                 \r
+                        split->subset[var_class_idx >> 5] |= 1 << (var_class_idx & 31);\r
+                    }\r
+                    else\r
+                    {\r
+                        rc[r]++;\r
+                        r_sind[R] = sidx[si];\r
+                        R++;\r
+                    }\r
+                }\r
+                for (int i = 0; i < cm; i++)\r
+                {\r
+                    lbest_val += lc[i]*lc[i];\r
+                    rbest_val += rc[i]*rc[i];\r
+                }\r
+                lbest_val = lbest_val/L;\r
+                rbest_val = rbest_val/R;\r
+                best_val = lbest_val + rbest_val;\r
+\r
+                split->quality = (float)best_val;\r
+\r
+                ((CvERTreeNode*)node)->ln = L;\r
+                ((CvERTreeNode*)node)->rn = R;\r
+\r
+                cvReleaseMat(&var_class_mask);\r
+            }             \r
+        }\r
+    }        \r
+  \r
+    return split;\r
+}\r
+\r
 void CvForestERTree::calc_node_value( CvDTreeNode* node )\r
 {\r
     CV_FUNCNAME("CvForestERTree::calc_node_value");\r
@@ -1076,7 +1215,7 @@ void CvForestERTree::calc_node_value( CvDTreeNode* node )
         }\r
 \r
         node->class_idx = max_k;\r
-        const int* ldata = ((CvERTreeTrainData*)data)->class_lables->data.i;\r
+        const int* ldata = ((CvERTreeTrainData*)data)->class_lables[data->cat_var_count]->data.i;\r
         node->value = ldata[max_k]; \r
 \r
         node->node_risk = total_weight - max_val;\r
@@ -1122,6 +1261,110 @@ void CvForestERTree::split_node_data( CvDTreeNode* node )
     data->free_node_data(node);\r
 }\r
 \r
+CvDTreeNode* CvForestERTree::predict( const CvMat* _sample,\r
+                                     const CvMat* _missing, bool preprocessed_input ) const\r
+{\r
+    CvDTreeNode* result = 0;\r
+\r
+    CV_FUNCNAME( "CvForestERTree::predict" );\r
+\r
+    __BEGIN__;\r
+\r
+    int i, step, mstep = 0;\r
+    const float* sample;\r
+    const uchar* m = 0;\r
+    CvDTreeNode* node = root;\r
+    const int* vtype;\r
+    const int* vidx;\r
+\r
+    if( !node )\r
+        CV_ERROR( CV_StsError, "The tree has not been trained yet" );\r
+\r
+    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||\r
+        (_sample->cols != 1 && _sample->rows != 1) ||\r
+        (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||\r
+        (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )\r
+        CV_ERROR( CV_StsBadArg,\r
+        "the input sample must be 1d floating-point vector with the same "\r
+        "number of elements as the total number of variables used for training" );\r
+\r
+    sample = _sample->data.fl;\r
+    step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);\r
+\r
+    if( _missing )\r
+    {\r
+        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||\r
+            !CV_ARE_SIZES_EQ(_missing, _sample) )\r
+            CV_ERROR( CV_StsBadArg,\r
+            "the missing data mask must be 8-bit vector of the same size as input sample" );\r
+        m = _missing->data.ptr;\r
+        mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);\r
+    }\r
+\r
+    vtype = data->var_type->data.i;\r
+    vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;\r
+\r
+    while( node->Tn > pruned_tree_idx && node->left )\r
+    {\r
+        CvDTreeSplit* split = node->split;\r
+        int dir = 0;\r
+        for( ; !dir && split != 0; split = split->next )\r
+        {\r
+            int vi = split->var_idx;\r
+            int ci = vtype[vi];\r
+            i = vidx ? vidx[vi] : vi;\r
+            float val = sample[i*step];\r
+            if( m && m[i*mstep] )\r
+                continue;\r
+            if( ci < 0 ) // ordered\r
+                dir = val <= split->ord.c ? -1 : 1;\r
+            else // categorical\r
+            {\r
+                int c;\r
+                int c_count = data->cat_count->data.i[ci];\r
+                if( preprocessed_input )\r
+                    c = cvRound(val);\r
+                else\r
+                {\r
+                    int ival = cvRound(val);\r
+                    int i = 0;\r
+                    int* labls = ((CvERTreeTrainData*)data)->class_lables[ci]->data.i;\r
+                    if( ival != val )\r
+                        CV_ERROR( CV_StsBadArg,\r
+                            "one of input categorical variable is not an integer" );\r
+                    for (i = 0; i < c_count; i++)\r
+                        if (ival == labls[i]) break;\r
+                    c = i;\r
+                }\r
+                if (c == c_count)\r
+                {\r
+                    CvRNG* rng = &data->rng;\r
+                    dir = 2*(cvRandInt(rng)%2)-1;\r
+                }\r
+                else\r
+                    dir = CV_DTREE_CAT_DIR(c, split->subset);\r
+            }\r
+\r
+            if( split->inversed )\r
+                dir = -dir;\r
+        }\r
+\r
+        if( !dir )\r
+        {\r
+            double diff = node->right->sample_count - node->left->sample_count;\r
+            dir = diff < 0 ? -1 : 1;\r
+        }\r
+        node = dir < 0 ? node->left : node->right;\r
+    }\r
+\r
+    result = node;\r
+\r
+    __END__;\r
+\r
+    return result;\r
+}\r
+\r
+\r
 bool CvERTrees::train( const CvMat* _train_data, int _tflag,\r
                       const CvMat* _responses, const CvMat* _var_idx,\r
                       const CvMat* _sample_idx, const CvMat* _var_type,\r
@@ -1195,6 +1438,21 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
 {\r
     bool result = false;\r
 \r
+    CV_FUNCNAME("CvERTrees::grow_forest");\r
+    __BEGIN__;\r
+\r
+    const int max_ntrees = term_crit.max_iter;\r
+    const double max_oob_err = term_crit.epsilon;\r
+\r
+    nsamples = data->sample_count;\r
+    nclasses = ((CvERTreeTrainData*)data)->get_num_classes();\r
+\r
+    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
+\r
+    memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
+\r
+//#define ET_OOB\r
+#ifdef ET_OOB\r
     CvMat* oob_sample_votes       = 0;\r
     CvMat* oob_responses       = 0;\r
 \r
@@ -1204,25 +1462,12 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
     uchar* missing_ptr     = 0;\r
     float* true_resp_ptr   = 0;\r
 \r
-    CV_FUNCNAME("CvERTrees::grow_forest");\r
-    __BEGIN__;\r
-\r
-    const int max_ntrees = term_crit.max_iter;\r
-    const double max_oob_err = term_crit.epsilon;\r
-\r
     const int dims = data->var_count;\r
-    float maximal_response = 0;\r
+    float maximal_response = 0;   \r
 \r
     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );\r
     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );\r
 \r
-    nsamples = data->sample_count;\r
-    nclasses = ((CvERTreeTrainData*)data)->get_num_classes();\r
-\r
-    trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );\r
-\r
-    memset( trees, 0, sizeof(trees[0])*max_ntrees );\r
-\r
     if( data->is_classifier )\r
     {\r
         CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));\r
@@ -1247,20 +1492,19 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
         cvMinMaxLoc( &responses, &minval, &maxval );\r
         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );\r
     }\r
-\r
+#endif\r
     ntrees = 0;\r
     while( ntrees < max_ntrees )\r
     {\r
-        int oob_samples_count = 0;\r
-        double ncorrect_responses = 0;\r
         CvForestERTree* tree = 0;\r
 \r
         trees[ntrees] = new CvForestERTree();\r
         tree = (CvForestERTree*)trees[ntrees];\r
         CV_CALL(tree->train( data, 0, this ));\r
-#define ET_OOB\r
+\r
 #ifdef ET_OOB\r
-        int i;\r
+        int i, oob_samples_count = 0;\r
+        double ncorrect_responses = 0;\r
         CvMat sample, missing;\r
         sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );\r
         missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );\r
@@ -1290,7 +1534,7 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
                 // compute oob error\r
                 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );\r
 \r
-                prdct_resp = ((CvERTreeTrainData*)data)->class_lables->data.i[max_loc.x];\r
+                prdct_resp = ((CvERTreeTrainData*)data)->class_lables[data->cat_var_count]->data.i[max_loc.x];\r
                 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;\r
 \r
                 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;\r
@@ -1368,6 +1612,7 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
 \r
     __END__;\r
 \r
+#ifdef ET_OOB\r
     cvReleaseMat( &oob_sample_votes );\r
     cvReleaseMat( &oob_responses );\r
 \r
@@ -1375,7 +1620,7 @@ bool CvERTrees::grow_forest( const CvTermCriteria term_crit )
     cvFree( &samples_ptr );\r
     cvFree( &missing_ptr );\r
     cvFree( &true_resp_ptr );\r
-\r
+#endif\r
     return result;\r
 }\r
 \r
@@ -1395,21 +1640,21 @@ void CvERTrees::clear()
     ntrees = 0;  \r
 }\r
 \r
-void CvERTrees::read( CvFileStorage* fs, CvFileNode* node )
-{
+void CvERTrees::read( CvFileStorage* fs, CvFileNode* node )\r
+{\r
     CV_FUNCNAME("CvERTrees::read");\r
-    __BEGIN__;
-    fs = 0; node = 0;
-    CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );
-    __END__;
-};
-
-void CvERTrees::write( CvFileStorage* fs, const char* name )
-{
+    __BEGIN__;\r
+    fs = 0; node = 0;\r
+    CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );\r
+    __END__;\r
+};\r
+\r
+void CvERTrees::write( CvFileStorage* fs, const char* name )\r
+{\r
     CV_FUNCNAME("CvERTrees::write");\r
-    __BEGIN__;
-    fs = 0; name = 0;
-    CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );
-    __END__;
+    __BEGIN__;\r
+    fs = 0; name = 0;\r
+    CV_ERROR( CV_StsBadArg, "ERTrees do not support this method" );\r
+    __END__;\r
 };\r
 // End of file.\r
index ee4804c6be81c67725f77c62feefd6f8ac6503f7..ce3aad7e63ac8d1ee31f5754238bf375db3ac2be 100644 (file)
@@ -1216,7 +1216,7 @@ void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
 \r
 CvERTreeTrainData::CvERTreeTrainData()\r
 {\r
-    pred = resp = class_lables = 0;\r
+    ord_pred = cat_pred = resp = 0; class_lables = 0;\r
 }\r
 \r
 \r
@@ -1241,7 +1241,7 @@ CvERTreeTrainData::CvERTreeTrainData( const CvMat* _train_data, int _tflag,
     __END__;\r
 }\r
 \r
-                 \r
+\r
 void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,\r
     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,\r
     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,\r
@@ -1258,9 +1258,12 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
 \r
     int sample_all = 0, r_type = 0, cv_n;\r
     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;\r
-    int vi;\r
+    int vi, _pstep, _rstep, postep, pcstep, rstep;\r
     const int *sidx = 0;\r
     time_t _time;\r
+       int* _idata, *idata;\r
+    float* _fdata, *fdata;\r
+    int cat_count_size, cl_size;\r
 \r
     if (_var_idx || _sample_idx || _missing_mask)\r
         CV_ERROR(CV_StsBadArg, "arguments _var_idx, _sample_idx, _missing_mask are not supported");\r
@@ -1273,8 +1276,9 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
         // compare new and old train data\r
         if( !(data->var_count == var_count &&\r
             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&\r
-            cvNorm( data->resp, resp, CV_C ) < FLT_EPSILON &&\r
-            cvNorm( data->pred, pred, CV_C ) < FLT_EPSILON) )\r
+            cvNorm( data->ord_pred, ord_pred, CV_C ) < FLT_EPSILON &&\r
+            cvNorm( data->cat_pred, cat_pred, CV_C ) < FLT_EPSILON &&\r
+            cvNorm( data->resp, resp, CV_C ) < FLT_EPSILON) )\r
             CV_ERROR( CV_StsBadArg,\r
                 "The new training data must have the same types and the input and output variables "\r
                 "and the same categories for categorical variables" );\r
@@ -1310,7 +1314,7 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
     if( !CV_IS_MAT(_responses) ||\r
         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&\r
         CV_MAT_TYPE(_responses->type) != CV_32FC1) ||\r
-        _responses->rows != 1 && _responses->cols != 1 ||\r
+        (_responses->rows != 1 && _responses->cols != 1) ||\r
         _responses->rows + _responses->cols - 1 != sample_all )\r
         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "\r
         "floating-point vector containing as many elements as "\r
@@ -1333,11 +1337,9 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
             cat_var_count++ : ord_var_count--;\r
     }\r
     ord_var_count = ~ord_var_count;\r
-    if (cat_var_count)\r
-        CV_ERROR( CV_StsBadArg, "ERTrees support categorical variables only");\r
 \r
     cv_n = params.cv_folds;\r
-    if( cv_n )    \r
+    if( cv_n )\r
         CV_ERROR( CV_StsBadArg, "pruning is not supported ERTrees, params.cv_folds must be equel 0" );\r
 \r
     var_type->data.i[var_count] = cat_var_count;\r
@@ -1368,44 +1370,128 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
     max_c_count = 1;\r
 \r
     shared = true;\r
-    \r
-    if( _tflag == CV_ROW_SAMPLE )\r
+\r
+    if (ord_var_count)\r
+        CV_CALL( ord_pred = cvCreateMat( sample_count, ord_var_count, CV_32FC1 ));\r
+    if (cat_var_count)\r
+        CV_CALL( cat_pred = cvCreateMat( sample_count, cat_var_count, CV_32SC1 ));\r
+    CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
+    cat_count_size = is_classifier ? cat_var_count+1 : cat_var_count;\r
+    CV_CALL( cat_count = cvCreateMat( 1, cat_count_size, CV_32SC1 ));\r
+\r
+    CvMat *train_data;\r
+    CV_CALL( train_data = cvCreateMat( sample_count, var_count, _train_data->type ));\r
+    if( _tflag == CV_COL_SAMPLE )\r
     {\r
-        CV_CALL( pred = cvCreateMat( _train_data->rows, _train_data->cols, _train_data->type ));\r
-        CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
-        cvCopy( _train_data, pred );\r
+        cvTranspose( _train_data, train_data );\r
         cvConvertScale( _responses, resp );\r
+        cvTranspose( resp, resp );\r
     }\r
     else\r
     {\r
-        CV_CALL( pred = cvCreateMat( _train_data->cols, _train_data->rows, _train_data->type ));\r
-        CV_CALL( resp = cvCreateMat( _responses->rows, _responses->cols, CV_32SC1 ));\r
-        cvTranspose( _train_data, pred );\r
-        cvConvertScale( _responses, resp );\r
-        cvTranspose( resp, resp );\r
+        cvCopy(_train_data, train_data);\r
+        cvConvertScale(_responses, resp);\r
     }\r
 \r
+    _pstep = train_data->step / CV_ELEM_SIZE(train_data->type);\r
+    _rstep = _responses->step / CV_ELEM_SIZE(_responses->type);\r
+    postep = ord_pred ? ord_pred->step / CV_ELEM_SIZE(ord_pred->type) : 0;\r
+    pcstep = cat_pred ? cat_pred->step / CV_ELEM_SIZE(cat_pred->type) : 0;\r
+    rstep =  resp->type ? resp->step / CV_ELEM_SIZE(resp->type) : 0;\r
+        \r
+    _idata = 0; _fdata = 0;\r
+    if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )\r
+        _idata = train_data->data.i;\r
+    else\r
+        _fdata = train_data->data.fl;\r
+\r
+    idata = cat_pred ? cat_pred->data.i : 0;\r
+    fdata = ord_pred ? ord_pred->data.fl : 0;\r
+    for (int vi = 0; vi < var_count; vi++)\r
+    {\r
+        int ci = get_var_type(vi);\r
+        if (ci >= 0)\r
+        {\r
+            for (int si = 0; si < sample_count; si++)\r
+                idata[si*pcstep + ci] = _idata ? _idata[si*_pstep + vi] : cvRound(_fdata[si*_pstep + vi]);\r
+        }\r
+        else\r
+        {\r
+            int idx = get_ord_var_idx(ci);            \r
+            for (int si = 0; si < sample_count; si++)\r
+                fdata[si*postep + idx] = _idata ? (float)_idata[si*_pstep + vi] : _fdata[si*_pstep + vi];\r
+        }\r
+    }\r
+    \r
+    cl_size = is_classifier ? (cat_var_count + 1) : cat_var_count;\r
+    if (cl_size)\r
+        class_lables = (CvMat**)cvAlloc( cl_size * sizeof(class_lables[0]));\r
+\r
+    for (int vi = 0; vi < var_count; vi++)\r
+    {\r
+        int ci = get_var_type(vi);\r
+        if (ci >= 0)\r
+        {\r
+            int c_count;\r
+            // calculate count of categories\r
+            for( int i = 0; i < sample_count; i++ )\r
+            {\r
+                int_ptr[i] = &idata[i*pcstep + ci];\r
+            }\r
+\r
+            icvSortIntPtr( int_ptr, sample_count, 0 );\r
+\r
+            c_count = 1;\r
+            for( int i = 1; i < sample_count; i++ )\r
+                c_count += *int_ptr[i] != *int_ptr[i-1];\r
+            cat_count->data.i[ci] = c_count;\r
+\r
+            class_lables[ci] = cvCreateMat( 1, c_count, _train_data->type);\r
+\r
+            int *lbs = class_lables[ci]->data.i;\r
+            int c_idx = 0;\r
+            lbs[c_idx] = *int_ptr[0];\r
+            for( int si = 1; si < sample_count; si++ )\r
+                if (*int_ptr[si] != *int_ptr[si-1])\r
+                {\r
+                    c_idx++;\r
+                    lbs[c_idx] = *int_ptr[si];\r
+                }\r
+\r
+            for( int si = 0; si < sample_count; si++ )\r
+            {\r
+                for(int j = 0; j < c_count; j++ )\r
+                    if ( abs(lbs[j] - idata[si*pcstep + ci]) < FLT_EPSILON )\r
+                    {\r
+                        idata[si*pcstep + ci] = j;\r
+                        break;\r
+                    }\r
+            }\r
+        }\r
+    }\r
+    cvReleaseMat( &train_data );\r
     if( is_classifier ) \r
     {\r
+        int c_count;\r
         // calculate count of categories\r
-        int *idata = resp->data.i;\r
-        int step = CV_IS_MAT_CONT(resp->type) ?\r
-            1 : resp->step / CV_ELEM_SIZE(resp->type);\r
+        idata = resp->data.i;\r
         for( int i = 0; i < sample_count; i++ )\r
         {\r
             int si = sidx ? sidx[i] : i;\r
-            int_ptr[i] = &idata[si*step];\r
+            int_ptr[i] = &idata[si*rstep];\r
         }\r
 \r
         icvSortIntPtr( int_ptr, sample_count, 0 );\r
 \r
-        int c_count = 1;\r
+        c_count = 1;\r
         for( int i = 1; i < sample_count; i++ )\r
             c_count += *int_ptr[i] != *int_ptr[i-1];\r
 \r
-        class_lables = cvCreateMat( 1, c_count, resp->type);\r
+        cat_count->data.i[cat_var_count] = c_count;\r
+\r
+        class_lables[cat_var_count] = cvCreateMat( 1, c_count, _responses->type);\r
 \r
-        int *lbs = class_lables->data.i;\r
+        int *lbs = class_lables[cat_var_count]->data.i;\r
         int c_idx = 0;\r
         lbs[c_idx] = *int_ptr[0];\r
         for( int i = 1; i < sample_count; i++ )\r
@@ -1415,14 +1501,12 @@ void CvERTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
                 lbs[c_idx] = *int_ptr[i];\r
             }\r
 \r
-            for( int i = 0; i < sample_count; i++ )\r
+            for( int si = 0; si < sample_count; si++ )\r
             {\r
-                int si = sidx ? sidx[i] : i;\r
-                si = si * step;\r
-                for(int j = 0; j < c_count; j++ )\r
-                    if ( abs(lbs[j] - idata[si*step]) < FLT_EPSILON )\r
+               for(int j = 0; j < c_count; j++ )\r
+                    if ( abs(lbs[j] - idata[si*rstep]) < FLT_EPSILON )\r
                     {\r
-                        idata[si*step] = j;\r
+                        idata[si*rstep] = j;\r
                         break;\r
                     }\r
             }\r
@@ -1495,7 +1579,8 @@ void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
     int i, vi, total = sample_count, count = total, cur_ofs = 0;\r
     int* sidx = 0;\r
     int* co = 0;\r
-\r
+    int postep, pcstep, rstep;\r
+    \r
     if( _subsample_idx )\r
     {\r
         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));\r
@@ -1518,24 +1603,36 @@ void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
     }\r
     if( missing )\r
         memset( missing, 1, count*var_count );\r
+\r
+    postep = ord_pred ? ord_pred->step/CV_ELEM_SIZE(ord_pred->type) : 0;\r
+    pcstep = cat_pred ? cat_pred->step/CV_ELEM_SIZE(cat_pred->type) : 0;\r
+    rstep = resp->step / CV_ELEM_SIZE(resp->type);\r
+    \r
     for( vi = 0; vi < var_count; vi++ )\r
     {\r
         int ci = get_var_type(vi);\r
+        float* dst = values + vi;\r
+        int count1 = data_root->get_num_valid(vi);\r
+        uchar* m = missing ? missing + vi : 0;  \r
         if( ci >= 0 ) // categorical\r
         {\r
-            CV_ERROR( CV_StsBadArg, "ERTrees do not supprot categorical variables" );\r
+            for( i = 0; i < count1; i++ ) // count1 == sample_count\r
+            {        \r
+                int c = cat_pred->data.i[i*pcstep + ci];\r
+                int val = get_class_idx ? c : class_lables[ci]->data.i[c];\r
+                cur_ofs = i*var_count;\r
+                dst[cur_ofs] = (float) val;\r
+                if( m )\r
+                    m[cur_ofs] = 0;\r
+            }\r
         }\r
         else // ordered\r
         {\r
-            float* dst = values + vi;\r
-            uchar* m = missing ? missing + vi : 0;\r
-            int count1 = data_root->get_num_valid(vi);\r
-\r
             for( i = 0; i < count1; i++ ) // count1 == sample_count\r
             {\r
+                int idx = get_ord_var_idx(ci);\r
                 cur_ofs = i*var_count;  \r
-                int step = pred->step / CV_ELEM_SIZE(pred->type);\r
-                dst[cur_ofs] = pred->data.fl[i*step + vi];\r
+                dst[cur_ofs] = ord_pred->data.fl[i*postep + idx];\r
                 if( m )\r
                     m[cur_ofs] = 0;\r
             }\r
@@ -1547,14 +1644,9 @@ void CvERTreeTrainData::get_vectors( const CvMat* _subsample_idx,
     {\r
         if( is_classifier )\r
           for( i = 0; i < count; i++ )// count == sample_count\r
-            {\r
-                int idx = sidx ? sidx[i] : i;\r
-\r
-                int rstep = resp->step / CV_ELEM_SIZE(resp->type);\r
-                int cstep = resp->step / CV_ELEM_SIZE(resp->type);\r
-                int r = resp->data.i[idx*rstep];\r
-                int val = get_class_idx ? r : class_lables->data.i[r*cstep];\r
-                    responses[i] = (float)val; \r
+            {   \r
+                int c = resp->data.i[i*rstep];\r
+                int val = get_class_idx ? c : class_lables[cat_var_count]->data.i[c];\r
                 responses[i] = (float)val;\r
             }\r
         else\r
@@ -1627,9 +1719,14 @@ void CvERTreeTrainData::free_node_data( CvDTreeNode* node )
 \r
 void CvERTreeTrainData::free_train_data()\r
 {\r
-    cvReleaseMat( &pred );\r
+    cvReleaseMat( &ord_pred );\r
+    cvReleaseMat( &cat_pred );\r
     cvReleaseMat( &resp );\r
-    cvReleaseMat( &class_lables );\r
+    for (int i = 0; i < cat_var_count; i++)\r
+        cvReleaseMat( &class_lables[i] );\r
+    if (is_classifier)\r
+        cvReleaseMat( &class_lables[cat_var_count] );\r
+    cvFree(&class_lables);\r
     CvDTreeTrainData :: free_train_data();\r
 }\r
 \r
@@ -1642,9 +1739,15 @@ void CvERTreeTrainData::clear()
 \r
     cvReleaseMat( &var_idx );\r
     cvReleaseMat( &var_type );\r
-    cvReleaseMat( &pred );\r
+    cvReleaseMat( &ord_pred );\r
+    cvReleaseMat( &cat_pred );\r
+    cvReleaseMat( &cat_count );\r
     cvReleaseMat( &resp );\r
-    cvReleaseMat( &class_lables );\r
+    for (int i = 0; i < cat_var_count; i++)\r
+        cvReleaseMat( &class_lables[i] );\r
+    if (is_classifier)\r
+        cvReleaseMat( &class_lables[cat_var_count] );\r
+    cvFree(&class_lables);\r
     cvReleaseMat( &priors );\r
     cvReleaseMat( &priors_mult );\r
 \r
@@ -1664,7 +1767,7 @@ void CvERTreeTrainData::clear()
 \r
 int CvERTreeTrainData::get_num_classes() const\r
 {\r
-    return is_classifier ? class_lables->cols : 0;\r
+    return is_classifier ? class_lables[cat_var_count]->cols : 0;\r
 }\r
 \r
 void CvERTreeTrainData::write_params( CvFileStorage* fs )\r