]> rtime.felk.cvut.cz Git - opencv.git/commitdiff
Boosting added,
authorvp153 <vp153@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Wed, 30 Aug 2006 17:37:19 +0000 (17:37 +0000)
committervp153 <vp153@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Wed, 30 Aug 2006 17:37:19 +0000 (17:37 +0000)
multiple improvements in standalone decision trees and random trees.

git-svn-id: https://code.ros.org/svn/opencv/trunk@772 73c94f0f-984f-4a5f-82bc-2d8db8d8ee08

opencv/include/opencv/ml.h
opencv/src/ml/mlboost.cpp
opencv/src/ml/mlrtrees.cpp
opencv/src/ml/mltree.cpp

index ad1e22468f6a7d52435c46eb2c07e1a1b53e184b..fafc66269756d1ee6d97c50c0efe112843736c7e 100644 (file)
@@ -115,38 +115,16 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void )
 #define CV_VAR_ORDERED      0
 #define CV_VAR_CATEGORICAL  1
 
-/* flag values for classifier consturctor <flags> parameter */
-#define CV_SVM_MAGIC_VAL            0x0000FF01
 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
-
-#define CV_KNN_MAGIC_VAL            0x0000FF02
 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
-
-#define CV_NBAYES_MAGIC_VAL         0x0000FF03
 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
-
-#define CV_EM_MAGIC_VAL             0x0000FF04
 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
-
-#define CV_BOOST_TREE_MAGIC_VAL     0x0000FF05
 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
-
-#define CV_TREE_MAGIC_VAL           0x0000FF06
 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
-
-#define CV_ANN_MLP_MAGIC_VAL        0x0000FF07
 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
-
-#define CV_CNN_MAGIC_VAL            0x0000FF08
 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
-
-#define CV_RTREES_MAGIC_VAL         0x0000FF09
 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
 
-#define CV_CROSSVAL_MAGIC_VAL       0x0000FF10
-#define CV_TYPE_NAME_ML_CROSSVAL    "opencv-ml-cross-validation"
-
-
 class CV_EXPORTS CvStatModel
 {
 public:
@@ -487,92 +465,6 @@ cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
             const CvParamLattice* nu_lattice     CV_DEFAULT(0),
             const CvParamLattice* p_lattice      CV_DEFAULT(0) );*/
 
-/****************************************************************************************\
-*                                   Boosted trees models                                 *
-\****************************************************************************************/
-
-#if 0
-
-/* Boosted trees training parameters */
-struct CV_EXPORTS CvBoostTrainParams
-{
-    int boost_type;
-    int weak_count;
-    double infl_trim_rate;
-    int weak_tree_splits;
-
-    CvBoostTrainParams();
-    CvBoostTrainParams( int boost_type, int weak_count,
-        double infl_trim_rate, int weak_tree_splits );
-};
-
-
-struct CV_EXPORTS CvBoostWeakTree
-{
-    /* number of internal tree nodes (splits) */
-    int count;
-
-    /* internal nodes (each is array of <count> elements) */
-    int* var_idx;
-    float* threshold;
-    int* left;
-    int* right;
-
-    /* leaves (array of <count>+1 elements) */
-    float* val;
-};
-
-
-class CV_EXPORTS CvBoost : public CvStatModel
-{
-public:
-    // Type of return value in ::predict
-    enum { VALUE=0, INDEX=1 };
-
-    // Boosting type
-    enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
-
-    CvBoost();
-    virtual ~CvBoost();
-
-    CvBoost( const CvMat* _train_data, int _tflag,
-             const CvMat* _responses, const CvMat* _var_idx=0,
-             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
-             const CvMat* _missing_mask=0,
-             CvBoostTrainParams params=CvBoostTrainParams() );
-        
-    virtual bool train( const CvMat* _train_data, int _tflag,
-             const CvMat* _responses, const CvMat* _var_idx=0,
-             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
-             const CvMat* _missing_mask=0,
-             CvBoostTrainParams params=CvBoostTrainParams(),
-             bool update=false );
-
-    virtual float predict( const CvMat* _sample,
-                           CvMat* weak_responses,
-                           CvSlice slice=CV_WHOLE_SEQ,
-                           int eval_type=VALUE) const;
-
-    virtual void prune( CvSlice slice );
-
-    virtual void clear();
-
-    virtual void write( CvFileStorage* storage, const char* name );
-    virtual void read( CvFileStorage* storage, CvFileNode* node );
-
-    CvSeq* get_weak_predictors();
-
-protected:
-    CvBoostTrainParams params;    
-    CvMat* class_labels;
-    int total_features;
-    CvSeq* weak; /* weak classifiers (CvTreeBoostClassifier) pointers */
-    CvMat* var_idx;
-    void* ts;
-};
-
-#endif
-
 /****************************************************************************************\
 *                              Expectation - Maximization                                *
 \****************************************************************************************/
@@ -663,6 +555,10 @@ struct CvPair32s32f
     float val;
 };
 
+
+#define CV_DTREE_CAT_DIR(idx,subset) \
+    (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
+
 struct CvDTreeSplit
 {
     int var_idx;
@@ -755,35 +651,39 @@ struct CV_EXPORTS CvDTreeTrainData
                       const CvMat* _responses, const CvMat* _var_idx=0,
                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                       const CvMat* _missing_mask=0,
-                      CvDTreeParams _params=CvDTreeParams(),
-                      bool _shared=false, bool _add_weights=false );
+                      const CvDTreeParams& _params=CvDTreeParams(),
+                      bool _shared=false, bool _add_labels=false );
     virtual ~CvDTreeTrainData();
 
     virtual void set_data( const CvMat* _train_data, int _tflag,
                           const CvMat* _responses, const CvMat* _var_idx=0,
                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                           const CvMat* _missing_mask=0,
-                          CvDTreeParams _params=CvDTreeParams(),
-                          bool _shared=false, bool _add_weights=false );
+                          const CvDTreeParams& _params=CvDTreeParams(),
+                          bool _shared=false, bool _add_labels=false,
+                          bool _update_data=false );
 
     virtual void get_vectors( const CvMat* _subsample_idx,
          float* values, uchar* missing, float* responses, bool get_class_idx=false );
 
     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
 
+    virtual void write_params( CvFileStorage* fs );
+    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
+
     // release all the data
     virtual void clear();
 
     int get_num_classes() const;
     int get_var_type(int vi) const;
+    int get_work_var_count() const;
 
     virtual int* get_class_labels( CvDTreeNode* n );
     virtual float* get_ord_responses( CvDTreeNode* n );
-    virtual int* get_cv_labels( CvDTreeNode* n );
+    virtual int* get_labels( CvDTreeNode* n );
     virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
     virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
     virtual int get_child_buf_idx( CvDTreeNode* n );
-    virtual float* get_weights( CvDTreeNode* n );
 
     ////////////////////////////////////
 
@@ -800,7 +700,7 @@ struct CV_EXPORTS CvDTreeTrainData
 
     int sample_count, var_all, var_count, max_c_count;
     int ord_var_count, cat_var_count;
-    bool have_cv_labels, have_priors, have_weights;
+    bool have_labels, have_priors;
     bool is_classifier;
 
     int buf_count, buf_size;
@@ -859,6 +759,11 @@ public:
     virtual void read( CvFileStorage* fs, CvFileNode* node );
     virtual void write( CvFileStorage* fs, const char* name );
     
+    // special read & write methods for trees in the tree ensembles
+    virtual void read( CvFileStorage* fs, CvFileNode* node,
+                       CvDTreeTrainData* data );
+    virtual void write( CvFileStorage* fs );
+    
     const CvDTreeNode* get_root() const;
     int get_pruned_tree_idx() const;
     CvDTreeTrainData* get_data();
@@ -893,8 +798,6 @@ protected:
     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
-    virtual void write_train_data_params( CvFileStorage* fs );
-    virtual void read_train_data_params( CvFileStorage* fs, CvFileNode* node );
     virtual void write_tree_nodes( CvFileStorage* fs );
     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
 
@@ -921,11 +824,7 @@ public:
 
     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
     virtual int get_var_count() const {return data ? data->var_count : 0;}
-    virtual void share_data( bool share ) { data->shared = share; }
-    // if _data == 0, then it will be read from file,
-    // otherwise this->data will be assigned to _data
-    virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData** _data = 0 );
-    virtual void write( CvFileStorage* fs, const char* name, bool write_train_data_params = false );
+    virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
 
 protected:
     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
@@ -993,10 +892,11 @@ public:
 
 protected:
 
-    bool grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria term_crit );
+    bool grow_forest( const CvTermCriteria term_crit );
 
     // array of the trees of the forest
     CvForestTree** trees;
+    CvDTreeTrainData* data;
     int ntrees;
     int nclasses;
     double oob_error;
@@ -1008,6 +908,120 @@ protected:
     CvMat* active_var_mask;
 };
 
+
+/****************************************************************************************\
+*                                   Boosted trees models                                 *
+\****************************************************************************************/
+
+struct CV_EXPORTS CvBoostParams : public CvDTreeParams
+{
+    int boost_type;
+    int weak_count;
+    int split_criteria;
+    double weight_trim_rate;
+
+    CvBoostParams();
+    CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
+                   int max_depth, bool use_surrogates, const float* priors );
+};
+
+
+class CvBoost;
+
+class CV_EXPORTS CvBoostTree: public CvDTree
+{
+public:
+    CvBoostTree();
+    virtual ~CvBoostTree();
+
+    virtual bool train( CvDTreeTrainData* _train_data,
+                        const CvMat* subsample_idx, CvBoost* ensemble );
+    virtual void scale( double s );
+    virtual void read( CvFileStorage* fs, CvFileNode* node,
+                       CvBoost* ensemble, CvDTreeTrainData* _data );
+    virtual void clear();
+
+protected:
+    
+    virtual void try_split_node( CvDTreeNode* n );
+    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
+    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
+    virtual void calc_node_value( CvDTreeNode* n );
+    virtual double calc_node_dir( CvDTreeNode* n );
+
+    CvBoost* ensemble;
+};
+
+
+class CV_EXPORTS CvBoost : public CvStatModel
+{
+public:
+    // Boosting type
+    enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
+    
+    // Splitting criteria
+    enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
+
+    CvBoost();
+    virtual ~CvBoost();
+
+    CvBoost( const CvMat* _train_data, int _tflag,
+             const CvMat* _responses, const CvMat* _var_idx=0,
+             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
+             const CvMat* _missing_mask=0,
+             CvBoostParams params=CvBoostParams() );
+        
+    virtual bool train( const CvMat* _train_data, int _tflag,
+             const CvMat* _responses, const CvMat* _var_idx=0,
+             const CvMat* _sample_idx=0, const CvMat* _var_type=0,
+             const CvMat* _missing_mask=0,
+             CvBoostParams params=CvBoostParams(),
+             bool update=false );
+
+    virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
+                           CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
+                           bool raw_mode=false ) const;
+
+    virtual void prune( CvSlice slice );
+
+    virtual void clear();
+
+    virtual void write( CvFileStorage* storage, const char* name );
+    virtual void read( CvFileStorage* storage, CvFileNode* node );
+
+    CvSeq* get_weak_predictors();
+
+    CvMat* get_weights();
+    CvMat* get_subtree_weights();
+    CvMat* get_weak_response();
+    const CvBoostParams& get_params() const;
+
+protected:
+
+    virtual bool set_params( const CvBoostParams& _params );
+    virtual void update_weights( CvBoostTree* tree );
+    virtual void trim_weights();
+    virtual void write_params( CvFileStorage* fs );
+    virtual void read_params( CvFileStorage* fs, CvFileNode* node );
+
+    CvDTreeTrainData* data;
+    CvBoostParams params;    
+    CvSeq* weak;
+    
+    CvMat* orig_response;
+    CvMat* sum_response;
+    CvMat* weak_eval;
+    CvMat* subsample_mask;
+    CvMat* weights;
+    CvMat* subtree_weights;
+    bool have_subsample;
+};
+
+
 /****************************************************************************************\
 *                              Artificial Neural Networks (ANN)                          *
 \****************************************************************************************/
index 28d26a8e69f3035174a25c6d4e4d331d4bf5ccce..311fba75fdfaefe4886e20a3cd557aaa866f6f1b 100644 (file)
 
 #include "_ml.h"
 
