]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/include/opencv/ml.h
OpenMP parallelization of find_best_split (mll)
[opencv.git] / opencv / include / opencv / ml.h
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #ifndef __ML_H__
42 #define __ML_H__
43
44 // disable deprecation warning which appears in VisualStudio 8.0
45 #if _MSC_VER >= 1400
46 #pragma warning( disable : 4996 )
47 #endif
48
49 #ifndef SKIP_INCLUDES
50
51   #include "cxcore.h"
52   #include <limits.h>
53
54   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
55     #include <windows.h>
56   #endif
57
58 #else // SKIP_INCLUDES
59
60   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
61     #define CV_CDECL __cdecl
62     #define CV_STDCALL __stdcall
63   #else
64     #define CV_CDECL
65     #define CV_STDCALL
66   #endif
67
68   #ifndef CV_EXTERN_C
69     #ifdef __cplusplus
70       #define CV_EXTERN_C extern "C"
71       #define CV_DEFAULT(val) = val
72     #else
73       #define CV_EXTERN_C
74       #define CV_DEFAULT(val)
75     #endif
76   #endif
77
78   #ifndef CV_EXTERN_C_FUNCPTR
79     #ifdef __cplusplus
80       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
81     #else
82       #define CV_EXTERN_C_FUNCPTR(x) typedef x
83     #endif
84   #endif
85
86   #ifndef CV_INLINE
87     #if defined __cplusplus
88       #define CV_INLINE inline
89     #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__
90       #define CV_INLINE __inline
91     #else
92       #define CV_INLINE static
93     #endif
94   #endif /* CV_INLINE */
95
96   #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS
97     #define CV_EXPORTS __declspec(dllexport)
98   #else
99     #define CV_EXPORTS
100   #endif
101
102   #ifndef CVAPI
103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
104   #endif
105
106 #endif // SKIP_INCLUDES
107
108
109 #ifdef __cplusplus
110
111 // Apple defines a check() macro somewhere in the debug headers
112 // that interferes with a method definiton in this header
113
114 using namespace std;
115 #undef check
116
117 /****************************************************************************************\
118 *                               Main struct definitions                                  *
119 \****************************************************************************************/
120
121 /* log(2*PI) */
122 #define CV_LOG2PI (1.8378770664093454835606594728112)
123
124 /* columns of <trainData> matrix are training samples */
125 #define CV_COL_SAMPLE 0
126
127 /* rows of <trainData> matrix are training samples */
128 #define CV_ROW_SAMPLE 1
129
130 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
131
132 struct CvVectors
133 {
134     int type;
135     int dims, count;
136     CvVectors* next;
137     union
138     {
139         uchar** ptr;
140         float** fl;
141         double** db;
142     } data;
143 };
144
145 #if 0
146 /* A structure, representing the lattice range of statmodel parameters.
147    It is used for optimizing statmodel parameters by cross-validation method.
148    The lattice is logarithmic, so <step> must be greater then 1. */
149 typedef struct CvParamLattice
150 {
151     double min_val;
152     double max_val;
153     double step;
154 }
155 CvParamLattice;
156
157 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
158                                          double log_step )
159 {
160     CvParamLattice pl;
161     pl.min_val = MIN( min_val, max_val );
162     pl.max_val = MAX( min_val, max_val );
163     pl.step = MAX( log_step, 1. );
164     return pl;
165 }
166
167 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
168 {
169     CvParamLattice pl = {0,0,0};
170     return pl;
171 }
172 #endif
173
174 /* Variable type */
175 #define CV_VAR_NUMERICAL    0
176 #define CV_VAR_ORDERED      0
177 #define CV_VAR_CATEGORICAL  1
178
179 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
180 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
181 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
182 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
183 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
184 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
185 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
186 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
187 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
188
189 #define CV_TRAIN_ERROR  0
190 #define CV_TEST_ERROR   1
191
192 class CV_EXPORTS CvStatModel
193 {
194 public:
195     CvStatModel();
196     virtual ~CvStatModel();
197
198     virtual void clear();
199
200     virtual void save( const char* filename, const char* name=0 );
201     virtual void load( const char* filename, const char* name=0 );
202
203     virtual void write( CvFileStorage* storage, const char* name );
204     virtual void read( CvFileStorage* storage, CvFileNode* node );
205
206 protected:
207     const char* default_model_name;
208 };
209
210
211 /****************************************************************************************\
212 *                                 Normal Bayes Classifier                                *
213 \****************************************************************************************/
214
215 /* The structure, representing the grid range of statmodel parameters.
216    It is used for optimizing statmodel accuracy by varying model parameters,
217    the accuracy estimate being computed by cross-validation.
218    The grid is logarithmic, so <step> must be greater then 1. */
219 struct CV_EXPORTS CvParamGrid
220 {
221     // SVM params type
222     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
223
224     CvParamGrid()
225     {
226         min_val = max_val = step = 0;
227     }
228
229     CvParamGrid( double _min_val, double _max_val, double log_step )
230     {
231         min_val = _min_val;
232         max_val = _max_val;
233         step = log_step;
234     }
235     //CvParamGrid( int param_id );
236     bool check() const;
237
238     double min_val;
239     double max_val;
240     double step;
241 };
242
243 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
244 {
245 public:
246     CvNormalBayesClassifier();
247     virtual ~CvNormalBayesClassifier();
248
249     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
250         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
251
252     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
253         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
254
255     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
256     virtual void clear();
257
258     virtual void write( CvFileStorage* storage, const char* name );
259     virtual void read( CvFileStorage* storage, CvFileNode* node );
260
261 protected:
262     int     var_count, var_all;
263     CvMat*  var_idx;
264     CvMat*  cls_labels;
265     CvMat** count;
266     CvMat** sum;
267     CvMat** productsum;
268     CvMat** avg;
269     CvMat** inv_eigen_values;
270     CvMat** cov_rotate_mats;
271     CvMat*  c;
272 };
273
274
275 /****************************************************************************************\
276 *                          K-Nearest Neighbour Classifier                                *
277 \****************************************************************************************/
278
279 // k Nearest Neighbors
280 class CV_EXPORTS CvKNearest : public CvStatModel
281 {
282 public:
283
284     CvKNearest();
285     virtual ~CvKNearest();
286
287     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
288                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
289
290     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
291                         const CvMat* _sample_idx=0, bool is_regression=false,
292                         int _max_k=32, bool _update_base=false );
293
294     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
295         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
296
297     virtual void clear();
298     int get_max_k() const;
299     int get_var_count() const;
300     int get_sample_count() const;
301     bool is_regression() const;
302
303 protected:
304
305     virtual float write_results( int k, int k1, int start, int end,
306         const float* neighbor_responses, const float* dist, CvMat* _results,
307         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
308
309     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
310         float* neighbor_responses, const float** neighbors, float* dist ) const;
311
312
313     int max_k, var_count;
314     int total;
315     bool regression;
316     CvVectors* samples;
317 };
318
319 /****************************************************************************************\
320 *                                   Support Vector Machines                              *
321 \****************************************************************************************/
322
323 // SVM training parameters
324 struct CV_EXPORTS CvSVMParams
325 {
326     CvSVMParams();
327     CvSVMParams( int _svm_type, int _kernel_type,
328                  double _degree, double _gamma, double _coef0,
329                  double _C, double _nu, double _p,
330                  CvMat* _class_weights, CvTermCriteria _term_crit );
331
332     int         svm_type;
333     int         kernel_type;
334     double      degree; // for poly
335     double      gamma;  // for poly/rbf/sigmoid
336     double      coef0;  // for poly/sigmoid
337
338     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
339     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
340     double      p; // for CV_SVM_EPS_SVR
341     CvMat*      class_weights; // for CV_SVM_C_SVC
342     CvTermCriteria term_crit; // termination criteria
343 };
344
345
346 struct CV_EXPORTS CvSVMKernel
347 {
348     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
349                                        const float* another, float* results );
350     CvSVMKernel();
351     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
352     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
353     virtual ~CvSVMKernel();
354
355     virtual void clear();
356     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
357
358     const CvSVMParams* params;
359     Calc calc_func;
360
361     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
362                                     const float* another, float* results,
363                                     double alpha, double beta );
364
365     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
366                               const float* another, float* results );
367     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
368                            const float* another, float* results );
369     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
370                             const float* another, float* results );
371     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
372                                const float* another, float* results );
373 };
374
375
376 struct CvSVMKernelRow
377 {
378     CvSVMKernelRow* prev;
379     CvSVMKernelRow* next;
380     float* data;
381 };
382
383
384 struct CvSVMSolutionInfo
385 {
386     double obj;
387     double rho;
388     double upper_bound_p;
389     double upper_bound_n;
390     double r;   // for Solver_NU
391 };
392
393 class CV_EXPORTS CvSVMSolver
394 {
395 public:
396     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
397     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
398     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
399
400     CvSVMSolver();
401
402     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
403                  int alpha_count, double* alpha, double Cp, double Cn,
404                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
405                  SelectWorkingSet select_working_set, CalcRho calc_rho );
406     virtual bool create( int count, int var_count, const float** samples, schar* y,
407                  int alpha_count, double* alpha, double Cp, double Cn,
408                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
409                  SelectWorkingSet select_working_set, CalcRho calc_rho );
410     virtual ~CvSVMSolver();
411
412     virtual void clear();
413     virtual bool solve_generic( CvSVMSolutionInfo& si );
414
415     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
416                               double Cp, double Cn, CvMemStorage* storage,
417                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
418     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
419                                CvMemStorage* storage, CvSVMKernel* kernel,
420                                double* alpha, CvSVMSolutionInfo& si );
421     virtual bool solve_one_class( int count, int var_count, const float** samples,
422                                   CvMemStorage* storage, CvSVMKernel* kernel,
423                                   double* alpha, CvSVMSolutionInfo& si );
424
425     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
426                                 CvMemStorage* storage, CvSVMKernel* kernel,
427                                 double* alpha, CvSVMSolutionInfo& si );
428
429     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
430                                CvMemStorage* storage, CvSVMKernel* kernel,
431                                double* alpha, CvSVMSolutionInfo& si );
432
433     virtual float* get_row_base( int i, bool* _existed );
434     virtual float* get_row( int i, float* dst );
435
436     int sample_count;
437     int var_count;
438     int cache_size;
439     int cache_line_size;
440     const float** samples;
441     const CvSVMParams* params;
442     CvMemStorage* storage;
443     CvSVMKernelRow lru_list;
444     CvSVMKernelRow* rows;
445
446     int alpha_count;
447
448     double* G;
449     double* alpha;
450
451     // -1 - lower bound, 0 - free, 1 - upper bound
452     schar* alpha_status;
453
454     schar* y;
455     double* b;
456     float* buf[2];
457     double eps;
458     int max_iter;
459     double C[2];  // C[0] == Cn, C[1] == Cp
460     CvSVMKernel* kernel;
461
462     SelectWorkingSet select_working_set_func;
463     CalcRho calc_rho_func;
464     GetRow get_row_func;
465
466     virtual bool select_working_set( int& i, int& j );
467     virtual bool select_working_set_nu_svm( int& i, int& j );
468     virtual void calc_rho( double& rho, double& r );
469     virtual void calc_rho_nu_svm( double& rho, double& r );
470
471     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
472     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
473     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
474 };
475
476
477 struct CvSVMDecisionFunc
478 {
479     double rho;
480     int sv_count;
481     double* alpha;
482     int* sv_index;
483 };
484
485
486 // SVM model
487 class CV_EXPORTS CvSVM : public CvStatModel
488 {
489 public:
490     // SVM type
491     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
492
493     // SVM kernel type
494     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
495
496     // SVM params type
497     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
498
499     CvSVM();
500     virtual ~CvSVM();
501
502     CvSVM( const CvMat* _train_data, const CvMat* _responses,
503            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
504            CvSVMParams _params=CvSVMParams() );
505
506     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
507                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
508                         CvSVMParams _params=CvSVMParams() );
509     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
510         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
511         int k_fold = 10,
512         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
513         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
514         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
515         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
516         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
517         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
518
519     virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const;
520
521     virtual int get_support_vector_count() const;
522     virtual const float* get_support_vector(int i) const;
523     virtual CvSVMParams get_params() const { return params; };
524     virtual void clear();
525
526     static CvParamGrid get_default_grid( int param_id );
527
528     virtual void write( CvFileStorage* storage, const char* name );
529     virtual void read( CvFileStorage* storage, CvFileNode* node );
530     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
531
532 protected:
533
534     virtual bool set_params( const CvSVMParams& _params );
535     virtual bool train1( int sample_count, int var_count, const float** samples,
536                     const void* _responses, double Cp, double Cn,
537                     CvMemStorage* _storage, double* alpha, double& rho );
538     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
539                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
540     virtual void create_kernel();
541     virtual void create_solver();
542
543     virtual void write_params( CvFileStorage* fs );
544     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
545
546     CvSVMParams params;
547     CvMat* class_labels;
548     int var_all;
549     float** sv;
550     int sv_total;
551     CvMat* var_idx;
552     CvMat* class_weights;
553     CvSVMDecisionFunc* decision_func;
554     CvMemStorage* storage;
555
556     CvSVMSolver* solver;
557     CvSVMKernel* kernel;
558 };
559
560 /****************************************************************************************\
561 *                              Expectation - Maximization                                *
562 \****************************************************************************************/
563
564 struct CV_EXPORTS CvEMParams
565 {
566     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
567         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
568     {
569         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
570     }
571
572     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
573                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
574                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
575                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
576                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
577                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
578     {}
579
580     int nclusters;
581     int cov_mat_type;
582     int start_step;
583     const CvMat* probs;
584     const CvMat* weights;
585     const CvMat* means;
586     const CvMat** covs;
587     CvTermCriteria term_crit;
588 };
589
590
591 class CV_EXPORTS CvEM : public CvStatModel
592 {
593 public:
594     // Type of covariation matrices
595     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
596
597     // The initial step
598     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
599
600     CvEM();
601     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
602           CvEMParams params=CvEMParams(), CvMat* labels=0 );
603     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
604
605     virtual ~CvEM();
606
607     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
608                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
609
610     virtual float predict( const CvMat* sample, CvMat* probs ) const;
611     virtual void clear();
612
613     int           get_nclusters() const;
614     const CvMat*  get_means()     const;
615     const CvMat** get_covs()      const;
616     const CvMat*  get_weights()   const;
617     const CvMat*  get_probs()     const;
618
619     inline double         get_log_likelihood     () const { return log_likelihood;     };
620     
621 //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
622 //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
623 //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
624
625 protected:
626
627     virtual void set_params( const CvEMParams& params,
628                              const CvVectors& train_data );
629     virtual void init_em( const CvVectors& train_data );
630     virtual double run_em( const CvVectors& train_data );
631     virtual void init_auto( const CvVectors& samples );
632     virtual void kmeans( const CvVectors& train_data, int nclusters,
633                          CvMat* labels, CvTermCriteria criteria,
634                          const CvMat* means );
635     CvEMParams params;
636     double log_likelihood;
637
638     CvMat* means;
639     CvMat** covs;
640     CvMat* weights;
641     CvMat* probs;
642
643     CvMat* log_weight_div_det;
644     CvMat* inv_eigen_values;
645     CvMat** cov_rotate_mats;
646 };
647
648 /****************************************************************************************\
649 *                                      Decision Tree                                     *
650 \****************************************************************************************/\
651 class CvMLData;
652
653 struct CvPair16u32s
654 {
655     unsigned short* u;
656     int* i;
657 };
658
659
660 #define CV_DTREE_CAT_DIR(idx,subset) \
661     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
662
663 struct CvDTreeSplit
664 {
665     int var_idx;
666     int condensed_idx;
667     int inversed;
668     float quality;
669     CvDTreeSplit* next;
670     union
671     {
672         int subset[2];
673         struct
674         {
675             float c;
676             int split_point;
677         }
678         ord;
679     };
680 };
681
682
683 struct CvDTreeNode
684 {
685     int class_idx;
686     int Tn;
687     double value;
688
689     CvDTreeNode* parent;
690     CvDTreeNode* left;
691     CvDTreeNode* right;
692
693     CvDTreeSplit* split;
694
695     int sample_count;
696     int depth;
697     int* num_valid;
698     int offset;
699     int buf_idx;
700     double maxlr;
701
702     // global pruning data
703     int complexity;
704     double alpha;
705     double node_risk, tree_risk, tree_error;
706
707     // cross-validation pruning data
708     int* cv_Tn;
709     double* cv_node_risk;
710     double* cv_node_error;
711
712     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
713     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
714 };
715
716
717 struct CV_EXPORTS CvDTreeParams
718 {
719     int   max_categories;
720     int   max_depth;
721     int   min_sample_count;
722     int   cv_folds;
723     bool  use_surrogates;
724     bool  use_1se_rule;
725     bool  truncate_pruned_tree;
726     float regression_accuracy;
727     const float* priors;
728
729     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
730         cv_folds(10), use_surrogates(true), use_1se_rule(true),
731         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
732     {}
733
734     CvDTreeParams( int _max_depth, int _min_sample_count,
735                    float _regression_accuracy, bool _use_surrogates,
736                    int _max_categories, int _cv_folds,
737                    bool _use_1se_rule, bool _truncate_pruned_tree,
738                    const float* _priors ) :
739         max_categories(_max_categories), max_depth(_max_depth),
740         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
741         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
742         truncate_pruned_tree(_truncate_pruned_tree),
743         regression_accuracy(_regression_accuracy),
744         priors(_priors)
745     {}
746 };
747
748
749 struct CV_EXPORTS CvDTreeTrainData
750 {
751     CvDTreeTrainData();
752     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
753                       const CvMat* _responses, const CvMat* _var_idx=0,
754                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
755                       const CvMat* _missing_mask=0,
756                       const CvDTreeParams& _params=CvDTreeParams(),
757                       bool _shared=false, bool _add_labels=false );
758     virtual ~CvDTreeTrainData();
759
760     virtual void set_data( const CvMat* _train_data, int _tflag,
761                           const CvMat* _responses, const CvMat* _var_idx=0,
762                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
763                           const CvMat* _missing_mask=0,
764                           const CvDTreeParams& _params=CvDTreeParams(),
765                           bool _shared=false, bool _add_labels=false,
766                           bool _update_data=false );
767     virtual void do_responses_copy();
768
769     virtual void get_vectors( const CvMat* _subsample_idx,
770          float* values, uchar* missing, float* responses, bool get_class_idx=false );
771
772     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
773
774     virtual void write_params( CvFileStorage* fs );
775     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
776
777     // release all the data
778     virtual void clear();
779
780     int get_num_classes() const;
781     int get_var_type(int vi) const;
782     int get_work_var_count() const {return work_var_count;}
783
784     virtual void get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values );    
785     virtual void get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
786     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
787     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** labels );
788     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
789     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf,
790         const float** ord_values, const int** indices );
791     virtual int get_child_buf_idx( CvDTreeNode* n );
792
793     ////////////////////////////////////
794
795     virtual bool set_params( const CvDTreeParams& params );
796     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
797                                    int storage_idx, int offset );
798
799     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
800                 int split_point, int inversed, float quality );
801     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
802     virtual void free_node_data( CvDTreeNode* node );
803     virtual void free_train_data();
804     virtual void free_node( CvDTreeNode* node );
805
806     // inner arrays for getting predictors and responses
807     float* get_pred_float_buf();
808     int* get_pred_int_buf();
809     float* get_resp_float_buf();
810     int* get_resp_int_buf();
811     int* get_cv_lables_buf();
812     int* get_sample_idx_buf();
813
814     vector<vector<float> > pred_float_buf;
815     vector<vector<int> > pred_int_buf;
816     vector<vector<float> > resp_float_buf;
817     vector<vector<int> > resp_int_buf;
818     vector<vector<int> > cv_lables_buf;
819     vector<vector<int> > sample_idx_buf;
820
821     int sample_count, var_all, var_count, max_c_count;
822     int ord_var_count, cat_var_count, work_var_count;
823     bool have_labels, have_priors;
824     bool is_classifier;
825     int tflag;
826
827     const CvMat* train_data;
828     const CvMat* responses;
829     CvMat* responses_copy; // used in Boosting
830
831     int buf_count, buf_size;
832     bool shared;
833     int is_buf_16u;
834     
835     CvMat* cat_count;
836     CvMat* cat_ofs;
837     CvMat* cat_map;
838
839     CvMat* counts;
840     CvMat* buf;
841     CvMat* direction;
842     CvMat* split_buf;
843
844     CvMat* var_idx;
845     CvMat* var_type; // i-th element =
846                      //   k<0  - ordered
847                      //   k>=0 - categorical, see k-th element of cat_* arrays
848     CvMat* priors;
849     CvMat* priors_mult;
850
851     CvDTreeParams params;
852
853     CvMemStorage* tree_storage;
854     CvMemStorage* temp_storage;
855
856     CvDTreeNode* data_root;
857
858     CvSet* node_heap;
859     CvSet* split_heap;
860     CvSet* cv_heap;
861     CvSet* nv_heap;
862
863     CvRNG rng;
864 };
865
866
867 class CV_EXPORTS CvDTree : public CvStatModel
868 {
869 public:
870     CvDTree();
871     virtual ~CvDTree();
872
873     virtual bool train( const CvMat* _train_data, int _tflag,
874                         const CvMat* _responses, const CvMat* _var_idx=0,
875                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
876                         const CvMat* _missing_mask=0,
877                         CvDTreeParams params=CvDTreeParams() );
878
879     virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() );
880
881     virtual float calc_error( CvMLData* _data, int type = CV_TEST_ERROR ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
882
883     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
884
885     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
886                                   bool preprocessed_input=false ) const;
887     virtual const CvMat* get_var_importance();
888     virtual void clear();
889
890     virtual void read( CvFileStorage* fs, CvFileNode* node );
891     virtual void write( CvFileStorage* fs, const char* name );
892
893     // special read & write methods for trees in the tree ensembles
894     virtual void read( CvFileStorage* fs, CvFileNode* node,
895                        CvDTreeTrainData* data );
896     virtual void write( CvFileStorage* fs );
897
898     const CvDTreeNode* get_root() const;
899     int get_pruned_tree_idx() const;
900     CvDTreeTrainData* get_data();
901
902 protected:
903
904     virtual bool do_train( const CvMat* _subsample_idx );
905
906     virtual void try_split_node( CvDTreeNode* n );
907     virtual void split_node_data( CvDTreeNode* n );
908     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
909     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
910     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
911     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
912     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
913     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
914     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
915     virtual double calc_node_dir( CvDTreeNode* node );
916     virtual void complete_node_dir( CvDTreeNode* node );
917     virtual void cluster_categories( const int* vectors, int vector_count,
918         int var_count, int* sums, int k, int* cluster_labels );
919
920     virtual void calc_node_value( CvDTreeNode* node );
921
922     virtual void prune_cv();
923     virtual double update_tree_rnc( int T, int fold );
924     virtual int cut_tree( int T, int fold, double min_alpha );
925     virtual void free_prune_data(bool cut_tree);
926     virtual void free_tree();
927
928     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
929     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
930     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
931     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
932     virtual void write_tree_nodes( CvFileStorage* fs );
933     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
934
935     CvDTreeNode* root;
936
937     int pruned_tree_idx;
938     CvMat* var_importance;
939
940     CvDTreeTrainData* data;
941 };
942
943
944 /****************************************************************************************\
945 *                                   Random Trees Classifier                              *
946 \****************************************************************************************/
947
948 class CvRTrees;
949
950 class CV_EXPORTS CvForestTree: public CvDTree
951 {
952 public:
953     CvForestTree();
954     virtual ~CvForestTree();
955
956     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
957
958     virtual int get_var_count() const {return data ? data->var_count : 0;}
959     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
960
961     /* dummy methods to avoid warnings: BEGIN */
962     virtual bool train( const CvMat* _train_data, int _tflag,
963                         const CvMat* _responses, const CvMat* _var_idx=0,
964                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
965                         const CvMat* _missing_mask=0,
966                         CvDTreeParams params=CvDTreeParams() );
967
968     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
969     virtual void read( CvFileStorage* fs, CvFileNode* node );
970     virtual void read( CvFileStorage* fs, CvFileNode* node,
971                        CvDTreeTrainData* data );
972     /* dummy methods to avoid warnings: END */
973
974 protected:
975     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
976     CvRTrees* forest;
977 };
978
979
980 struct CV_EXPORTS CvRTParams : public CvDTreeParams
981 {
982     //Parameters for the forest
983     bool calc_var_importance; // true <=> RF processes variable importance
984     int nactive_vars;
985     CvTermCriteria term_crit;
986
987     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
988         calc_var_importance(false), nactive_vars(0)
989     {
990         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
991     }
992
993     CvRTParams( int _max_depth, int _min_sample_count,
994                 float _regression_accuracy, bool _use_surrogates,
995                 int _max_categories, const float* _priors, bool _calc_var_importance,
996                 int _nactive_vars, int max_num_of_trees_in_the_forest,
997                 float forest_accuracy, int termcrit_type ) :
998         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
999                        _use_surrogates, _max_categories, 0,
1000                        false, false, _priors ),
1001         calc_var_importance(_calc_var_importance),
1002         nactive_vars(_nactive_vars)
1003     {
1004         term_crit = cvTermCriteria(termcrit_type,
1005             max_num_of_trees_in_the_forest, forest_accuracy);
1006     }
1007 };
1008
1009
1010 class CV_EXPORTS CvRTrees : public CvStatModel
1011 {
1012 public:
1013     CvRTrees();
1014     virtual ~CvRTrees();
1015     virtual bool train( const CvMat* _train_data, int _tflag,
1016                         const CvMat* _responses, const CvMat* _var_idx=0,
1017                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1018                         const CvMat* _missing_mask=0,
1019                         CvRTParams params=CvRTParams() );
1020     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1021     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1022     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1023     virtual void clear();
1024
1025     virtual const CvMat* get_var_importance();
1026     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1027         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1028     
1029     virtual float calc_error( CvMLData* _data, int type = CV_TEST_ERROR ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1030
1031     virtual float get_train_error();    
1032
1033     virtual void read( CvFileStorage* fs, CvFileNode* node );
1034     virtual void write( CvFileStorage* fs, const char* name );
1035
1036     CvMat* get_active_var_mask();
1037     CvRNG* get_rng();
1038
1039     int get_tree_count() const;
1040     CvForestTree* get_tree(int i) const;
1041
1042 protected:
1043
1044     virtual bool grow_forest( const CvTermCriteria term_crit );
1045
1046     // array of the trees of the forest
1047     CvForestTree** trees;
1048     CvDTreeTrainData* data;
1049     int ntrees;
1050     int nclasses;
1051     double oob_error;
1052     CvMat* var_importance;
1053     int nsamples;
1054
1055     CvRNG rng;
1056     CvMat* active_var_mask;
1057 };
1058
1059 /****************************************************************************************\
1060 *                           Extremely randomized trees Classifier                        *
1061 \****************************************************************************************/
1062 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
1063 {
1064     virtual void set_data( const CvMat* _train_data, int _tflag,
1065                           const CvMat* _responses, const CvMat* _var_idx=0,
1066                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1067                           const CvMat* _missing_mask=0,
1068                           const CvDTreeParams& _params=CvDTreeParams(),
1069                           bool _shared=false, bool _add_labels=false,
1070                           bool _update_data=false );
1071     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1072         const float** ord_values, const int** missing );
1073     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices );
1074     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
1075     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
1076     virtual void get_vectors( const CvMat* _subsample_idx,
1077          float* values, uchar* missing, float* responses, bool get_class_idx=false );
1078     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1079     const CvMat* missing_mask;
1080 };
1081
1082 class CV_EXPORTS CvForestERTree : public CvForestTree
1083 {
1084 protected:
1085     virtual double calc_node_dir( CvDTreeNode* node );
1086     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* node, int vi, CvDTreeSplit* _split = 0 );
1087     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* node, int vi, CvDTreeSplit* _split = 0 );
1088     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* node, int vi, CvDTreeSplit* _split = 0 );
1089     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* node, int vi, CvDTreeSplit* _split = 0 );
1090     //virtual void complete_node_dir( CvDTreeNode* node );
1091     virtual void split_node_data( CvDTreeNode* n );
1092 };
1093
1094 class CV_EXPORTS CvERTrees : public CvRTrees
1095 {
1096 public:
1097     CvERTrees();
1098     virtual ~CvERTrees();
1099     virtual bool train( const CvMat* _train_data, int _tflag,
1100                         const CvMat* _responses, const CvMat* _var_idx=0,
1101                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1102                         const CvMat* _missing_mask=0,
1103                         CvRTParams params=CvRTParams());
1104     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1105 protected:
1106     virtual bool grow_forest( const CvTermCriteria term_crit );
1107 };
1108
1109
1110 /****************************************************************************************\
1111 *                                   Boosted tree classifier                              *
1112 \****************************************************************************************/
1113
1114 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
1115 {
1116     int boost_type;
1117     int weak_count;
1118     int split_criteria;
1119     double weight_trim_rate;
1120
1121     CvBoostParams();
1122     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1123                    int max_depth, bool use_surrogates, const float* priors );
1124 };
1125
1126
1127 class CvBoost;
1128
1129 class CV_EXPORTS CvBoostTree: public CvDTree
1130 {
1131 public:
1132     CvBoostTree();
1133     virtual ~CvBoostTree();
1134
1135     virtual bool train( CvDTreeTrainData* _train_data,
1136                         const CvMat* subsample_idx, CvBoost* ensemble );
1137
1138     virtual void scale( double s );
1139     virtual void read( CvFileStorage* fs, CvFileNode* node,
1140                        CvBoost* ensemble, CvDTreeTrainData* _data );
1141     virtual void clear();
1142
1143     /* dummy methods to avoid warnings: BEGIN */
1144     virtual bool train( const CvMat* _train_data, int _tflag,
1145                         const CvMat* _responses, const CvMat* _var_idx=0,
1146                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1147                         const CvMat* _missing_mask=0,
1148                         CvDTreeParams params=CvDTreeParams() );
1149     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1150
1151     virtual void read( CvFileStorage* fs, CvFileNode* node );
1152     virtual void read( CvFileStorage* fs, CvFileNode* node,
1153                        CvDTreeTrainData* data );
1154     /* dummy methods to avoid warnings: END */
1155
1156 protected:
1157
1158     virtual void try_split_node( CvDTreeNode* n );
1159     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
1160     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
1161     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
1162     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
1163     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
1164     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, CvDTreeSplit* _split = 0 );
1165     virtual void calc_node_value( CvDTreeNode* n );
1166     virtual double calc_node_dir( CvDTreeNode* n );
1167
1168     CvBoost* ensemble;
1169 };
1170
1171
1172 class CV_EXPORTS CvBoost : public CvStatModel
1173 {
1174 public:
1175     // Boosting type
1176     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1177
1178     // Splitting criteria
1179     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1180
1181     CvBoost();
1182     virtual ~CvBoost();
1183
1184     CvBoost( const CvMat* _train_data, int _tflag,
1185              const CvMat* _responses, const CvMat* _var_idx=0,
1186              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1187              const CvMat* _missing_mask=0,
1188              CvBoostParams params=CvBoostParams() );
1189
1190     virtual bool train( const CvMat* _train_data, int _tflag,
1191              const CvMat* _responses, const CvMat* _var_idx=0,
1192              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1193              const CvMat* _missing_mask=0,
1194              CvBoostParams params=CvBoostParams(),
1195              bool update=false );
1196
1197     virtual bool train( CvMLData* data,
1198              CvBoostParams params=CvBoostParams(),
1199              bool update=false );
1200
1201     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
1202                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1203                            bool raw_mode=false, bool return_sum=false ) const;
1204
1205     virtual float calc_error( CvMLData* _data, int type = CV_TEST_ERROR ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1206
1207     virtual void prune( CvSlice slice );
1208
1209     virtual void clear();
1210
1211     virtual void write( CvFileStorage* storage, const char* name );
1212     virtual void read( CvFileStorage* storage, CvFileNode* node );
1213     virtual const CvMat* get_active_vars(bool absolute_idx=true);
1214
1215     CvSeq* get_weak_predictors();
1216
1217     CvMat* get_weights();
1218     CvMat* get_subtree_weights();
1219     CvMat* get_weak_response();
1220     const CvBoostParams& get_params() const;
1221     const CvDTreeTrainData* get_data() const;
1222
1223 protected:
1224
1225     virtual bool set_params( const CvBoostParams& _params );
1226     virtual void update_weights( CvBoostTree* tree );
1227     virtual void trim_weights();
1228     virtual void write_params( CvFileStorage* fs );
1229     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1230
1231     CvDTreeTrainData* data;
1232     CvBoostParams params;
1233     CvSeq* weak;
1234
1235     CvMat* active_vars;
1236     CvMat* active_vars_abs;
1237     bool have_active_cat_vars;
1238
1239     CvMat* orig_response;
1240     CvMat* sum_response;
1241     CvMat* weak_eval;
1242     CvMat* subsample_mask;
1243     CvMat* weights;
1244     CvMat* subtree_weights;
1245     bool have_subsample;
1246 };
1247
1248
1249 /****************************************************************************************\
1250 *                              Artificial Neural Networks (ANN)                          *
1251 \****************************************************************************************/
1252
1253 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1254
1255 struct CV_EXPORTS CvANN_MLP_TrainParams
1256 {
1257     CvANN_MLP_TrainParams();
1258     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1259                            double param1, double param2=0 );
1260     ~CvANN_MLP_TrainParams();
1261
1262     enum { BACKPROP=0, RPROP=1 };
1263
1264     CvTermCriteria term_crit;
1265     int train_method;
1266
1267     // backpropagation parameters
1268     double bp_dw_scale, bp_moment_scale;
1269
1270     // rprop parameters
1271     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1272 };
1273
1274
1275 class CV_EXPORTS CvANN_MLP : public CvStatModel
1276 {
1277 public:
1278     CvANN_MLP();
1279     CvANN_MLP( const CvMat* _layer_sizes,
1280                int _activ_func=SIGMOID_SYM,
1281                double _f_param1=0, double _f_param2=0 );
1282
1283     virtual ~CvANN_MLP();
1284
1285     virtual void create( const CvMat* _layer_sizes,
1286                          int _activ_func=SIGMOID_SYM,
1287                          double _f_param1=0, double _f_param2=0 );
1288
1289     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
1290                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
1291                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1292                        int flags=0 );
1293     virtual float predict( const CvMat* _inputs,
1294                            CvMat* _outputs ) const;
1295
1296     virtual void clear();
1297
1298     // possible activation functions
1299     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1300
1301     // available training flags
1302     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1303
1304     virtual void read( CvFileStorage* fs, CvFileNode* node );
1305     virtual void write( CvFileStorage* storage, const char* name );
1306
1307     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1308     const CvMat* get_layer_sizes() { return layer_sizes; }
1309     double* get_weights(int layer)
1310     {
1311         return layer_sizes && weights &&
1312             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1313     }
1314
1315 protected:
1316
1317     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1318             const CvMat* _sample_weights, const CvMat* _sample_idx,
1319             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1320
1321     // sequential random backpropagation
1322     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1323
1324     // RPROP algorithm
1325     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1326
1327     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1328     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1329     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1330                                  double _f_param1=0, double _f_param2=0 );
1331     virtual void init_weights();
1332     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1333     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1334     virtual void calc_input_scale( const CvVectors* vecs, int flags );
1335     virtual void calc_output_scale( const CvVectors* vecs, int flags );
1336
1337     virtual void write_params( CvFileStorage* fs );
1338     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1339
1340     CvMat* layer_sizes;
1341     CvMat* wbuf;
1342     CvMat* sample_weights;
1343     double** weights;
1344     double f_param1, f_param2;
1345     double min_val, max_val, min_val1, max_val1;
1346     int activ_func;
1347     int max_count, max_buf_sz;
1348     CvANN_MLP_TrainParams params;
1349     CvRNG rng;
1350 };
1351
1352 #if 0
1353 /****************************************************************************************\
1354 *                            Convolutional Neural Network                                *
1355 \****************************************************************************************/
1356 typedef struct CvCNNLayer CvCNNLayer;
1357 typedef struct CvCNNetwork CvCNNetwork;
1358
1359 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
1360 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
1361 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
1362
1363 #define CV_CNN_GRAD_ESTIM_RANDOM        0
1364 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
1365
1366 #define ICV_CNN_LAYER                0x55550000
1367 #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
1368 #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
1369 #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
1370
1371 #define ICV_IS_CNN_LAYER( layer )                                          \
1372     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
1373         == ICV_CNN_LAYER ))
1374
1375 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
1376     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1377         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
1378
1379 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
1380     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1381         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
1382
1383 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
1384     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1385         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
1386
1387 typedef void (CV_CDECL *CvCNNLayerForward)
1388     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
1389
1390 typedef void (CV_CDECL *CvCNNLayerBackward)
1391     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
1392
1393 typedef void (CV_CDECL *CvCNNLayerRelease)
1394     (CvCNNLayer** layer);
1395
1396 typedef void (CV_CDECL *CvCNNetworkAddLayer)
1397     (CvCNNetwork* network, CvCNNLayer* layer);
1398
1399 typedef void (CV_CDECL *CvCNNetworkRelease)
1400     (CvCNNetwork** network);
1401
1402 #define CV_CNN_LAYER_FIELDS()           \
1403     /* Indicator of the layer's type */ \
1404     int flags;                          \
1405                                         \
1406     /* Number of input images */        \
1407     int n_input_planes;                 \
1408     /* Height of each input image */    \
1409     int input_height;                   \
1410     /* Width of each input image */     \
1411     int input_width;                    \
1412                                         \
1413     /* Number of output images */       \
1414     int n_output_planes;                \
1415     /* Height of each output image */   \
1416     int output_height;                  \
1417     /* Width of each output image */    \
1418     int output_width;                   \
1419                                         \
1420     /* Learning rate at the first iteration */                      \
1421     float init_learn_rate;                                          \
1422     /* Dynamics of learning rate decreasing */                      \
1423     int learn_rate_decrease_type;                                   \
1424     /* Trainable weights of the layer (including bias) */           \
1425     /* i-th row is a set of weights of the i-th output plane */     \
1426     CvMat* weights;                                                 \
1427                                                                     \
1428     CvCNNLayerForward  forward;                                     \
1429     CvCNNLayerBackward backward;                                    \
1430     CvCNNLayerRelease  release;                                     \
1431     /* Pointers to the previous and next layers in the network */   \
1432     CvCNNLayer* prev_layer;                                         \
1433     CvCNNLayer* next_layer
1434
1435 typedef struct CvCNNLayer
1436 {
1437     CV_CNN_LAYER_FIELDS();
1438 }CvCNNLayer;
1439
1440 typedef struct CvCNNConvolutionLayer
1441 {
1442     CV_CNN_LAYER_FIELDS();
1443     // Kernel size (height and width) for convolution.
1444     int K;
1445     // connections matrix, (i,j)-th element is 1 iff there is a connection between
1446     // i-th plane of the current layer and j-th plane of the previous layer;
1447     // (i,j)-th element is equal to 0 otherwise
1448     CvMat *connect_mask;
1449     // value of the learning rate for updating weights at the first iteration
1450 }CvCNNConvolutionLayer;
1451
1452 typedef struct CvCNNSubSamplingLayer
1453 {
1454     CV_CNN_LAYER_FIELDS();
1455     // ratio between the heights (or widths - ratios are supposed to be equal)
1456     // of the input and output planes
1457     int sub_samp_scale;
1458     // amplitude of sigmoid activation function
1459     float a;
1460     // scale parameter of sigmoid activation function
1461     float s;
1462     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
1463     // - is the vector used in computing of the activation function in backward
1464     CvMat* exp2ssumWX;
1465     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
1466     // - is the vector used in computing of the activation function in backward
1467     CvMat* sumX;
1468 }CvCNNSubSamplingLayer;
1469
1470 // Structure of the last layer.
1471 typedef struct CvCNNFullConnectLayer
1472 {
1473     CV_CNN_LAYER_FIELDS();
1474     // amplitude of sigmoid activation function
1475     float a;
1476     // scale parameter of sigmoid activation function
1477     float s;
1478     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
1479     // activation function and it's derivative by the formulae
1480     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
1481     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
1482     CvMat* exp2ssumWX;
1483 }CvCNNFullConnectLayer;
1484
1485 typedef struct CvCNNetwork
1486 {
1487     int n_layers;
1488     CvCNNLayer* layers;
1489     CvCNNetworkAddLayer add_layer;
1490     CvCNNetworkRelease release;
1491 }CvCNNetwork;
1492
1493 typedef struct CvCNNStatModel
1494 {
1495     CV_STAT_MODEL_FIELDS();
1496     CvCNNetwork* network;
1497     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
1498     CvMat* etalons;
1499     // classes labels
1500     CvMat* cls_labels;
1501 }CvCNNStatModel;
1502
1503 typedef struct CvCNNStatModelParams
1504 {
1505     CV_STAT_MODEL_PARAM_FIELDS();
1506     // network must be created by the functions cvCreateCNNetwork and <add_layer>
1507     CvCNNetwork* network;
1508     CvMat* etalons;
1509     // termination criteria
1510     int max_iter;
1511     int start_iter;
1512     int grad_estim_type;
1513 }CvCNNStatModelParams;
1514
1515 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
1516     int n_input_planes, int input_height, int input_width,
1517     int n_output_planes, int K,
1518     float init_learn_rate, int learn_rate_decrease_type,
1519     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
1520
1521 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
1522     int n_input_planes, int input_height, int input_width,
1523     int sub_samp_scale, float a, float s,
1524     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
1525
1526 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
1527     int n_inputs, int n_outputs, float a, float s,
1528     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
1529
1530 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
1531
1532 CVAPI(CvStatModel*) cvTrainCNNClassifier(
1533             const CvMat* train_data, int tflag,
1534             const CvMat* responses,
1535             const CvStatModelParams* params,
1536             const CvMat* CV_DEFAULT(0),
1537             const CvMat* sample_idx CV_DEFAULT(0),
1538             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
1539
1540 /****************************************************************************************\
1541 *                               Estimate classifiers algorithms                          *
1542 \****************************************************************************************/
1543 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
1544                     ( const CvStatModel* estimateModel );
1545
1546 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
1547                     ( CvStatModel* estimateModel );
1548
1549 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
1550                     ( CvStatModel* estimateModel,
1551                 const CvStatModel* model,
1552                 const CvMat*       features,
1553                       int          sample_t_flag,
1554                 const CvMat*       responses );
1555
1556 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
1557                     ( CvStatModel* estimateModel,
1558                 const CvStatModel* model );
1559
1560 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
1561                     ( const CvStatModel* estimateModel,
1562                             float*       correlation );
1563
1564 typedef void (CV_CDECL *CvStatModelEstimateReset)
1565                     ( CvStatModel* estimateModel );
1566
1567 //-------------------------------- Cross-validation --------------------------------------
1568 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
1569     CV_STAT_MODEL_PARAM_FIELDS();                                 \
1570     int     k_fold;                                               \
1571     int     is_regression;                                        \
1572     CvRNG*  rng
1573
1574 typedef struct CvCrossValidationParams
1575 {
1576     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
1577 } CvCrossValidationParams;
1578
1579 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
1580     CvStatModelEstimateGetMat               getTrainIdxMat; \
1581     CvStatModelEstimateGetMat               getCheckIdxMat; \
1582     CvStatModelEstimateNextStep             nextStep;       \
1583     CvStatModelEstimateCheckClassifier      check;          \
1584     CvStatModelEstimateGetCurrentResult     getResult;      \
1585     CvStatModelEstimateReset                reset;          \
1586     int     is_regression;                                  \
1587     int     folds_all;                                      \
1588     int     samples_all;                                    \
1589     int*    sampleIdxAll;                                   \
1590     int*    folds;                                          \
1591     int     max_fold_size;                                  \
1592     int         current_fold;                               \
1593     int         is_checked;                                 \
1594     CvMat*      sampleIdxTrain;                             \
1595     CvMat*      sampleIdxEval;                              \
1596     CvMat*      predict_results;                            \
1597     int     correct_results;                                \
1598     int     all_results;                                    \
1599     double  sq_error;                                       \
1600     double  sum_correct;                                    \
1601     double  sum_predict;                                    \
1602     double  sum_cc;                                         \
1603     double  sum_pp;                                         \
1604     double  sum_cp
1605
1606 typedef struct CvCrossValidationModel
1607 {
1608     CV_STAT_MODEL_FIELDS();
1609     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
1610 } CvCrossValidationModel;
1611
1612 CVAPI(CvStatModel*)
1613 cvCreateCrossValidationEstimateModel
1614            ( int                samples_all,
1615        const CvStatModelParams* estimateParams CV_DEFAULT(0),
1616        const CvMat*             sampleIdx CV_DEFAULT(0) );
1617
1618 CVAPI(float)
1619 cvCrossValidation( const CvMat*             trueData,
1620                          int                tflag,
1621                    const CvMat*             trueClasses,
1622                          CvStatModel*     (*createClassifier)( const CvMat*,
1623                                                                      int,
1624                                                                const CvMat*,
1625                                                                const CvStatModelParams*,
1626                                                                const CvMat*,
1627                                                                const CvMat*,
1628                                                                const CvMat*,
1629                                                                const CvMat* ),
1630                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
1631                    const CvStatModelParams* trainParams CV_DEFAULT(0),
1632                    const CvMat*             compIdx CV_DEFAULT(0),
1633                    const CvMat*             sampleIdx CV_DEFAULT(0),
1634                          CvStatModel**      pCrValModel CV_DEFAULT(0),
1635                    const CvMat*             typeMask CV_DEFAULT(0),
1636                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
1637 #endif
1638
1639 /****************************************************************************************\
1640 *                           Auxilary functions declarations                              *
1641 \****************************************************************************************/
1642
1643 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1644    average row vector, <cov> - symmetric covariation matrix */
1645 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1646                            CvRNG* rng CV_DEFAULT(0) );
1647
1648 /* Generates sample from gaussian mixture distribution */
1649 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1650                                CvMat* covs[],
1651                                float weights[],
1652                                int clsnum,
1653                                CvMat* sample,
1654                                CvMat* sampClasses CV_DEFAULT(0) );
1655
1656 #define CV_TS_CONCENTRIC_SPHERES 0
1657
1658 /* creates test set */
1659 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1660                  int num_samples,
1661                  int num_features,
1662                  CvMat** responses,
1663                  int num_classes, ... );
1664
1665
1666 #endif
1667
1668 /****************************************************************************************\
1669 *                                      Data                                             *
1670 \****************************************************************************************/
1671
1672 #include <map>
1673 #include <string>
1674 #include <iostream>
1675 using namespace std;
1676
1677 #define CV_COUNT     0
1678 #define CV_PORTION   1
1679
1680 struct CV_EXPORTS CvTrainTestSplit
1681 {
1682 public:
1683     CvTrainTestSplit();
1684     CvTrainTestSplit( int _train_sample_count, bool _mix = true);
1685     CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
1686
1687     union
1688     {
1689         int count;
1690         float portion;
1691     } train_sample_part;
1692     int train_sample_part_mode;
1693
1694     union
1695     {
1696         int *count;
1697         float *portion;
1698     } *class_part;
1699     int class_part_mode;
1700
1701     bool mix;    
1702 };
1703
1704 class CV_EXPORTS CvMLData
1705 {
1706 public:
1707     CvMLData();
1708     virtual ~CvMLData();
1709
1710     // returns:
1711     // 0 - OK  
1712     // 1 - file can not be opened or is not correct
1713     int read_csv(const char* filename);
1714
1715     const CvMat* get_values(){ return values; };
1716
1717     const CvMat* get_response();
1718
1719     const CvMat* get_missing(){ return missing; };
1720
1721     void set_response_idx( int idx ); // idx < 0 to set all vars as predictors
1722     int get_response_idx() { return response_idx; }
1723
1724     const CvMat* get_train_sample_idx() { return train_sample_idx; };
1725     const CvMat* get_test_sample_idx() { return test_sample_idx; };
1726     void mix_train_and_test_idx();
1727     void set_train_test_split( const CvTrainTestSplit * spl);
1728     
1729     const CvMat* get_var_idx();
1730     void chahge_var_idx( int vi, bool state );
1731
1732     const CvMat* get_var_types();
1733     int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
1734     // following 2 methods enable to change vars type
1735     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1736     // with numerical labels; in the other cases var types are correctly determined automatically
1737     void set_var_types( const char* str );  // str examples:
1738                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1739                                             // "cat", "ord" (all vars are categorical/ordered)
1740     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
1741  
1742     void set_delimiter( char ch );
1743     char get_delimiter() { return delimiter; };
1744
1745     void set_miss_ch( char ch );
1746     char get_miss_ch() { return miss_ch; };
1747     
1748 protected:
1749     virtual void clear();
1750
1751     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
1752     void free_train_test_idx();
1753     
1754     char delimiter;
1755     char miss_ch;
1756     //char flt_separator;
1757
1758     CvMat* values;
1759     CvMat* missing;
1760     CvMat* var_types;
1761     CvMat* var_idx_mask;
1762
1763     CvMat* response_out; // header
1764     CvMat* var_idx_out; // mat
1765     CvMat* var_types_out; // mat
1766
1767     int response_idx;
1768
1769     int train_sample_count;
1770     bool mix;
1771    
1772     int total_class_count;
1773     map<string, int> *class_map;
1774
1775     CvMat* train_sample_idx;
1776     CvMat* test_sample_idx;
1777     int* sample_idx; // data of train_sample_idx and test_sample_idx
1778
1779     CvRNG rng;
1780 };
1781
1782 #endif /*__ML_H__*/
1783 /* End of file. */