-/* End of file. */
+static inline double
+log_ratio( double val )
+{
+    const double eps = 1e-5;
+    
+    val = MAX( val, eps );
+    val = MIN( val, 1. - eps );
+    return log( val/(1. - val) );
+}
+
+
+CvBoostParams::CvBoostParams()
+{
+    boost_type = CvBoost::REAL;
+    weak_count = 100;
+    weight_trim_rate = 0.95;
+    cv_folds = 0;
+    max_depth = 1;
+}
+
+
+CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
+                                        double _weight_trim_rate, int _max_depth,
+                                        bool _use_surrogates, const float* _priors )
+{
+    boost_type = _boost_type;
+    weak_count = _weak_count;
+    weight_trim_rate = _weight_trim_rate;
+    split_criteria = CvBoost::DEFAULT;
+    cv_folds = 0;
+    max_depth = _max_depth;
+    use_surrogates = _use_surrogates;
+    priors = _priors;
+}
+
+
+
+///////////////////////////////// CvBoostTree ///////////////////////////////////
+
+CvBoostTree::CvBoostTree()
+{
+    ensemble = 0;
+}
+
+
+CvBoostTree::~CvBoostTree()
+{
+    clear();
+}
+
+
+void
+CvBoostTree::clear()
+{
+    CvDTree::clear();
+    ensemble = 0;
+}
+
+
+bool
+CvBoostTree::train( CvDTreeTrainData* _train_data,
+                    const CvMat* _subsample_idx, CvBoost* _ensemble )
+{
+    clear();
+    ensemble = _ensemble;
+    data = _train_data;
+    data->shared = true;
+    
+    return do_train( _subsample_idx );
+}
+
+
+void
+CvBoostTree::scale( double scale )
+{
+    CvDTreeNode* node = root;
+
+    // traverse the tree and scale all the node values
+    for(;;)
+    {
+        CvDTreeNode* parent;
+        for(;;)
+        {
+            node->value *= scale;
+            if( !node->left )
+                break;
+            node = node->left;
+        }
+        
+        for( parent = node->parent; parent && parent->right == node;
+            node = parent, parent = parent->parent )
+            ;
+
+        if( !parent )
+            break;
+
+        node = parent->right;
+    }
+}
+
+
+void
+CvBoostTree::try_split_node( CvDTreeNode* node )
+{
+    CvDTree::try_split_node( node );
+
+    if( !node->left )
+    {
+        // if the node has not been split,
+        // store the responses for the corresponding training samples
+        double* weak_eval = ensemble->get_weak_response()->data.db;
+        int* labels = data->get_labels( node );
+        int i, count = node->sample_count;
+        double value = node->value;
+
+        for( i = 0; i < count; i++ )
+            weak_eval[labels[i]] = value;
+    }
+}
+
+
+double
+CvBoostTree::calc_node_dir( CvDTreeNode* node )
+{
+    char* dir = (char*)data->direction->data.ptr;
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    int i, n = node->sample_count, vi = node->split->var_idx;
+    double L, R;
+
+    assert( !node->split->inversed );
+
+    if( data->get_var_type(vi) >= 0 ) // split on categorical var
+    {
+        const int* cat_labels = data->get_cat_var_data( node, vi );
+        const int* subset = node->split->subset;
+        double sum = 0, sum_abs = 0;
+
+        for( i = 0; i < n; i++ )
+        {
+            int idx = cat_labels[i];
+            double w = weights[i];
+            int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
+            sum += d*w; sum_abs += (d & 1)*w;
+            dir[i] = (char)d;
+        }
+
+        R = (sum_abs + sum) * 0.5;
+        L = (sum_abs - sum) * 0.5;
+    }
+    else // split on ordered var
+    {
+        const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
+        int split_point = node->split->ord.split_point;
+        int n1 = node->get_num_valid(vi);
+
+        assert( 0 <= split_point && split_point < n1-1 );
+        L = R = 0;
+
+        for( i = 0; i <= split_point; i++ )
+        {
+            int idx = sorted[i].i;
+            double w = weights[idx];
+            dir[idx] = (char)-1;
+            L += w;
+        }
+
+        for( ; i < n1; i++ )
+        {
+            int idx = sorted[i].i;
+            double w = weights[idx];
+            dir[idx] = (char)1;
+            R += w;
+        }
+
+        for( ; i < n; i++ )
+            dir[sorted[i].i] = (char)0;
+    }
+
+    node->maxlr = MAX( L, R );
+    return node->split->quality/(L + R);
+}
+
+
+CvDTreeSplit*
+CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
+{
+    const float epsilon = FLT_EPSILON*2;
+    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
+    const int* responses = data->get_class_labels(node);
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    int n = node->sample_count;
+    int n1 = node->get_num_valid(vi);
+    const double* rcw0 = weights + n;
+    double lcw[2] = {0,0}, rcw[2];
+    int i, best_i = -1;
+    double best_val = 0;
+    int boost_type = ensemble->get_params().boost_type;
+    int split_criteria = ensemble->get_params().split_criteria;
+
+    rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
+    for( i = n1; i < n; i++ )
+    {
+        int idx = sorted[i].i;
+        double w = weights[idx];
+        rcw[responses[idx]] -= w;
+    }
+
+    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
+        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
+
+    if( split_criteria == CvBoost::GINI )
+    {
+        double L = 0, R = rcw[0] + rcw[1];
+        double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
+
+        for( i = 0; i < n1 - 1; i++ )
+        {
+            int idx = sorted[i].i;
+            double w = weights[idx], w2 = w*w;
+            double lv, rv;
+            idx = responses[idx];
+            L += w; R -= w;
+            lv = lcw[idx]; rv = rcw[idx];
+            lsum2 += 2*lv*w + w2;
+            rsum2 -= 2*rv*w - w2;
+            lcw[idx] = lv + w; rcw[idx] = rv - w;
+
+            if( sorted[i].val + epsilon < sorted[i+1].val )
+            {
+                double val = (lsum2*R + rsum2*L)/(L*R);
+                if( best_val < val )
+                {
+                    best_val = val;
+                    best_i = i;
+                }
+            }
+        }
+    }
+    else
+    {
+        for( i = 0; i < n1 - 1; i++ )
+        {
+            int idx = sorted[i].i;
+            double w = weights[idx];
+            idx = responses[idx];
+            lcw[idx] += w;
+            rcw[idx] -= w;
+
+            if( sorted[i].val + epsilon < sorted[i+1].val )
+            {
+                double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
+                val = MAX(val, val2);
+                if( best_val < val )
+                {
+                    best_val = val;
+                    best_i = i;
+                }
+            }
+        }
+    }
+
+    return best_i >= 0 ? data->new_split_ord( vi,
+        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
+        0, (float)best_val ) : 0;
+}
+
+
+#define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
+static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
+
+CvDTreeSplit*
+CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
+{
+    CvDTreeSplit* split;
+    const int* cat_labels = data->get_cat_var_data(node, vi);
+    const int* responses = data->get_class_labels(node);
+    int ci = data->get_var_type(vi);
+    int n = node->sample_count;
+    int mi = data->cat_count->data.i[ci];
+    double lcw[2]={0,0}, rcw[2]={0,0};
+    double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
+    int i, j, k, idx;
+    double L = 0, R;
+    double best_val = 0;
+    int best_subset = -1, subset_i;
+    int boost_type = ensemble->get_params().boost_type;
+    int split_criteria = ensemble->get_params().split_criteria;
+
+    // init array of counters:
+    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
+    for( j = -1; j < mi; j++ )
+        cjk[j*2] = cjk[j*2+1] = 0;
+
+    for( i = 0; i < n; i++ )
+    {
+        double w = weights[i];
+        j = cat_labels[i];
+        k = responses[i];
+        cjk[j*2 + k] += w;
+    }
+
+    for( j = 0; j < mi; j++ )
+    {
+        rcw[0] += cjk[j*2];
+        rcw[1] += cjk[j*2+1];
+        dbl_ptr[j] = cjk + j*2 + 1;
+    }
+
+    R = rcw[0] + rcw[1];
+
+    if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
+        split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
+
+    // sort rows of c_jk by increasing c_j,1
+    // (i.e. by the weight of samples in j-th category that belong to class 1)
+    icvSortDblPtr( dbl_ptr, mi, 0 );
+
+    for( subset_i = 0; subset_i < mi-1; subset_i++ )
+    {
+        idx = (int)(dbl_ptr[subset_i] - cjk)/2;
+        const double* crow = cjk + idx*2;
+        double w0 = crow[0], w1 = crow[1];
+        double weight = w0 + w1;
+
+        if( weight < FLT_EPSILON )
+            continue;
+
+        lcw[0] += w0; rcw[0] -= w0;
+        lcw[1] += w1; rcw[1] -= w1;
+
+        if( split_criteria == CvBoost::GINI )
+        {
+            double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
+            double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
+        
+            L += weight;
+            R -= weight;
+
+            if( L > FLT_EPSILON && R > FLT_EPSILON )
+            {
+                double val = (lsum2*R + rsum2*L)/(L*R);
+                if( best_val < val )
+                {
+                    best_val = val;
+                    best_subset = subset_i;
+                }
+            }
+        }
+        else
+        {
+            double val = lcw[0] + rcw[1];
+            double val2 = lcw[1] + rcw[0];
+
+            val = MAX(val, val2);
+            if( best_val < val )
+            {
+                best_val = val;
+                best_subset = subset_i;
+            }
+        }
+    }
+
+    if( best_subset < 0 )
+        return 0;
+
+    split = data->new_split_cat( vi, (float)best_val );
+
+    for( i = 0; i <= best_subset; i++ )
+    {
+        idx = (int)(dbl_ptr[i] - cjk) >> 1;
+        split->subset[idx >> 5] |= 1 << (idx & 31);
+    }
+
+    return split;
+}
+
+
+CvDTreeSplit*
+CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
+{
+    const float epsilon = FLT_EPSILON*2;
+    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
+    const float* responses = data->get_ord_responses(node);
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    int n = node->sample_count;
+    int n1 = node->get_num_valid(vi);
+    int i, best_i = -1;
+    double best_val = 0, lsum = 0, rsum = node->value*n;
+    double L = 0, R = weights[n];
+
+    // compensate for missing values
+    for( i = n1; i < n; i++ )
+    {
+        int idx = sorted[i].i;
+        double w = weights[idx];
+        rsum -= responses[idx]*w;
+        R -= w;
+    }
+
+    // find the optimal split
+    for( i = 0; i < n1 - 1; i++ )
+    {
+        int idx = sorted[i].i;
+        double w = weights[idx];
+        double t = responses[idx]*w;
+        L += w; R -= w;
+        lsum += t; rsum -= t;
+
+        if( sorted[i].val + epsilon < sorted[i+1].val )
+        {
+            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
+            if( best_val < val )
+            {
+                best_val = val;
+                best_i = i;
+            }
+        }
+    }
+
+    return best_i >= 0 ? data->new_split_ord( vi,
+        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
+        0, (float)best_val ) : 0;
+}
+
+
+CvDTreeSplit*
+CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
+{
+    CvDTreeSplit* split;
+    const int* cat_labels = data->get_cat_var_data(node, vi);
+    const float* responses = data->get_ord_responses(node);
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    int ci = data->get_var_type(vi);
+    int n = node->sample_count;
+    int mi = data->cat_count->data.i[ci];
+    double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
+    double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
+    double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
+    double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
+    int i, best_subset = -1, subset_i;
+
+    for( i = -1; i < mi; i++ )
+        sum[i] = counts[i] = 0;
+
+    // calculate sum response and weight of each category of the input var
+    for( i = 0; i < n; i++ )
+    {
+        int idx = cat_labels[i];
+        double w = weights[i];
+        double s = sum[idx] + responses[i]*w;
+        double nc = counts[idx] + w;
+        sum[idx] = s;
+        counts[idx] = nc;
+    }
+
+    // calculate average response in each category
+    for( i = 0; i < mi; i++ )
+    {
+        R += counts[i];
+        rsum += sum[i];
+        sum[i] /= counts[i];
+        sum_ptr[i] = sum + i;
+    }
+
+    icvSortDblPtr( sum_ptr, mi, 0 );
+
+    // revert back to unnormalized sums
+    // (there should be a very little loss in accuracy)
+    for( i = 0; i < mi; i++ )
+        sum[i] *= counts[i];
+
+    for( subset_i = 0; subset_i < mi-1; subset_i++ )
+    {
+        int idx = (int)(sum_ptr[subset_i] - sum);
+        double ni = counts[idx];
+
+        if( ni > FLT_EPSILON )
+        {
+            double s = sum[idx];
+            lsum += s; L += ni;
+            rsum -= s; R -= ni;
+            
+            if( L > FLT_EPSILON && R > FLT_EPSILON )
+            {
+                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
+                if( best_val < val )
+                {
+                    best_val = val;
+                    best_subset = subset_i;
+                }
+            }
+        }
+    }
+
+    if( best_subset < 0 )
+        return 0;
+
+    split = data->new_split_cat( vi, (float)best_val );
+    for( i = 0; i <= best_subset; i++ )
+    {
+        int idx = (int)(sum_ptr[i] - sum);
+        split->subset[idx >> 5] |= 1 << (idx & 31);
+    }
+
+    return split;
+}
+
+
+CvDTreeSplit*
+CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
+{
+    const float epsilon = FLT_EPSILON*2;
+    const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    const char* dir = (char*)data->direction->data.ptr;
+    int n1 = node->get_num_valid(vi);
+    // LL - number of samples that both the primary and the surrogate splits send to the left
+    // LR - ... primary split sends to the left and the surrogate split sends to the right
+    // RL - ... primary split sends to the right and the surrogate split sends to the left
+    // RR - ... both send to the right
+    int i, best_i = -1, best_inversed = 0;
+    double best_val; 
+    double LL = 0, RL = 0, LR, RR;
+    double worst_val = node->maxlr;
+    double sum = 0, sum_abs = 0;
+    best_val = worst_val;
+    
+    for( i = 0; i < n1; i++ )
+    {
+        int idx = sorted[i].i;
+        double w = weights[idx];
+        int d = dir[idx];
+        sum += d*w; sum_abs += (d & 1)*w;
+    }
+
+    // sum_abs = R + L; sum = R - L
+    RR = (sum_abs + sum)*0.5;
+    LR = (sum_abs - sum)*0.5;
+
+    // initially all the samples are sent to the right by the surrogate split,
+    // LR of them are sent to the left by primary split, and RR - to the right.
+    // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
+    for( i = 0; i < n1 - 1; i++ )
+    {
+        int idx = sorted[i].i;
+        double w = weights[idx];
+        int d = dir[idx];
+
+        if( d < 0 )
+        {
+            LL += w; LR -= w;
+            if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
+            {
+                best_val = LL + RR;
+                best_i = i; best_inversed = 0;
+            }
+        }
+        else if( d > 0 )
+        {
+            RL += w; RR -= w;
+            if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
+            {
+                best_val = RL + LR;
+                best_i = i; best_inversed = 1;
+            }
+        }
+    }
+
+    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
+        (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
+        best_inversed, (float)best_val ) : 0;
+}
+
+
+CvDTreeSplit*
+CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
+{
+    const int* cat_labels = data->get_cat_var_data(node, vi);
+    const char* dir = (char*)data->direction->data.ptr;
+    const double* weights = ensemble->get_subtree_weights()->data.db;
+    int n = node->sample_count;
+    // LL - number of samples that both the primary and the surrogate splits send to the left
+    // LR - ... primary split sends to the left and the surrogate split sends to the right
+    // RL - ... primary split sends to the right and the surrogate split sends to the left
+    // RR - ... both send to the right
+    CvDTreeSplit* split = data->new_split_cat( vi, 0 );
+    int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
+    double best_val = 0;
+    double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
+    double* rc = lc + mi + 1;
+    
+    for( i = -1; i < mi; i++ )
+        lc[i] = rc[i] = 0;
+
+    // 1. for each category calculate the weight of samples
+    // sent to the left (lc) and to the right (rc) by the primary split
+    for( i = 0; i < n; i++ )
+    {
+        int idx = cat_labels[i];
+        double w = weights[i];
+        int d = dir[i];
+        double sum = lc[idx] + d*w;
+        double sum_abs = rc[idx] + (d & 1)*w;
+        lc[idx] = sum; rc[idx] = sum_abs;
+    }
+
+    for( i = 0; i < mi; i++ )
+    {
+        double sum = lc[i];
+        double sum_abs = rc[i];
+        lc[i] = (sum_abs - sum) * 0.5;
+        rc[i] = (sum_abs + sum) * 0.5;
+    }
+
+    // 2. now form the split.
+    // in each category send all the samples to the same direction as majority
+    for( i = 0; i < mi; i++ )
+    {
+        double lval = lc[i], rval = rc[i];
+        if( lval > rval )
+        {
+            split->subset[i >> 5] |= 1 << (i & 31);
+            best_val += lval;
+        }
+        else
+            best_val += rval;
+    }
+
+    split->quality = (float)best_val;
+    if( split->quality <= node->maxlr )
+        cvSetRemoveByPtr( data->split_heap, split ), split = 0;
+
+    return split;
+}
+
+
+void
+CvBoostTree::calc_node_value( CvDTreeNode* node )
+{
+    int i, count = node->sample_count;
+    const double* weights = ensemble->get_weights()->data.db;
+    const int* labels = data->get_labels(node);
+    double* subtree_weights = ensemble->get_subtree_weights()->data.db;
+    double rcw[2] = {0,0};
+    int boost_type = ensemble->get_params().boost_type;
+
+    if( data->is_classifier )
+    {
+        const int* responses = data->get_class_labels(node);
+        
+        for( i = 0; i < count; i++ )
+        {
+            int idx = labels[i];
+            double w = weights[idx];
+            rcw[responses[i]] += w;
+            subtree_weights[i] = w;
+        }
+
+        node->class_idx = rcw[1] > rcw[0];
+
+        if( boost_type == CvBoost::DISCRETE )
+        {
+            // ignore cat_map for responses, and use {-1,1},
+            // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
+            node->value = node->class_idx*2 - 1;
+        }
+        else
+        {
+            double p = rcw[1]/(rcw[0] + rcw[1]);
+            assert( boost_type == CvBoost::REAL );
+            
+            // store log-ratio of the probability
+            node->value = 0.5*log_ratio(p);
+        }
+    }
+    else
+    {
+        // in case of regression tree:
+        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
+        //    n is the number of samples in the node.
+        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
+        double sum = 0, sum2 = 0, iw;
+        const float* values = data->get_ord_responses(node);
+        
+        for( i = 0; i < count; i++ )
+        {
+            int idx = labels[i];
+            double w = weights[idx];
+            double t = values[i];
+            rcw[0] += w;
+            subtree_weights[i] = w;
+            sum += t*w;
+            sum2 += t*t*w;
+        }
+
+        iw = 1./rcw[0];
+        node->value = sum*iw;
+        node->node_risk = sum2 - (sum*iw)*sum;
+        
+        // renormalize the risk, as in try_split_node the unweighted formula
+        // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
+        node->node_risk *= count*iw*count*iw;
+    }
+
+    // store summary weights
+    subtree_weights[count] = rcw[0];
+    subtree_weights[count+1] = rcw[1];
+}
+
+
+void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
+{
+    CvDTree::read( fs, fnode, _data );
+    ensemble = _ensemble;
+}
+
+
+/////////////////////////////////// CvBoost /////////////////////////////////////
+
+CvBoost::CvBoost()
+{
+    data = 0;
+    weak = 0;
+    default_model_name = "my_boost_tree";
+    orig_response = sum_response = weak_eval = subsample_mask =
+        weights = subtree_weights = 0;
+
+    clear();
+}
+
+
+void CvBoost::prune( CvSlice slice )
+{
+    if( weak )
+    {
+        CvSeqReader reader;
+        int i, count = cvSliceLength( slice, weak );
+        
+        cvStartReadSeq( weak, &reader );
+        cvSetSeqReaderPos( &reader, slice.start_index );
+
+        for( i = 0; i < count; i++ )
+        {
+            CvBoostTree* w;
+            CV_READ_SEQ_ELEM( w, reader );
+            delete w;
+        }
+
+        cvSeqRemoveSlice( weak, slice );
+    }
+}
+
+
+void CvBoost::clear()
+{
+    if( weak )
+    {
+        prune( CV_WHOLE_SEQ );
+        cvReleaseMemStorage( &weak->storage );
+    }
+    if( data )
+        delete data;
+    weak = 0;
+    data = 0;
+    cvReleaseMat( &orig_response );
+    cvReleaseMat( &sum_response );
+    cvReleaseMat( &weak_eval );
+    cvReleaseMat( &subsample_mask );
+    cvReleaseMat( &weights );
+    have_subsample = false;
+}
+
+
+CvBoost::~CvBoost()
+{
+    clear();
+}
+
+
+CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
+                  const CvMat* _responses, const CvMat* _var_idx,
+                  const CvMat* _sample_idx, const CvMat* _var_type,
+                  const CvMat* _missing_mask, CvBoostParams _params )
+{
+    weak = 0;
+    data = 0;
+    default_model_name = "my_boost_tree";
+    orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
+
+    train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
+           _var_type, _missing_mask, _params );
+}
+
+
+bool
+CvBoost::set_params( const CvBoostParams& _params )
+{
+    bool ok = false;
+    
+    CV_FUNCNAME( "CvBoost::set_params" );
+
+    __BEGIN__;
+
+    params = _params;
+    if( params.boost_type != DISCRETE && params.boost_type != REAL &&
+        params.boost_type != LOGIT && params.boost_type != GENTLE )
+        CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
+
+    params.weak_count = MAX( params.weak_count, 1 );
+    params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
+    params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
+    if( params.weight_trim_rate < FLT_EPSILON )
+        params.weight_trim_rate = 1.f;
+
+    if( params.boost_type == DISCRETE &&
+        params.split_criteria != GINI && params.split_criteria != MISCLASS )
+        params.split_criteria = MISCLASS;
+    if( params.boost_type == REAL &&
+        params.split_criteria != GINI && params.split_criteria != MISCLASS )
+        params.split_criteria = GINI;
+    if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
+        params.split_criteria != SQERR )
+        params.split_criteria = SQERR;
+    
+    ok = true;
+    
+    __END__;
+
+    return ok;
+}
+
+
+bool
+CvBoost::train( const CvMat* _train_data, int _tflag,
+              const CvMat* _responses, const CvMat* _var_idx,
+              const CvMat* _sample_idx, const CvMat* _var_type,
+              const CvMat* _missing_mask,
+              CvBoostParams _params, bool _update )
+{
+    bool ok = false;
+    CvMemStorage* storage = 0;
+
+    CV_FUNCNAME( "CvBoost::train" );
+
+    __BEGIN__;
+
+    int i;
+
+    set_params( _params );
+
+    if( !_update || !data )
+    {
+        clear();
+        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
+            _sample_idx, _var_type, _missing_mask, _params, true, true );
 
+        if( data->get_num_classes() != 2 )
+            CV_ERROR( CV_StsNotImplemented,
+            "Boosted trees can only be used for 2-class classification." );
+        CV_CALL( storage = cvCreateMemStorage() );
+        weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
+        storage = 0;
+    }
+    else
+    {
+        data->set_data( _train_data, _tflag, _responses, _var_idx,
+            _sample_idx, _var_type, _missing_mask, _params, true, true, true );
+    }
+
+    update_weights( 0 );
+
+    for( i = 0; i < params.weak_count; i++ )
+    {
+        CvBoostTree* tree = new CvBoostTree;
+        if( !tree->train( data, subsample_mask, this ) )
+        {
+            delete tree;
+            continue;
+        }
+        cvSeqPush( weak, &tree );
+        update_weights( tree );
+        trim_weights();
+    }
+
+    data->is_classifier = true;
+    ok = true;
+
+    __END__;
+
+    return ok;
+}
+
+
+void
+CvBoost::update_weights( CvBoostTree* tree )
+{
+    CV_FUNCNAME( "CvBoost::update_weights" );
+
+    __BEGIN__;
+
+    int i, count = data->sample_count;
+
+    if( !tree ) // before training the first tree, initialize weights and other parameters
+    {
+        const int* class_labels = data->get_class_labels(data->data_root);
+        // in case of logitboost and gentle adaboost each weak tree is a regression tree,
+        // so we need to convert class labels to floating-point values
+        float* responses = data->get_ord_responses(data->data_root);
+        int* labels = data->get_labels(data->data_root);
+        double w0 = 1./count;
+        
+        cvReleaseMat( &orig_response );
+        cvReleaseMat( &sum_response );
+        cvReleaseMat( &weak_eval );
+        cvReleaseMat( &subsample_mask );
+        cvReleaseMat( &weights );
+
+        CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
+        CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
+        CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
+        CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
+        CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
+
+        for( i = 0; i < count; i++ )
+        {
+            // save original categorical responses {0,1}, convert them to {-1,1}
+            orig_response->data.i[i] = class_labels[i]*2 - 1;
+            // make all the samples active at start.
+            // later, in trim_weights() deactivate/reactive again some, if need
+            subsample_mask->data.ptr[i] = (uchar)1;
+            // make all the initial weights the same.
+            weights->data.db[i] = w0;
+            // set the labels to find (from within weak tree learning proc)
+            // the particular sample weight, and where to store the response.
+            labels[i] = i;
+        }
+
+        if( params.boost_type == LOGIT )
+        {
+            CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
+            
+            for( i = 0; i < count; i++ )
+            {
+                sum_response->data.db[i] = 0;
+                responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
+            }
+
+            // in case of logitboost each weak tree is a regression tree.
+            // the target function values are recalculated for each of the trees
+            data->is_classifier = false;
+        }
+        else if( params.boost_type == GENTLE )
+        {
+            for( i = 0; i < count; i++ )
+                responses[i] = (float)orig_response->data.i[i];
+
+            data->is_classifier = false;
+        }
+    }
+    else
+    {
+        double sumw = 0.;
+        
+        // at this moment, for all the samples that participated in the training of the most
+        // recent weak classifier we know the responses. For other samples we need to compute them
+        if( have_subsample )
+        {
+            float* values = (float*)(data->buf->data.ptr + data->buf->step);
+            uchar* missing = data->buf->data.ptr + data->buf->step*2;
+            CvMat _sample, _mask;
+
+            // invert the subsample mask
+            cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
+            data->get_vectors( subsample_mask, values, missing, 0 );
+            //data->get_vectors( 0, values, missing, 0 );
+
+            _sample = cvMat( 1, data->var_count, CV_32F );
+            _mask = cvMat( 1, data->var_count, CV_8U );
+
+            // run tree through all the non-processed samples
+            for( i = 0; i < count; i++ )
+                if( subsample_mask->data.ptr[i] )
+                {
+                    _sample.data.fl = values;
+                    _mask.data.ptr = missing;
+                    values += _sample.cols;
+                    missing += _mask.cols;
+                    weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
+                }
+        }
+
+        // now update weights and other parameters for each type of boosting
+        if( params.boost_type == DISCRETE )
+        {
+            // Discrete AdaBoost:
+            //   weak_eval[i] (=f(x_i)) is in {-1,1}
+            //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
+            //   C = log((1-err)/err)
+            //   w_i *= exp(C*(f(x_i) != y_i))
+            
+            double C, err = 0.;
+            double scale[] = { 1., 0. };
+
+            for( i = 0; i < count; i++ )
+            {
+                double w = weights->data.db[i];
+                sumw += w;
+                err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
+            }
+            
+            if( sumw != 0 )
+                err /= sumw;
+            C = err = -log_ratio( err );
+            scale[1] = exp(err);
+    
+            sumw = 0;
+            for( i = 0; i < count; i++ )
+            {
+                double w = weights->data.db[i]*
+                    scale[weak_eval->data.db[i] != orig_response->data.i[i]];
+                sumw += w;
+                weights->data.db[i] = w;
+            }
+
+            tree->scale( C );
+        }
+        else if( params.boost_type == REAL )
+        {
+            // Real AdaBoost:
+            //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
+            //   w_i *= exp(-y_i*f(x_i))
+            
+            for( i = 0; i < count; i++ )
+                weak_eval->data.db[i] *= -orig_response->data.i[i];
+
+            cvExp( weak_eval, weak_eval );
+
+            for( i = 0; i < count; i++ )
+            {
+                double w = weights->data.db[i]*weak_eval->data.db[i];
+                sumw += w;
+                weights->data.db[i] = w;
+            }
+        }
+        else if( params.boost_type == LOGIT )
+        {
+            // LogitBoost:
+            //   weak_eval[i] = f(x_i) in [-z_max,z_max]
+            //   sum_response = F(x_i).
+            //   F(x_i) += 0.5*f(x_i)
+            //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i)))
+            //   reuse weak_eval: weak_eval[i] <- p(x_i)
+            //   w_i = p(x_i)*1(1 - p(x_i))
+            //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
+            //   store z_i to the data->data_root as the new target responses
+
+            const double lb_weight_thresh = 1e-4;
+            const double lb_z_max = 3.;
+            float* responses = data->get_ord_responses(data->data_root);
+
+            for( i = 0; i < count; i++ )
+            {
+                double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
+                sum_response->data.db[i] = s;
+                weak_eval->data.db[i] = -2*s;
+            }
+
+            cvExp( weak_eval, weak_eval );
+            
+            for( i = 0; i < count; i++ )
+            {
+                double p = 1./(1. + weak_eval->data.db[i]);
+                double w = p*(1 - p), z;
+                w = MAX( w, lb_weight_thresh );
+                weights->data.db[i] = (float)w;
+                sumw += w;
+                if( orig_response->data.i[i] > 0 )
+                {
+                    z = 1./p;
+                    responses[i] = (float)MAX(z, lb_z_max);
+                }
+                else
+                {
+                    z = 1./(1-p);
+                    responses[i] = (float)-MAX(z, lb_z_max);
+                }
+            }
+        }
+        else
+        {
+            // Gentle AdaBoost:
+            //   weak_eval[i] = f(x_i) in [-1,1]
+            //   w_i *= exp(-y_i*f(x_i))
+            assert( params.boost_type == GENTLE );
+            
+            for( i = 0; i < count; i++ )
+                weak_eval->data.db[i] *= -orig_response->data.i[i];
+
+            cvExp( weak_eval, weak_eval );
+
+            for( i = 0; i < count; i++ )
+            {
+                double w = weights->data.db[i] * weak_eval->data.db[i];
+                weights->data.db[i] = w;
+                sumw += w;
+            }
+        }
+
+        // renormalize weights
+        if( sumw > FLT_EPSILON )
+        {
+            sumw = 1./sumw;
+            for( i = 0; i < count; ++i )
+                weights->data.db[i] *= sumw;
+        }
+    }
+
+    __END__;
+}
+
+
+static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
+
+
+void
+CvBoost::trim_weights()
+{
+    CV_FUNCNAME( "CvBoost::trim_weights" );
+
+    __BEGIN__;
+
+    int i, count = data->sample_count, nz_count = 0;
+    double sum, threshold;
+
+    if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
+        EXIT;
+
+    // use weak_eval as temporary buffer for sorted weights
+    cvCopy( weights, weak_eval );
+
+    icvSort_64f( weak_eval->data.db, count, 0 );
+
+    // as weight trimming occurs immediately after updating the weights,
+    // where they are renormalized, we assume that the weight sum = 1.
+    sum = 1. - params.weight_trim_rate;
+
+    for( i = 0; i < count; i++ )
+    {
+        double w = weak_eval->data.db[i];
+        if( sum > w )
+            break;
+        sum -= w;
+    }
+
+    threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
+
+    for( i = 0; i < count; i++ )
+    {
+        double w = weights->data.db[i];
+        int f = w > threshold;
+        subsample_mask->data.ptr[i] = (uchar)f;
+        nz_count += f;
+    }
+
+    have_subsample = nz_count < count;
+
+    __END__;
+}
+
+
+float
+CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
+                  CvMat* weak_responses, CvSlice slice,
+                  bool raw_mode ) const
+{
+    float* buf = 0;
+    bool allocated = false;
+    float value = -FLT_MAX;
+    
+    CV_FUNCNAME( "CvBoost::predict" );
+
+    __BEGIN__;
+
+    int i, weak_count, var_count;
+    CvMat sample, missing;
+    CvSeqReader reader;
+    double sum = 0;
+    int cls_idx;
+    int wstep = 0;
+    const int* vtype;
+    const int* cmap;
+    const int* cofs;
+
+    if( !weak )
+        CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
+
+    if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
+        _sample->cols != 1 && _sample->rows != 1 ||
+        _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
+        _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
+            CV_ERROR( CV_StsBadArg,
+        "the input sample must be 1d floating-point vector with the same "
+        "number of elements as the total number of variables used for training" );
+
+    if( _missing )
+    {
+        if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
+            !CV_ARE_SIZES_EQ(_missing, _sample) )
+            CV_ERROR( CV_StsBadArg,
+            "the missing data mask must be 8-bit vector of the same size as input sample" );
+    }
+
+    weak_count = cvSliceLength( slice, weak );
+    if( weak_count >= weak->total )
+    {
+        weak_count = weak->total;
+        slice.start_index = 0;
+    }
+
+    if( weak_responses )
+    {
+        if( !CV_IS_MAT(weak_responses) ||
+            CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
+            weak_responses->cols != 1 && weak_responses->rows != 1 ||
+            weak_responses->cols + weak_responses->rows - 1 != weak_count )
+            CV_ERROR( CV_StsBadArg,
+            "The output matrix of weak classifier responses must be valid "
+            "floating-point vector of the same number of components as the length of input slice" );
+        wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
+    }
+
+    var_count = data->var_count;
+    vtype = data->var_type->data.i;
+    cmap = data->cat_map->data.i;
+    cofs = data->cat_ofs->data.i;
+
+    // if need, preprocess the input vector
+    if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
+    {
+        int bufsize;
+        int step, mstep = 0;
+        const float* src_sample;
+        const uchar* src_mask = 0;
+        float* dst_sample;
+        uchar* dst_mask;
+        const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
+        bool have_mask = _missing != 0;
+
+        bufsize = var_count*(sizeof(float) + sizeof(uchar));
+        if( bufsize <= CV_MAX_LOCAL_SIZE )
+            buf = (float*)cvStackAlloc( bufsize );
+        else
+        {
+            CV_CALL( buf = (float*)cvAlloc( bufsize ));
+            allocated = true;
+        }
+        dst_sample = buf;
+        dst_mask = (uchar*)(buf + var_count);
+
+        src_sample = _sample->data.fl;
+        step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
+
+        if( _missing )
+        {
+            src_mask = _missing->data.ptr;
+            mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
+        }
+
+        for( i = 0; i < var_count; i++ )
+        {
+            int idx = vidx ? vidx[i] : i;
+            float val = src_sample[idx*step];
+            int ci = vtype[i];
+            uchar m = src_mask ? src_mask[i] : (uchar)0;
+
+            if( ci >= 0 )
+            {
+                int a = cofs[ci], b = cofs[ci+1], c = a;
+                int ival = cvRound(val);
+                if( ival != val )
+                    CV_ERROR( CV_StsBadArg,
+                    "one of input categorical variable is not an integer" );
+
+                while( a < b )
+                {
+                    c = (a + b) >> 1;
+                    if( ival < cmap[c] )
+                        b = c;
+                    else if( ival > cmap[c] )
+                        a = c+1;
+                    else
+                        break;
+                }
+
+                if( c < 0 || ival != cmap[c] )
+                {
+                    m = 1;
+                    have_mask = true;
+                }
+                else
+                {
+                    val = (float)(c - cofs[ci]);
+                }
+            }
+
+            dst_sample[i] = val;
+            dst_mask[i] = m;
+        }
+
+        sample = cvMat( 1, var_count, CV_32F, dst_sample );
+        _sample = &sample;
+
+        if( have_mask )
+        {
+            missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
+            _missing = &missing;
+        }
+    }
+
+    cvStartReadSeq( weak, &reader );
+    cvSetSeqReaderPos( &reader, slice.start_index );
+
+    for( i = 0; i < weak_count; i++ )
+    {
+        CvBoostTree* wtree;
+        double val;
+
+        CV_READ_SEQ_ELEM( wtree, reader );
+
+        val = wtree->predict( _sample, _missing, true )->value;
+        if( weak_responses )
+            weak_responses->data.fl[i*wstep] = (float)val;
+
+        sum += val;
+    }
+
+    cls_idx = sum >= 0;
+    if( raw_mode )
+        value = (float)cls_idx;
+    else
+        value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
+
+    __END__;
+
+    if( allocated )
+        cvFree( &buf );
+
+    return value;
+}
+
+
+
+void CvBoost::write_params( CvFileStorage* fs )
+{
+    CV_FUNCNAME( "CvBoost::write_params" );
+
+    __BEGIN__;
+
+    const char* boost_type_str =
+        params.boost_type == DISCRETE ? "DiscreteAdaboost" :
+        params.boost_type == REAL ? "RealAdaboost" :
+        params.boost_type == LOGIT ? "LogitBoost" :
+        params.boost_type == GENTLE ? "GentleAdaboost" : 0;
+
+    const char* split_crit_str =
+        params.split_criteria == DEFAULT ? "Default" :
+        params.split_criteria == GINI ? "Gini" :
+        params.boost_type == MISCLASS ? "Misclassification" :
+        params.boost_type == SQERR ? "SquaredErr" : 0;
+
+    if( boost_type_str )
+        cvWriteString( fs, "boosting_type", boost_type_str );
+    else
+        cvWriteInt( fs, "boosting_type", params.boost_type );
+
+    if( split_crit_str )
+        cvWriteString( fs, "splitting_criteria", split_crit_str );
+    else
+        cvWriteInt( fs, "splitting_criteria", params.split_criteria );
+
+    cvWriteInt( fs, "ntrees", params.weak_count );
+    cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
+
+    data->write_params( fs );
+
+    __END__;
+}
+
+
+void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
+{
+    CV_FUNCNAME( "CvBoost::read_params" );
+
+    __BEGIN__;
+
+    CvFileNode* temp;
+
+    if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
+        return;
+
+    data = new CvDTreeTrainData();
+    CV_CALL( data->read_params(fs, fnode));
+    data->shared = true;
+
+    params.max_depth = data->params.max_depth;
+    params.min_sample_count = data->params.min_sample_count;
+    params.max_categories = data->params.max_categories;
+    params.priors = data->params.priors;
+    params.regression_accuracy = data->params.regression_accuracy;
+    params.use_surrogates = data->params.use_surrogates;
+
+    temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
+    if( !temp )
+        return;
+
+    if( temp && CV_NODE_IS_STRING(temp->tag) )
+    {
+        const char* boost_type_str = cvReadString( temp, "" );
+        params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
+                            strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
+                            strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
+                            strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
+    }
+    else
+        params.boost_type = cvReadInt( temp, -1 );
+
+    if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
+        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
+
+    temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
+    if( temp && CV_NODE_IS_STRING(temp->tag) )
+    {
+        const char* split_crit_str = cvReadString( temp, "" );
+        params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
+                                strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
+                                strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
+                                strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
+    }
+    else
+        params.split_criteria = cvReadInt( temp, -1 );
+
+    if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
+        CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
+
+    params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
+    params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
+
+    __END__;
+}
+
+
+
+void
+CvBoost::read( CvFileStorage* fs, CvFileNode* node )
+{
+    CV_FUNCNAME( "CvRTrees::read" );
+
+    __BEGIN__;
+
+    CvSeqReader reader;
+    CvFileNode* trees_fnode;
+    CvMemStorage* storage;
+    int i, ntrees;
+
+    clear();
+    read_params( fs, node );
+
+    if( !data )
+        EXIT;
+        
+    trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
+    if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
+        CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
+
+    cvStartReadSeq( trees_fnode->data.seq, &reader );
+    ntrees = trees_fnode->data.seq->total;
+
+    if( ntrees != params.weak_count )
+        CV_ERROR( CV_StsUnmatchedSizes,
+        "The number of trees stored does not match <ntrees> tag value" );
+
+    CV_CALL( storage = cvCreateMemStorage() );
+    weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
+
+    for( i = 0; i < ntrees; i++ )
+    {
+        CvBoostTree* tree = new CvBoostTree();
+        CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
+        CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
+        cvSeqPush( weak, &tree );
+    }
+
+    __END__;
+}
+
+
+void
+CvBoost::write( CvFileStorage* fs, const char* name )
+{
+    CV_FUNCNAME( "CvBoost::write" );
+
+    __BEGIN__;
+    
+    CvSeqReader reader;
+    int i;
+
+    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
+
+    if( !weak )
+        CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
+        
+    write_params( fs );
+    cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
+
+    cvStartReadSeq( weak, &reader );
+
+    for( i = 0; i < weak->total; i++ )
+    {
+        CvBoostTree* tree;
+        CV_READ_SEQ_ELEM( tree, reader );
+        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
+        tree->write( fs );
+        cvEndWriteStruct( fs );
+    }
+
+    cvEndWriteStruct( fs );
+    cvEndWriteStruct( fs );
+
+    __END__;
+}
+
+
+CvMat*
+CvBoost::get_weights()
+{
+    return weights;
+}
+
+
+CvMat*
+CvBoost::get_subtree_weights()
+{
+    return subtree_weights;
+}
+
+
+CvMat*
+CvBoost::get_weak_response()
+{
+    return weak_eval;
+}
+
+
+const CvBoostParams&
+CvBoost::get_params() const
+{
+    return params;
+}
+
+/* End of file. */
index b8a4cae424f22204ea29493eba587dba8ad8ec6b..d37d50314616349a510b7d0dd14ca523ffef724f 100644 (file)
@@ -141,55 +141,13 @@ CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
 }
 
 
-void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData** pdata )
+void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
 {
-    CV_FUNCNAME( "CvForestTree::read" );
-
-    __BEGIN__;
-
+    CvDTree::read( fs, fnode, _data );
     forest = _forest;
-
-    if( pdata && *pdata )
-        data = *pdata;
-    else
-    {
-        CV_CALL(read_train_data_params( fs, fnode ));
-        data->shared = true;
-        *pdata = data;
-    }
-
-    CV_CALL(pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -2 ));
-    if( pruned_tree_idx < -1 )
-        CV_ERROR( CV_StsParseError, "<best_tree_idx> is absent" );
-
-    CV_CALL(read_tree_nodes( fs, cvGetFileNodeByName( fs, fnode, "nodes" )));
-
-    __END__;
 }
 
 
-void CvForestTree::write( CvFileStorage* fs, const char* name, bool write_td_tp )
-{
-    CV_FUNCNAME( "CvForestTree::write" );
-
-    __BEGIN__;
-
-    cvStartWriteStruct( fs, name, CV_NODE_MAP );
-
-    if( write_td_tp )
-        CV_CALL(write_train_data_params( fs ));
-
-    cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
-
-    cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
-    CV_CALL(write_tree_nodes( fs ));
-    cvEndWriteStruct( fs ); // nodes
-
-    cvEndWriteStruct( fs ); // name
-
-    __END__;
-}
-
 //////////////////////////////////////////////////////////////////////////////////////////
 //                                  Random trees                                        //
 //////////////////////////////////////////////////////////////////////////////////////////
@@ -200,6 +158,7 @@ CvRTrees::CvRTrees()
     oob_error        = 0;
     ntrees           = 0;
     trees            = NULL;
+    data             = NULL;
     active_var_mask  = NULL;
     var_importance   = NULL;
     proximities      = NULL;
@@ -211,15 +170,12 @@ CvRTrees::CvRTrees()
 void CvRTrees::clear()
 {
     int k;
-    for( k = 1; k < ntrees; k++ )
+    for( k = 0; k < ntrees; k++ )
         delete trees[k];
+    cvFree( &trees );
 
-    if( trees && *trees )
-    {
-        trees[0]->share_data( false );
-        delete trees[0];
-        cvFree( &trees );
-    }
+    delete data;
+    data = 0;
 
     cvReleaseMat( &active_var_mask );
     cvReleaseMat( &var_importance );
@@ -250,7 +206,6 @@ bool CvRTrees::train( const CvMat* _train_data, int _tflag,
                         const CvMat* _missing_mask, CvRTParams params )
 {
     bool result = false;
-    CvDTreeTrainData* train_data = 0;
 
     CV_FUNCNAME("CvRTrees::train");
     __BEGIN__;
@@ -263,11 +218,11 @@ bool CvRTrees::train( const CvMat* _train_data, int _tflag,
         params.regression_accuracy, params.use_surrogates, params.max_categories,
         params.cv_folds, params.use_1se_rule, false, params.priors );
     
-    train_data = new CvDTreeTrainData();
-    CV_CALL(train_data->set_data( _train_data, _tflag, _responses, _var_idx,
+    data = new CvDTreeTrainData();
+    CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
         _sample_idx, _var_type, _missing_mask, tree_params, true));
 
-    var_count = train_data->var_count;
+    var_count = data->var_count;
     if( params.nactive_vars > var_count )
         params.nactive_vars = var_count;
     else if( params.nactive_vars == 0 )
@@ -285,7 +240,7 @@ bool CvRTrees::train( const CvMat* _train_data, int _tflag,
     }
     if( params.calc_proximities )
     {
-        const int n = train_data->sample_count;
+        const int n = data->sample_count;
         CV_CALL(proximities = cvCreateMat( 1, n*(n-1)/2, CV_32FC1) );
         cvZero( proximities );
     }
@@ -297,7 +252,7 @@ bool CvRTrees::train( const CvMat* _train_data, int _tflag,
         cvZero( &submask2 );
     }
 
-    CV_CALL(result = grow_forest(train_data, params.term_crit ));
+    CV_CALL(result = grow_forest( params.term_crit ));
 
     result = true;
 
@@ -307,7 +262,7 @@ bool CvRTrees::train( const CvMat* _train_data, int _tflag,
 }
 
 
-bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria term_crit )
+bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
 {
     bool result = false;
 
@@ -331,7 +286,7 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
     const int max_ntrees = term_crit.max_iter;
     const double max_oob_err = term_crit.epsilon;
     
-    const int dims = train_data->var_count;
+    const int dims = data->var_count;
     float maximal_response = 0;
 
     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
@@ -341,13 +296,13 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
 
-    nsamples = train_data->sample_count;
-    nclasses = train_data->get_num_classes();
+    nsamples = data->sample_count;
+    nclasses = data->get_num_classes();
 
     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
     memset( trees, 0, sizeof(trees[0])*max_ntrees );
 
-    if( train_data->is_classifier )
+    if( data->is_classifier )
     {
         CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
         cvZero(oob_sample_votes);
@@ -377,7 +332,7 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
         memset( predicted_nodes_ptr, 0, sizeof(CvDTreeNode*)*nsamples );
     }
 
-    CV_CALL(train_data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
+    CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
 
     {
         double minval, maxval;
@@ -390,8 +345,7 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
     while( ntrees < max_ntrees )
     {
         int i, j, oob_samples_count = 0;
-        float ncorrect_responses = 0; // used for estimation of variable importance
-        float* true_resp = 0;
+        double ncorrect_responses = 0; // used for estimation of variable importance
         CvMat sample, missing;
         CvDTreeNode** predicted_nodes = 0;
         CvForestTree* tree = 0;
@@ -406,17 +360,16 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
 
         trees[ntrees] = new CvForestTree();
         tree = trees[ntrees];
-        CV_CALL(tree->train( train_data, sample_idx_for_tree, this ));
+        CV_CALL(tree->train( data, sample_idx_for_tree, this ));
 
         // form array of OOB samples indices and get these samples
         sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
         missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
-        true_resp = true_resp_ptr;
         predicted_nodes = predicted_nodes_ptr;
 
         oob_error = 0;
         for( i = 0; i < nsamples; i++,
-            sample.data.fl += dims, missing.data.ptr += dims, true_resp++ )
+            sample.data.fl += dims, missing.data.ptr += dims )
         {
             CvDTreeNode* predicted_node = 0;
             if( proximities )
@@ -432,21 +385,22 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
             if( !predicted_node )
                 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
 
-            if( !train_data->is_classifier ) //regression
+            if( !data->is_classifier ) //regression
             {
-                float avg_resp, resp = (float)predicted_node->value;
-                oob_predictions_sum.data.fl[i] += resp;
+                double avg_resp, resp = predicted_node->value;
+                oob_predictions_sum.data.fl[i] += (float)resp;
                 oob_num_of_predictions.data.fl[i] += 1;
 
                 // compute oob error
-                avg_resp=oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
-                oob_error += powf(avg_resp - *true_resp, 2);
-
-                ncorrect_responses += expf( -powf((resp - *true_resp)/maximal_response, 2) );
+                avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
+                avg_resp -= true_resp_ptr[i];
+                oob_error += avg_resp*avg_resp;
+                resp = (resp - true_resp_ptr[i])/maximal_response;
+                ncorrect_responses += exp( -resp*resp );
             }
             else //classification
             {
-                float prdct_resp;
+                double prdct_resp;
                 CvPoint max_loc;
                 CvMat votes;
 
@@ -456,10 +410,10 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
                 // compute oob error
                 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
 
-                prdct_resp = (float)train_data->cat_map->data.i[max_loc.x];
-                oob_error += (fabs(prdct_resp - *true_resp) < FLT_EPSILON) ? 0 : 1;
+                prdct_resp = data->cat_map->data.i[max_loc.x];
+                oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
 
-                ncorrect_responses += ((int)predicted_node->value == (int)*true_resp);
+                ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
             }
             oob_samples_count++;
         }
@@ -474,7 +428,7 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
             memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
             for( m = 0; m < dims; m++ )
             {
-                float ncorrect_responses_permuted = 0;
+                double ncorrect_responses_permuted = 0;
                 // randomly permute values of the m-th variable in the oob samples
                 float* mth_var_ptr = oob_samples_perm_ptr + m;
 
@@ -502,18 +456,22 @@ bool CvRTrees::grow_forest( CvDTreeTrainData* train_data, const CvTermCriteria t
                 for( i = 0; i < nsamples; i++,
                     sample.data.fl += dims, missing.data.ptr += dims )
                 {
-                    float predct_resp, true_resp;
+                    double predct_resp, true_resp;
 
                     if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
                         continue;
 
-                    predct_resp = (float)tree->predict(&sample, &missing, true)->value;
+                    predct_resp = tree->predict(&sample, &missing, true)->value;
                     true_resp   = true_resp_ptr[i];
-                    ncorrect_responses_permuted += train_data->is_classifier ?
-                        (int)true_resp == (int)predct_resp 
-                        : expf( -powf((true_resp - predct_resp)/maximal_response, 2));
+                    if( data->is_classifier )
+                        ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
+                    else
+                    {
+                        true_resp = (true_resp - predct_resp)/maximal_response;
+                        ncorrect_responses_permuted += exp( -true_resp*true_resp );
+                    }
                 }
-                var_importance->data.fl[m] += (ncorrect_responses
+                var_importance->data.fl[m] += (float)(ncorrect_responses
                     - ncorrect_responses_permuted);
             }
         }
@@ -655,14 +613,18 @@ void CvRTrees::write( CvFileStorage* fs, const char* name )
 
     cvWriteInt( fs, "ntrees", ntrees );
 
+    CV_CALL(data->write_params( fs ));
+
     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
 
-    CV_CALL(trees[0]->write( fs, 0, true ));
-    for( k = 1; k < ntrees; k++ )
-       CV_CALL(trees[k]->write( fs, 0, false ));
+    for( k = 0; k < ntrees; k++ )
+    {
+        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
+        CV_CALL( trees[k]->write( fs ));
+        cvEndWriteStruct( fs );
+    }
 
     cvEndWriteStruct( fs ); //trees
-
     cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
 
     __END__;
@@ -678,7 +640,6 @@ void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
     int nactive_vars, var_count, k;
     CvSeqReader reader;
     CvFileNode* trees_fnode = 0;
-    CvDTreeTrainData* train_data = 0;
 
     clear();
 
@@ -700,6 +661,10 @@ void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
     memset( trees, 0, sizeof(trees[0])*ntrees );
 
+    data = new CvDTreeTrainData();
+    data->read_params( fs, fnode );
+    data->shared = true;
+
     trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
         CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
@@ -708,16 +673,18 @@ void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
     if( reader.seq->total != ntrees )
         CV_ERROR( CV_StsParseError,
         "<ntrees> is not equal to the number of trees saved in file" );
+
     for( k = 0; k < ntrees; k++ )
     {
         trees[k] = new CvForestTree();
-        CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, &train_data ));
+        CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data ));
         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
     }
 
-    var_count = trees[0]->get_var_count();
+    var_count = data->var_count;
     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
-    { // initialize active variables mask
+    {
+        // initialize active variables mask
         CvMat submask1, submask2;
         cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
         cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
index f623b45ec7a0906abcab9564feae2ddfddba4687..5c66717b4672b2e88ea0ab2458b4661ff00be72f 100644 (file)
@@ -57,15 +57,15 @@ CvDTreeTrainData::CvDTreeTrainData()
 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
                       const CvMat* _responses, const CvMat* _var_idx,
                       const CvMat* _sample_idx, const CvMat* _var_type,
-                      const CvMat* _missing_mask,
-                      CvDTreeParams _params, bool _shared, bool _add_weights )
+                      const CvMat* _missing_mask, const CvDTreeParams& _params,
+                      bool _shared, bool _add_labels )
 {
     var_idx = var_type = cat_count = cat_ofs = cat_map =
         priors = counts = buf = direction = split_buf = 0;
     tree_storage = temp_storage = 0;
     
     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
-              _var_type, _missing_mask, _params, _shared, _add_weights );
+              _var_type, _missing_mask, _params, _shared, _add_labels );
 }
 
 
@@ -124,13 +124,14 @@ static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
 
 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
-    const CvMat* _var_type, const CvMat* _missing_mask, CvDTreeParams _params,
-    bool _shared, bool _add_weights )
+    const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
+    bool _shared, bool _add_labels, bool _update_data )
 {
     CvMat* sample_idx = 0;
     CvMat* var_type0 = 0;
     CvMat* tmp_map = 0;
     int** int_ptr = 0;
+    CvDTreeTrainData* data = 0;
 
     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
 
@@ -144,6 +145,40 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
     char err[100];
     const int *sidx = 0, *vidx = 0;
 
+    if( _update_data && data_root )
+    {
+        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
+            _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
+        
+        // compare new and old train data
+        if( !(data->var_count == var_count &&
+            cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
+            cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
+            cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
+            CV_ERROR( CV_StsBadArg,
+            "The new training data must have the same types and the input and output variables "
+            "and the same categories for categorical variables" );
+
+        cvReleaseMat( &priors );
+        cvReleaseMat( &buf );
+        cvReleaseMat( &direction );
+        cvReleaseMat( &split_buf );
+        cvReleaseMemStorage( &temp_storage );
+
+        priors = data->priors; data->priors = 0;
+        buf = data->buf; data->buf = 0;
+        buf_count = data->buf_count; buf_size = data->buf_size;
+        sample_count = data->sample_count;
+
+        direction = data->direction; data->direction = 0;
+        split_buf = data->split_buf; data->split_buf = 0;
+        temp_storage = data->temp_storage; data->temp_storage = 0;
+        nv_heap = data->nv_heap; cv_heap = data->cv_heap;
+
+        data_root = new_node( 0, sample_count, 0, 0 );
+        EXIT;
+    }
+
     clear();
 
     var_all = 0;
@@ -219,11 +254,9 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
 
     // in case of single ordered predictor we need dummy cv_labels
     // for safe split_node_data() operation
-    have_cv_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0;
-    have_weights = _add_weights;
+    have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
 
-    buf_size = (ord_var_count*2 + cat_var_count + 1 +
-        (have_cv_labels ? 1 : 0) + (have_weights ? 1 : 0))*sample_count + 2;
+    buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
     shared = _shared;
     buf_count = shared ? 3 : 2;
     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
@@ -447,7 +480,7 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
 
     if( cv_n )
     {
-        int* dst = get_cv_labels(data_root);
+        int* dst = get_labels(data_root);
         CvRNG* r = &rng;
 
         for( i = vi = 0; i < sample_count; i++ )
@@ -473,7 +506,7 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
     have_priors = is_classifier && params.priors;
     if( is_classifier )
     {
-        int m = get_num_classes(), rows = 4;
+        int m = get_num_classes();
         double sum = 0;
         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
         for( i = 0; i < m; i++ )
@@ -488,16 +521,7 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
         if( have_priors )
             cvScale( priors, priors, 1./sum );
 
-        if( cat_var_count > 0 || params.cv_folds > 0 )
-        {
-            // need storage for cjk (see find_split_cat_gini) and risks/errors
-            rows += MAX( max_c_count, params.cv_folds ) + 1;
-            // add buffer for k-means clustering
-            if( m > 2 && max_c_count > params.max_categories )
-                rows += params.max_categories + (max_c_count+m-1)/m;
-        }
-
-        CV_CALL( counts = cvCreateMat( rows, m, CV_32SC2 ));
+        CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
     }
 
     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
@@ -505,6 +529,9 @@ void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
 
     __END__;
 
+    if( data )
+        delete data;
+
     cvFree( &int_ptr );
     cvReleaseMat( &sample_idx );
     cvReleaseMat( &var_type0 );
@@ -553,6 +580,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
         int* co, cur_ofs = 0;
         int vi, i, total = data_root->sample_count;
         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
+        int work_var_count = get_work_var_count();
         root = new_node( 0, count, 1, 0 );
 
         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
@@ -571,7 +599,7 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
                 co[i*2+1] = -1;
         }
 
-        for( vi = 0; vi <= var_count + (have_cv_labels ? 1 : 0); vi++ )
+        for( vi = 0; vi < work_var_count; vi++ )
         {
             int ci = get_var_type(vi);
 
@@ -678,7 +706,8 @@ void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
         }
     }
 
-    memset( missing, 1, count*var_count );
+    if( missing )
+        memset( missing, 1, count*var_count );
 
     for( vi = 0; vi < var_count; vi++ )
     {
@@ -686,21 +715,25 @@ void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
         if( ci >= 0 ) // categorical
         {
             float* dst = values + vi;
-            uchar* m = missing + vi;
+            uchar* m = missing ? missing + vi : 0;
             const int* src = get_cat_var_data(data_root, vi);
 
-            for( i = 0; i < count; i++, dst += var_count, m += var_count )
+            for( i = 0; i < count; i++, dst += var_count )
             {
                 int idx = sidx ? sidx[i] : i;
                 int val = src[idx];
                 *dst = (float)val;
-                *m = val < 0;
+                if( m )
+                {
+                    *m = val < 0;
+                    m += var_count;
+                }
             }
         }
         else // ordered
         {
             float* dst = values + vi;
-            uchar* m = missing + vi;
+            uchar* m = missing ? missing + vi : 0;
             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
             int count1 = data_root->get_num_valid(vi);
 
@@ -721,7 +754,8 @@ void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
                     {
                         dst[cur_ofs] = val;
-                        m[cur_ofs] = 0;
+                        if( m )
+                            m[cur_ofs] = 0;
                     }
                 }
             }
@@ -729,24 +763,27 @@ void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
     }
 
     // copy responses
-    if( is_classifier )
+    if( responses )
     {
-        const int* src = get_class_labels(data_root);
-        for( i = 0; i < count; i++ )
+        if( is_classifier )
         {
-            int idx = sidx ? sidx[i] : i;
-            int val = get_class_idx ? src[idx] :
-                cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
-            responses[i] = (float)val;
+            const int* src = get_class_labels(data_root);
+            for( i = 0; i < count; i++ )
+            {
+                int idx = sidx ? sidx[i] : i;
+                int val = get_class_idx ? src[idx] :
+                    cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
+                responses[i] = (float)val;
+            }
         }
-    }
-    else
-    {
-        const float* src = get_ord_responses(data_root);
-        for( i = 0; i < count; i++ )
+        else
         {
-            int idx = sidx ? sidx[i] : i;
-            responses[i] = src[idx];
+            const float* src = get_ord_responses(data_root);
+            for( i = 0; i < count; i++ )
+            {
+                int idx = sidx ? sidx[i] : i;
+                responses[i] = src[idx];
+            }
         }
     }
 
@@ -884,7 +921,7 @@ void CvDTreeTrainData::clear()
     node_heap = split_heap = 0;
 
     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
-    have_cv_labels = have_priors = is_classifier = false;
+    have_labels = have_priors = is_classifier = false;
 
     buf_count = buf_size = 0;
     shared = false;
@@ -907,6 +944,11 @@ int CvDTreeTrainData::get_var_type(int vi) const
 }
 
 
+int CvDTreeTrainData::get_work_var_count() const
+{
+    return var_count + 1 + (have_labels ? 1 : 0);
+}
+
 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
 {
     int oi = ~get_var_type(vi);
@@ -928,9 +970,9 @@ float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
 }
 
 
-int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n )
+int* CvDTreeTrainData::get_labels( CvDTreeNode* n )
 {
-    return params.cv_folds > 0 ? get_cat_var_data( n, var_count + 1 ) : 0;
+    return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
 }
 
 
@@ -943,13 +985,6 @@ int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
 }
 
 
-float* CvDTreeTrainData::get_weights( CvDTreeNode* n )
-{
-    return have_weights ?
-        (float*)get_cat_var_data( n, var_count + 1 + (params.cv_folds > 0) ) : 0;
-}
-
-
 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
 {
     int idx = n->buf_idx + 1;
@@ -959,6 +994,214 @@ int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
 }
 
 
+void CvDTreeTrainData::write_params( CvFileStorage* fs )
+{
+    CV_FUNCNAME( "CvDTreeTrainData::write_params" );
+
+    __BEGIN__;
+
+    int vi, vcount = var_count;
+
+    cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
+    cvWriteInt( fs, "var_all", var_all );
+    cvWriteInt( fs, "var_count", var_count );
+    cvWriteInt( fs, "ord_var_count", ord_var_count );
+    cvWriteInt( fs, "cat_var_count", cat_var_count );
+
+    cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
+    cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
+
+    if( is_classifier )
+    {
+        cvWriteInt( fs, "max_categories", params.max_categories );
+    }
+    else
+    {
+        cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
+    }
+
+    cvWriteInt( fs, "max_depth", params.max_depth );
+    cvWriteInt( fs, "min_sample_count", params.min_sample_count );
+    cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
+    
+    if( params.cv_folds > 1 )
+    {
+        cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
+        cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
+    }
+
+    if( priors )
+        cvWrite( fs, "priors", priors );
+
+    cvEndWriteStruct( fs );
+
+    if( var_idx )
+        cvWrite( fs, "var_idx", var_idx );
+    
+    cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
+
+    for( vi = 0; vi < vcount; vi++ )
+        cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
+
+    cvEndWriteStruct( fs );
+
+    if( cat_count && (cat_var_count > 0 || is_classifier) )
+    {
+        CV_ASSERT( cat_count != 0 );
+        cvWrite( fs, "cat_count", cat_count );
+        cvWrite( fs, "cat_map", cat_map );
+    }
+
+    __END__;
+}
+
+
+void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
+{
+    CV_FUNCNAME( "CvDTreeTrainData::read_params" );
+
+    __BEGIN__;
+    
+    CvFileNode *tparams_node, *vartype_node;
+    CvSeqReader reader;
+    int vi, max_split_size, tree_block_size;
+
+    is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
+    var_all = cvReadIntByName( fs, node, "var_all" );
+    var_count = cvReadIntByName( fs, node, "var_count", var_all );
+    cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
+    ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
+
+    tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
+
+    if( tparams_node ) // training parameters are not necessary
+    {
+        params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
+
+        if( is_classifier )
+        {
+            params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
+        }
+        else
+        {
+            params.regression_accuracy =
+                (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
+        }
+
+        params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
+        params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
+        params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
+    
+        if( params.cv_folds > 1 )
+        {
+            params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
+            params.truncate_pruned_tree =
+                cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
+        }
+
+        priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
+        if( priors && !CV_IS_MAT(priors) )
+            CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
+    }
+
+    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
+    if( var_idx )
+    {
+        if( !CV_IS_MAT(var_idx) ||
+            var_idx->cols != 1 && var_idx->rows != 1 ||
+            var_idx->cols + var_idx->rows - 1 != var_count ||
+            CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
+            CV_ERROR( CV_StsParseError,
+                "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
+
+        for( vi = 0; vi < var_count; vi++ )
+            if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
+                CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
+    }
+    
+    ////// read var type
+    CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
+
+    cat_var_count = 0;
+    ord_var_count = -1;
+    vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
+
+    if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
+        var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
+    else
+    {
+        if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
+            vartype_node->data.seq->total != var_count )
+            CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
+
+        cvStartReadSeq( vartype_node->data.seq, &reader );
+    
+        for( vi = 0; vi < var_count; vi++ )
+        {
+            CvFileNode* n = (CvFileNode*)reader.ptr;
+            if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
+                CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
+            var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
+            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
+        }
+    }
+    var_type->data.i[var_count] = cat_var_count;
+
+    ord_var_count = ~ord_var_count;
+    if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
+        CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
+    //////
+
+    if( cat_var_count > 0 || is_classifier )
+    {
+        int ccount, total_c_count = 0;
+        CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
+        CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
+
+        if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
+            cat_count->cols != 1 && cat_count->rows != 1 ||
+            CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
+            cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
+            cat_map->cols != 1 && cat_map->rows != 1 ||
+            CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
+            CV_ERROR( CV_StsParseError,
+            "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
+
+        ccount = cat_var_count + is_classifier;
+
+        CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
+        cat_ofs->data.i[0] = 0;
+        max_c_count = 1;
+
+        for( vi = 0; vi < ccount; vi++ )
+        {
+            int val = cat_count->data.i[vi];
+            if( val <= 0 )
+                CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
+            max_c_count = MAX( max_c_count, val );
+            cat_ofs->data.i[vi+1] = total_c_count += val;
+        }
+
+        if( cat_map->cols + cat_map->rows - 1 != total_c_count )
+            CV_ERROR( CV_StsBadSize,
+            "cat_map vector length is not equal to the total number of categories in all categorical vars" );
+    }
+
+    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
+        (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
+
+    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
+    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
+    CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
+    CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
+            sizeof(CvDTreeNode), tree_storage ));
+    CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
+            max_split_size, tree_storage ));
+
+    __END__;
+}
+
+
 /////////////////////// Decision Tree /////////////////////////
 
 CvDTree::CvDTree()
@@ -1079,9 +1322,6 @@ bool CvDTree::do_train( const CvMat* _subsample_idx )
 }
 
 
-#define DTREE_CAT_DIR(idx,subset) \
-    (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
-
 void CvDTree::try_split_node( CvDTreeNode* node )
 {
     CvDTreeSplit* best_split = 0;
@@ -1196,7 +1436,7 @@ double CvDTree::calc_node_dir( CvDTreeNode* node )
             for( i = 0; i < n; i++ )
             {
                 int idx = labels[i];
-                int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
+                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
                 sum += d; sum_abs += d & 1;
                 dir[i] = (char)d;
             }
@@ -1214,7 +1454,7 @@ double CvDTree::calc_node_dir( CvDTreeNode* node )
             {
                 int idx = labels[i];
                 double w = priors[responses[i]];
-                int d = idx >= 0 ? DTREE_CAT_DIR(idx,subset) : 0;
+                int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
                 sum += d*w; sum_abs += (d & 1)*w;
                 dir[i] = (char)d;
             }
@@ -1323,8 +1563,8 @@ CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
     int n1 = node->get_num_valid(vi);
     int m = data->get_num_classes();
     const int* rc0 = data->counts->data.i;
-    int* lc = (int*)(rc0 + m);
-    int* rc = lc + m;
+    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
+    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
     int i, best_i = -1;
     double lsum2 = 0, rsum2 = 0, best_val = 0;
     const double* priors = data->have_priors ? data->priors->data.db : 0;
@@ -1359,7 +1599,7 @@ CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
 
             if( sorted[i].val + epsilon < sorted[i+1].val )
             {
-                double val = lsum2/L + rsum2/R;
+                double val = (lsum2*R + rsum2*L)/((double)L*R);
                 if( best_val < val )
                 {
                     best_val = val;
@@ -1391,7 +1631,7 @@ CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
 
             if( sorted[i].val + epsilon < sorted[i+1].val )
             {
-                double val = lsum2/L + rsum2/R;
+                double val = (lsum2*R + rsum2*L)/((double)L*R);
                 if( best_val < val )
                 {
                     best_val = val;
@@ -1513,10 +1753,9 @@ CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
     int n = node->sample_count;
     int m = data->get_num_classes();
     int _mi = data->cat_count->data.i[ci], mi = _mi;
-    const int* rc0 = data->counts->data.i;
-    int* lc = (int*)(rc0 + m);
-    int* rc = lc + m;
-    int* _cjk = rc + m*2, *cjk = _cjk;
+    int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
+    int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
+    int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
     int* cluster_labels = 0;
     int** int_ptr = 0;
@@ -1545,7 +1784,7 @@ CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
         {
             mi = MIN(data->params.max_categories, n);
             cjk += _mi*m;
-            cluster_labels = cjk + mi*m;
+            cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
         }
         subset_i = 1;
@@ -1640,7 +1879,7 @@ CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
 
         if( L > FLT_EPSILON && R > FLT_EPSILON )
         {
-            double val = lsum2/L + rsum2/R;
+            double val = (lsum2*R + rsum2*L)/((double)L*R);
             if( best_val < val )
             {
                 best_val = val;
@@ -1701,7 +1940,7 @@ CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
 
         if( sorted[i].val + epsilon < sorted[i+1].val )
         {
-            double val = lsum*lsum/L + rsum*rsum/R;
+            double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
             if( best_val < val )
             {
                 best_val = val;
@@ -1755,7 +1994,7 @@ CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
 
     icvSortDblPtr( sum_ptr, mi, 0 );
 
-    // revert back to unnormalized sum
+    // revert back to unnormalized sums
     // (there should be a very little loss of accuracy)
     for( i = 0; i < mi; i++ )
         sum[i] *= counts[i];
@@ -1773,7 +2012,7 @@ CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
             
             if( L && R )
             {
-                double val = lsum*lsum/L + rsum*rsum/R;
+                double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
                 if( best_val < val )
                 {
                     best_val = val;
@@ -1905,7 +2144,7 @@ CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
         }
     }
 
-    return best_i >= 0 ? data->new_split_ord( vi,
+    return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
         best_inversed, (float)best_val ) : 0;
 }
@@ -1933,7 +2172,7 @@ CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
     // sent to the left (lc) and to the right (rc) by the primary split
     if( !data->have_priors )
     {
-        int* _lc = data->counts->data.i + 1;
+        int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
         int* _rc = _lc + mi + 1;
 
         for( i = -1; i < mi; i++ )
@@ -2005,7 +2244,7 @@ CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
 void CvDTree::calc_node_value( CvDTreeNode* node )
 {
     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
-    const int* cv_labels = data->get_cv_labels(node);
+    const int* cv_labels = data->get_labels(node);
 
     if( data->is_classifier )
     {
@@ -2021,7 +2260,7 @@ void CvDTree::calc_node_value( CvDTreeNode* node )
         int* cls_count = data->counts->data.i;
         const int* responses = data->get_class_labels(node);
         int m = data->get_num_classes();
-        int* cv_cls_count = cls_count + m;
+        int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
         double max_val = -1, total_weight = 0;
         int max_k = -1;
         double* priors = data->priors->data.db;
@@ -2194,7 +2433,7 @@ void CvDTree::complete_node_dir( CvDTreeNode* node )
                     int idx;
                     if( !dir[i] && (idx = labels[i]) >= 0 )
                     {
-                        int d = DTREE_CAT_DIR(idx,subset);
+                        int d = CV_DTREE_CAT_DIR(idx,subset);
                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
                         if( --nz )
                             break;
@@ -2256,6 +2495,13 @@ void CvDTree::split_node_data( CvDTreeNode* node )
     CvDTreeNode *left = 0, *right = 0;
     int* new_idx = data->split_buf->data.i;
     int new_buf_idx = data->get_child_buf_idx( node );
+    int work_var_count = data->get_work_var_count();
+
+    // speedup things a little, especially for tree ensembles with a lots of small trees:
+    //   do not physically split the input data between the left and right child nodes
+    //   when we are not going to split them further,
+    //   as calc_node_value() does not requires input features anyway.
+    bool split_input_data;
 
     complete_node_dir(node);
 
@@ -2270,7 +2516,11 @@ void CvDTree::split_node_data( CvDTreeNode* node )
 
     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
-        (data->ord_var_count*2 + data->cat_var_count+1+data->have_cv_labels)*nl );
+        (data->ord_var_count + work_var_count)*nl );
+
+    split_input_data = node->depth + 1 < data->params.max_depth &&
+        (node->left->sample_count > data->params.min_sample_count ||
+        node->right->sample_count > data->params.min_sample_count);
 
     // split ordered variables, keep both halves sorted.
     for( vi = 0; vi < data->var_count; vi++ )
@@ -2280,7 +2530,7 @@ void CvDTree::split_node_data( CvDTreeNode* node )
         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
         CvPair32s32f tl, tr;
 
-        if( ci >= 0 )
+        if( ci >= 0 || !split_input_data )
             continue;
 
         src = data->get_ord_var_data(node, vi);
@@ -2320,14 +2570,14 @@ void CvDTree::split_node_data( CvDTreeNode* node )
     }
 
     // split categorical vars, responses and cv_labels using new_idx relocation table
-    for( vi = 0; vi <= data->var_count + data->have_cv_labels + data->have_weights; vi++ )
+    for( vi = 0; vi < work_var_count; vi++ )
     {
         int ci = data->get_var_type(vi);
         int n1 = node->get_num_valid(vi), nr1 = 0;
         int *src, *ldst0, *rdst0, *ldst, *rdst;
         int tl, tr;
 
-        if( ci < 0 )
+        if( ci < 0 || (vi < data->var_count && !split_input_data) )
             continue;
 
         src = data->get_cat_var_data(node, vi);
@@ -2634,7 +2884,7 @@ CvDTreeNode* CvDTree::predict( const CvMat* _sample,
         "number of elements as the total number of variables used for training" );
 
     sample = _sample->data.fl;
-    step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(data[0]);
+    step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
 
     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
     {
@@ -2707,7 +2957,7 @@ CvDTreeNode* CvDTree::predict( const CvMat* _sample,
                         catbuf[ci] = c -= cofs[ci];
                     }
                 }
-                dir = DTREE_CAT_DIR(c, split->subset);
+                dir = CV_DTREE_CAT_DIR(c, split->subset);
             }
 
             if( split->inversed )
@@ -2773,68 +3023,6 @@ const CvMat* CvDTree::get_var_importance()
 }
 
 
-void CvDTree::write_train_data_params( CvFileStorage* fs )
-{
-    CV_FUNCNAME( "CvDTree::write_train_data_params" );
-
-    __BEGIN__;
-
-    int vi, vcount = data->var_count;
-
-    cvWriteInt( fs, "is_classifier", data->is_classifier ? 1 : 0 );
-    cvWriteInt( fs, "var_all", data->var_all );
-    cvWriteInt( fs, "var_count", data->var_count );
-    cvWriteInt( fs, "ord_var_count", data->ord_var_count );
-    cvWriteInt( fs, "cat_var_count", data->cat_var_count );
-
-    cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
-    cvWriteInt( fs, "use_surrogates", data->params.use_surrogates ? 1 : 0 );
-
-    if( data->is_classifier )
-    {
-        cvWriteInt( fs, "max_categories", data->params.max_categories );
-    }
-    else
-    {
-        cvWriteReal( fs, "regression_accuracy", data->params.regression_accuracy );
-    }
-
-    cvWriteInt( fs, "max_depth", data->params.max_depth );
-    cvWriteInt( fs, "min_sample_count", data->params.min_sample_count );
-    cvWriteInt( fs, "cross_validation_folds", data->params.cv_folds );
-    
-    if( data->params.cv_folds > 1 )
-    {
-        cvWriteInt( fs, "use_1se_rule", data->params.use_1se_rule ? 1 : 0 );
-        cvWriteInt( fs, "truncate_pruned_tree", data->params.truncate_pruned_tree ? 1 : 0 );
-    }
-
-    if( data->priors )
-        cvWrite( fs, "priors", data->priors );
-
-    cvEndWriteStruct( fs );
-
-    if( data->var_idx )
-        cvWrite( fs, "var_idx", data->var_idx );
-    
-    cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
-
-    for( vi = 0; vi < vcount; vi++ )
-        cvWriteInt( fs, 0, data->var_type->data.i[vi] >= 0 );
-
-    cvEndWriteStruct( fs );
-
-    if( data->cat_count && (data->cat_var_count > 0 || data->is_classifier) )
-    {
-        CV_ASSERT( data->cat_count != 0 );
-        cvWrite( fs, "cat_count", data->cat_count );
-        cvWrite( fs, "cat_map", data->cat_map );
-    }
-
-    __END__;
-}
-
-
 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
 {
     int ci;
@@ -2848,7 +3036,7 @@ void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
     {
         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
         for( i = 0; i < n; i++ )
-            to_right += DTREE_CAT_DIR(i,split->subset) > 0;
+            to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
 
         // ad-hoc rule when to use inverse categorical split notation
         // to achieve more compact and clear representation
@@ -2858,7 +3046,7 @@ void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
         for( i = 0; i < n; i++ )
         {
-            int dir = DTREE_CAT_DIR(i,split->subset);
+            int dir = CV_DTREE_CAT_DIR(i,split->subset);
             if( dir*default_dir < 0 )
                 cvWriteInt( fs, 0, i );
         }
@@ -2947,15 +3135,10 @@ void CvDTree::write( CvFileStorage* fs, const char* name )
 
     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
 
-    write_train_data_params( fs );
-
-    cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
-    get_var_importance();
-    cvWrite( fs, "var_importance", var_importance );
-
-    cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
-    write_tree_nodes( fs );
-    cvEndWriteStruct( fs );
+    data->write_params( fs );
+    if( var_importance )
+        cvWrite( fs, "var_importance", var_importance );
+    write( fs );
 
     cvEndWriteStruct( fs );
 
@@ -2963,152 +3146,18 @@ void CvDTree::write( CvFileStorage* fs, const char* name )
 }
 
 
-void CvDTree::read_train_data_params( CvFileStorage* fs, CvFileNode* node )
+void CvDTree::write( CvFileStorage* fs )
 {
-    CV_FUNCNAME( "CvDTree::read_train_data_params" );
+    //CV_FUNCNAME( "CvDTree::write" );
 
     __BEGIN__;
-    
-    CvDTreeParams params;
-    CvFileNode *tparams_node, *vartype_node;
-    CvSeqReader reader;
-    int is_classifier, vi, cat_var_count, ord_var_count;
-    int max_split_size, tree_block_size;
-
-    data = new CvDTreeTrainData;
-
-    is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
-    data->is_classifier = (is_classifier != 0);
-    data->var_all = cvReadIntByName( fs, node, "var_all" );
-    data->var_count = cvReadIntByName( fs, node, "var_count", data->var_all );
-    data->cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
-    data->ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
-
-    tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
-
-    if( tparams_node ) // training parameters are not necessary
-    {
-        data->params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
-
-        if( is_classifier )
-        {
-            data->params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
-        }
-        else
-        {
-            data->params.regression_accuracy =
-                (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
-        }
 
-        data->params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
-        data->params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
-        data->params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
-    
-        if( data->params.cv_folds > 1 )
-        {
-            data->params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
-            data->params.truncate_pruned_tree =
-                cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
-        }
-
-        data->priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
-        if( data->priors && !CV_IS_MAT(data->priors) )
-            CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
-    }
-
-    CV_CALL( data->var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
-    if( data->var_idx )
-    {
-        if( !CV_IS_MAT(data->var_idx) ||
-            data->var_idx->cols != 1 && data->var_idx->rows != 1 ||
-            data->var_idx->cols + data->var_idx->rows - 1 != data->var_count ||
-            CV_MAT_TYPE(data->var_idx->type) != CV_32SC1 )
-            CV_ERROR( CV_StsParseError,
-                "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
-
-        for( vi = 0; vi < data->var_count; vi++ )
-            if( (unsigned)data->var_idx->data.i[vi] >= (unsigned)data->var_all )
-                CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
-    }
-    
-    ////// read var type
-    CV_CALL( data->var_type = cvCreateMat( 1, data->var_count + 2, CV_32SC1 ));
-
-    cat_var_count = 0;
-    ord_var_count = -1;
-    vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
-
-    if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && data->var_count == 1 )
-        data->var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
-    else
-    {
-        if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
-            vartype_node->data.seq->total != data->var_count )
-            CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
-
-        cvStartReadSeq( vartype_node->data.seq, &reader );
-    
-        for( vi = 0; vi < data->var_count; vi++ )
-        {
-            CvFileNode* n = (CvFileNode*)reader.ptr;
-            if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
-                CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
-            data->var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
-            CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
-        }
-    }
-    
-    ord_var_count = ~ord_var_count;
-    if( cat_var_count != data->cat_var_count || ord_var_count != data->ord_var_count )
-        CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
-    //////
-
-    if( data->cat_var_count > 0 || is_classifier )
-    {
-        int ccount, max_c_count = 0, total_c_count = 0;
-        CV_CALL( data->cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
-        CV_CALL( data->cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
-
-        if( !CV_IS_MAT(data->cat_count) || !CV_IS_MAT(data->cat_map) ||
-            data->cat_count->cols != 1 && data->cat_count->rows != 1 ||
-            CV_MAT_TYPE(data->cat_count->type) != CV_32SC1 ||
-            data->cat_count->cols + data->cat_count->rows - 1 != cat_var_count + is_classifier ||
-            data->cat_map->cols != 1 && data->cat_map->rows != 1 ||
-            CV_MAT_TYPE(data->cat_map->type) != CV_32SC1 )
-            CV_ERROR( CV_StsParseError,
-            "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
-
-        ccount = cat_var_count + is_classifier;
-
-        CV_CALL( data->cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
-        data->cat_ofs->data.i[0] = 0;
-
-        for( vi = 0; vi < ccount; vi++ )
-        {
-            int val = data->cat_count->data.i[vi];
-            if( val <= 0 )
-                CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
-            max_c_count = MAX( max_c_count, val );
-            data->cat_ofs->data.i[vi+1] = total_c_count += val;
-        }
-
-        if( data->cat_map->cols + data->cat_map->rows - 1 != total_c_count )
-            CV_ERROR( CV_StsBadSize,
-            "cat_map vector length is not equal to the total number of categories in all categorical vars" );
-        
-        data->max_c_count = max_c_count;
-    }
-
-    max_split_size = cvAlign(sizeof(CvDTreeSplit) +
-        (MAX(0,data->max_c_count - 33)/32)*sizeof(int),sizeof(void*));
+    cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
+    get_var_importance();
 
-    tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
-    tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
-    CV_CALL( data->tree_storage = cvCreateMemStorage( tree_block_size ));
-    CV_CALL( data->node_heap = cvCreateSet( 0, sizeof(data->node_heap[0]),
-            sizeof(CvDTreeNode), data->tree_storage ));
-    CV_CALL( data->split_heap = cvCreateSet( 0, sizeof(data->split_heap[0]),
-            max_split_size, data->tree_storage ));
+    cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
+    write_tree_nodes( fs );
+    cvEndWriteStruct( fs );
 
     __END__;
 }
@@ -3303,6 +3352,16 @@ void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
 
 
 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
+{
+    CvDTreeTrainData* _data = new CvDTreeTrainData();
+    _data->read_params( fs, fnode );
+
+    read( fs, fnode, _data );
+}
+
+
+// a special entry point for reading weak decision trees from the tree ensembles
+void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
 {
     CV_FUNCNAME( "CvDTree::read" );
 
@@ -3311,16 +3370,14 @@ void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
     CvFileNode* tree_nodes;
 
     clear();
-    read_train_data_params( fs, fnode );
+    data = _data;
 
-    tree_nodes = cvGetFileNodeByName( fs, fnode, "nodes" );
+    tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
 
-    pruned_tree_idx = cvReadIntByName( fs, fnode, "best_tree_idx", -1 );
-
+    pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
     read_tree_nodes( fs, tree_nodes );
-    get_var_importance(); // recompute variable importance
 
     __END__;
 }