]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/include/opencv/ml.h
renamed all the _[A-Z] variables to avoid possible name conflicts.
[opencv.git] / opencv / include / opencv / ml.h
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #ifndef __OPENCV_ML_H__
42 #define __OPENCV_ML_H__
43
44 // disable deprecation warning which appears in VisualStudio 8.0
45 #if _MSC_VER >= 1400
46 #pragma warning( disable : 4996 )
47 #endif
48
49 #ifndef SKIP_INCLUDES
50
51   #include "cxcore.h"
52   #include <limits.h>
53
54   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
55     #include <windows.h>
56   #endif
57
58 #else // SKIP_INCLUDES
59
60   #if defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64
61     #define CV_CDECL __cdecl
62     #define CV_STDCALL __stdcall
63   #else
64     #define CV_CDECL
65     #define CV_STDCALL
66   #endif
67
68   #ifndef CV_EXTERN_C
69     #ifdef __cplusplus
70       #define CV_EXTERN_C extern "C"
71       #define CV_DEFAULT(val) = val
72     #else
73       #define CV_EXTERN_C
74       #define CV_DEFAULT(val)
75     #endif
76   #endif
77
78   #ifndef CV_EXTERN_C_FUNCPTR
79     #ifdef __cplusplus
80       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
81     #else
82       #define CV_EXTERN_C_FUNCPTR(x) typedef x
83     #endif
84   #endif
85
86   #ifndef CV_INLINE
87     #if defined __cplusplus
88       #define CV_INLINE inline
89     #elif (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && !defined __GNUC__
90       #define CV_INLINE __inline
91     #else
92       #define CV_INLINE static
93     #endif
94   #endif /* CV_INLINE */
95
96   #if (defined WIN32 || defined _WIN32 || defined WIN64 || defined _WIN64) && defined CVAPI_EXPORTS
97     #define CV_EXPORTS __declspec(dllexport)
98   #else
99     #define CV_EXPORTS
100   #endif
101
102   #ifndef CVAPI
103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
104   #endif
105
106 #endif // SKIP_INCLUDES
107
108
109 #ifdef __cplusplus
110
111 // Apple defines a check() macro somewhere in the debug headers
112 // that interferes with a method definiton in this header
113 #undef check
114
115 /****************************************************************************************\
116 *                               Main struct definitions                                  *
117 \****************************************************************************************/
118
119 /* log(2*PI) */
120 #define CV_LOG2PI (1.8378770664093454835606594728112)
121
122 /* columns of <trainData> matrix are training samples */
123 #define CV_COL_SAMPLE 0
124
125 /* rows of <trainData> matrix are training samples */
126 #define CV_ROW_SAMPLE 1
127
128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
129
130 struct CvVectors
131 {
132     int type;
133     int dims, count;
134     CvVectors* next;
135     union
136     {
137         uchar** ptr;
138         float** fl;
139         double** db;
140     } data;
141 };
142
143 #if 0
144 /* A structure, representing the lattice range of statmodel parameters.
145    It is used for optimizing statmodel parameters by cross-validation method.
146    The lattice is logarithmic, so <step> must be greater then 1. */
147 typedef struct CvParamLattice
148 {
149     double min_val;
150     double max_val;
151     double step;
152 }
153 CvParamLattice;
154
155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
156                                          double log_step )
157 {
158     CvParamLattice pl;
159     pl.min_val = MIN( min_val, max_val );
160     pl.max_val = MAX( min_val, max_val );
161     pl.step = MAX( log_step, 1. );
162     return pl;
163 }
164
165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
166 {
167     CvParamLattice pl = {0,0,0};
168     return pl;
169 }
170 #endif
171
172 /* Variable type */
173 #define CV_VAR_NUMERICAL    0
174 #define CV_VAR_ORDERED      0
175 #define CV_VAR_CATEGORICAL  1
176
177 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
178 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
179 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
180 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
181 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
182 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
183 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
184 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
185 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
186
187 #define CV_TRAIN_ERROR  0
188 #define CV_TEST_ERROR   1
189
190 class CV_EXPORTS CvStatModel
191 {
192 public:
193     CvStatModel();
194     virtual ~CvStatModel();
195
196     virtual void clear();
197
198     virtual void save( const char* filename, const char* name=0 ) const;
199     virtual void load( const char* filename, const char* name=0 );
200
201     virtual void write( CvFileStorage* storage, const char* name ) const;
202     virtual void read( CvFileStorage* storage, CvFileNode* node );
203
204 protected:
205     const char* default_model_name;
206 };
207
208 /****************************************************************************************\
209 *                                 Normal Bayes Classifier                                *
210 \****************************************************************************************/
211
212 /* The structure, representing the grid range of statmodel parameters.
213    It is used for optimizing statmodel accuracy by varying model parameters,
214    the accuracy estimate being computed by cross-validation.
215    The grid is logarithmic, so <step> must be greater then 1. */
216
217 class CvMLData;
218
219 struct CV_EXPORTS CvParamGrid
220 {
221     // SVM params type
222     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
223
224     CvParamGrid()
225     {
226         min_val = max_val = step = 0;
227     }
228
229     CvParamGrid( double _min_val, double _max_val, double log_step )
230     {
231         min_val = _min_val;
232         max_val = _max_val;
233         step = log_step;
234     }
235     //CvParamGrid( int param_id );
236     bool check() const;
237
238     double min_val;
239     double max_val;
240     double step;
241 };
242
243 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
244 {
245 public:
246     CvNormalBayesClassifier();
247     virtual ~CvNormalBayesClassifier();
248
249     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
250         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
251     
252     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
253         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
254    
255     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
256     virtual void clear();
257
258 #ifndef SWIG
259     CvNormalBayesClassifier( const cv::Mat& _train_data, const cv::Mat& _responses,
260                             const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat() );
261     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
262                        const cv::Mat& _var_idx = cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
263                        bool update=false );
264     virtual float predict( const cv::Mat& _samples, cv::Mat* results=0 ) const;
265 #endif
266     
267     virtual void write( CvFileStorage* storage, const char* name ) const;
268     virtual void read( CvFileStorage* storage, CvFileNode* node );
269
270 protected:
271     int     var_count, var_all;
272     CvMat*  var_idx;
273     CvMat*  cls_labels;
274     CvMat** count;
275     CvMat** sum;
276     CvMat** productsum;
277     CvMat** avg;
278     CvMat** inv_eigen_values;
279     CvMat** cov_rotate_mats;
280     CvMat*  c;
281 };
282
283
284 /****************************************************************************************\
285 *                          K-Nearest Neighbour Classifier                                *
286 \****************************************************************************************/
287
288 // k Nearest Neighbors
289 class CV_EXPORTS CvKNearest : public CvStatModel
290 {
291 public:
292
293     CvKNearest();
294     virtual ~CvKNearest();
295
296     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
297                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
298
299     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
300                         const CvMat* _sample_idx=0, bool is_regression=false,
301                         int _max_k=32, bool _update_base=false );
302     
303     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
304         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
305     
306 #ifndef SWIG
307     CvKNearest( const cv::Mat& _train_data, const cv::Mat& _responses,
308                const cv::Mat& _sample_idx=cv::Mat(), bool _is_regression=false, int max_k=32 );
309     
310     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
311                        const cv::Mat& _sample_idx=cv::Mat(), bool is_regression=false,
312                        int _max_k=32, bool _update_base=false );    
313     
314     virtual float find_nearest( const cv::Mat& _samples, int k, cv::Mat* results=0,
315                                 const float** neighbors=0,
316                                 cv::Mat* neighbor_responses=0,
317                                 cv::Mat* dist=0 ) const;
318 #endif
319     
320     virtual void clear();
321     int get_max_k() const;
322     int get_var_count() const;
323     int get_sample_count() const;
324     bool is_regression() const;
325
326 protected:
327
328     virtual float write_results( int k, int k1, int start, int end,
329         const float* neighbor_responses, const float* dist, CvMat* _results,
330         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
331
332     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
333         float* neighbor_responses, const float** neighbors, float* dist ) const;
334
335
336     int max_k, var_count;
337     int total;
338     bool regression;
339     CvVectors* samples;
340 };
341
342 /****************************************************************************************\
343 *                                   Support Vector Machines                              *
344 \****************************************************************************************/
345
346 // SVM training parameters
347 struct CV_EXPORTS CvSVMParams
348 {
349     CvSVMParams();
350     CvSVMParams( int _svm_type, int _kernel_type,
351                  double _degree, double _gamma, double _coef0,
352                  double Cvalue, double _nu, double _p,
353                  CvMat* _class_weights, CvTermCriteria _term_crit );
354
355     int         svm_type;
356     int         kernel_type;
357     double      degree; // for poly
358     double      gamma;  // for poly/rbf/sigmoid
359     double      coef0;  // for poly/sigmoid
360
361     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
362     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
363     double      p; // for CV_SVM_EPS_SVR
364     CvMat*      class_weights; // for CV_SVM_C_SVC
365     CvTermCriteria term_crit; // termination criteria
366 };
367
368
369 struct CV_EXPORTS CvSVMKernel
370 {
371     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
372                                        const float* another, float* results );
373     CvSVMKernel();
374     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
375     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
376     virtual ~CvSVMKernel();
377
378     virtual void clear();
379     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
380
381     const CvSVMParams* params;
382     Calc calc_func;
383
384     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
385                                     const float* another, float* results,
386                                     double alpha, double beta );
387
388     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
389                               const float* another, float* results );
390     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
391                            const float* another, float* results );
392     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
393                             const float* another, float* results );
394     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
395                                const float* another, float* results );
396 };
397
398
399 struct CvSVMKernelRow
400 {
401     CvSVMKernelRow* prev;
402     CvSVMKernelRow* next;
403     float* data;
404 };
405
406
407 struct CvSVMSolutionInfo
408 {
409     double obj;
410     double rho;
411     double upper_bound_p;
412     double upper_bound_n;
413     double r;   // for Solver_NU
414 };
415
416 class CV_EXPORTS CvSVMSolver
417 {
418 public:
419     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
420     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
421     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
422
423     CvSVMSolver();
424
425     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
426                  int alpha_count, double* alpha, double Cp, double Cn,
427                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
428                  SelectWorkingSet select_working_set, CalcRho calc_rho );
429     virtual bool create( int count, int var_count, const float** samples, schar* y,
430                  int alpha_count, double* alpha, double Cp, double Cn,
431                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
432                  SelectWorkingSet select_working_set, CalcRho calc_rho );
433     virtual ~CvSVMSolver();
434
435     virtual void clear();
436     virtual bool solve_generic( CvSVMSolutionInfo& si );
437
438     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
439                               double Cp, double Cn, CvMemStorage* storage,
440                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
441     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
442                                CvMemStorage* storage, CvSVMKernel* kernel,
443                                double* alpha, CvSVMSolutionInfo& si );
444     virtual bool solve_one_class( int count, int var_count, const float** samples,
445                                   CvMemStorage* storage, CvSVMKernel* kernel,
446                                   double* alpha, CvSVMSolutionInfo& si );
447
448     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
449                                 CvMemStorage* storage, CvSVMKernel* kernel,
450                                 double* alpha, CvSVMSolutionInfo& si );
451
452     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
453                                CvMemStorage* storage, CvSVMKernel* kernel,
454                                double* alpha, CvSVMSolutionInfo& si );
455
456     virtual float* get_row_base( int i, bool* _existed );
457     virtual float* get_row( int i, float* dst );
458
459     int sample_count;
460     int var_count;
461     int cache_size;
462     int cache_line_size;
463     const float** samples;
464     const CvSVMParams* params;
465     CvMemStorage* storage;
466     CvSVMKernelRow lru_list;
467     CvSVMKernelRow* rows;
468
469     int alpha_count;
470
471     double* G;
472     double* alpha;
473
474     // -1 - lower bound, 0 - free, 1 - upper bound
475     schar* alpha_status;
476
477     schar* y;
478     double* b;
479     float* buf[2];
480     double eps;
481     int max_iter;
482     double C[2];  // C[0] == Cn, C[1] == Cp
483     CvSVMKernel* kernel;
484
485     SelectWorkingSet select_working_set_func;
486     CalcRho calc_rho_func;
487     GetRow get_row_func;
488
489     virtual bool select_working_set( int& i, int& j );
490     virtual bool select_working_set_nu_svm( int& i, int& j );
491     virtual void calc_rho( double& rho, double& r );
492     virtual void calc_rho_nu_svm( double& rho, double& r );
493
494     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
495     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
496     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
497 };
498
499
500 struct CvSVMDecisionFunc
501 {
502     double rho;
503     int sv_count;
504     double* alpha;
505     int* sv_index;
506 };
507
508
509 // SVM model
510 class CV_EXPORTS CvSVM : public CvStatModel
511 {
512 public:
513     // SVM type
514     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
515
516     // SVM kernel type
517     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
518
519     // SVM params type
520     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
521
522     CvSVM();
523     virtual ~CvSVM();
524
525     CvSVM( const CvMat* _train_data, const CvMat* _responses,
526            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
527            CvSVMParams _params=CvSVMParams() );
528
529     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
530                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
531                         CvSVMParams _params=CvSVMParams() );
532     
533     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
534         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
535         int k_fold = 10,
536         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
537         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
538         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
539         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
540         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
541         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
542
543     virtual float predict( const CvMat* _sample, bool returnDFVal=false ) const;
544
545 #ifndef SWIG
546     CvSVM( const cv::Mat& _train_data, const cv::Mat& _responses,
547           const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
548           CvSVMParams _params=CvSVMParams() );
549     
550     virtual bool train( const cv::Mat& _train_data, const cv::Mat& _responses,
551                        const cv::Mat& _var_idx=cv::Mat(), const cv::Mat& _sample_idx=cv::Mat(),
552                        CvSVMParams _params=CvSVMParams() );
553     
554     virtual bool train_auto( const cv::Mat& _train_data, const cv::Mat& _responses,
555                             const cv::Mat& _var_idx, const cv::Mat& _sample_idx, CvSVMParams _params,
556                             int k_fold = 10,
557                             CvParamGrid C_grid      = get_default_grid(CvSVM::C),
558                             CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
559                             CvParamGrid p_grid      = get_default_grid(CvSVM::P),
560                             CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
561                             CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
562                             CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
563     virtual float predict( const cv::Mat& _sample, bool returnDFVal=false ) const;    
564 #endif
565     
566     virtual int get_support_vector_count() const;
567     virtual const float* get_support_vector(int i) const;
568     virtual CvSVMParams get_params() const { return params; };
569     virtual void clear();
570
571     static CvParamGrid get_default_grid( int param_id );
572
573     virtual void write( CvFileStorage* storage, const char* name ) const;
574     virtual void read( CvFileStorage* storage, CvFileNode* node );
575     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
576
577 protected:
578
579     virtual bool set_params( const CvSVMParams& _params );
580     virtual bool train1( int sample_count, int var_count, const float** samples,
581                     const void* _responses, double Cp, double Cn,
582                     CvMemStorage* _storage, double* alpha, double& rho );
583     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
584                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
585     virtual void create_kernel();
586     virtual void create_solver();
587
588     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
589
590     virtual void write_params( CvFileStorage* fs ) const;
591     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
592
593     CvSVMParams params;
594     CvMat* class_labels;
595     int var_all;
596     float** sv;
597     int sv_total;
598     CvMat* var_idx;
599     CvMat* class_weights;
600     CvSVMDecisionFunc* decision_func;
601     CvMemStorage* storage;
602
603     CvSVMSolver* solver;
604     CvSVMKernel* kernel;
605 };
606
607 /****************************************************************************************\
608 *                              Expectation - Maximization                                *
609 \****************************************************************************************/
610
611 struct CV_EXPORTS CvEMParams
612 {
613     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
614         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
615     {
616         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
617     }
618
619     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
620                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
621                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
622                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
623                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
624                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
625     {}
626
627     int nclusters;
628     int cov_mat_type;
629     int start_step;
630     const CvMat* probs;
631     const CvMat* weights;
632     const CvMat* means;
633     const CvMat** covs;
634     CvTermCriteria term_crit;
635 };
636
637
638 class CV_EXPORTS CvEM : public CvStatModel
639 {
640 public:
641     // Type of covariation matrices
642     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
643
644     // The initial step
645     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
646
647     CvEM();
648     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
649           CvEMParams params=CvEMParams(), CvMat* labels=0 );
650     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
651
652     virtual ~CvEM();
653
654     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
655                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
656
657     virtual float predict( const CvMat* sample, CvMat* probs ) const;
658
659 #ifndef SWIG
660     CvEM( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
661          CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
662     
663     virtual bool train( const cv::Mat& samples, const cv::Mat& sample_idx=cv::Mat(),
664                        CvEMParams params=CvEMParams(), cv::Mat* labels=0 );
665     
666     virtual float predict( const cv::Mat& sample, cv::Mat* probs ) const;
667 #endif
668     
669     virtual void clear();
670
671     int           get_nclusters() const;
672     const CvMat*  get_means()     const;
673     const CvMat** get_covs()      const;
674     const CvMat*  get_weights()   const;
675     const CvMat*  get_probs()     const;
676
677     inline double         get_log_likelihood     () const { return log_likelihood;     };
678     
679 //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
680 //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
681 //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
682
683 protected:
684
685     virtual void set_params( const CvEMParams& params,
686                              const CvVectors& train_data );
687     virtual void init_em( const CvVectors& train_data );
688     virtual double run_em( const CvVectors& train_data );
689     virtual void init_auto( const CvVectors& samples );
690     virtual void kmeans( const CvVectors& train_data, int nclusters,
691                          CvMat* labels, CvTermCriteria criteria,
692                          const CvMat* means );
693     CvEMParams params;
694     double log_likelihood;
695
696     CvMat* means;
697     CvMat** covs;
698     CvMat* weights;
699     CvMat* probs;
700
701     CvMat* log_weight_div_det;
702     CvMat* inv_eigen_values;
703     CvMat** cov_rotate_mats;
704 };
705
706 /****************************************************************************************\
707 *                                      Decision Tree                                     *
708 \****************************************************************************************/\
709 struct CvPair16u32s
710 {
711     unsigned short* u;
712     int* i;
713 };
714
715
716 #define CV_DTREE_CAT_DIR(idx,subset) \
717     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
718
719 struct CvDTreeSplit
720 {
721     int var_idx;
722     int condensed_idx;
723     int inversed;
724     float quality;
725     CvDTreeSplit* next;
726     union
727     {
728         int subset[2];
729         struct
730         {
731             float c;
732             int split_point;
733         }
734         ord;
735     };
736 };
737
738
739 struct CvDTreeNode
740 {
741     int class_idx;
742     int Tn;
743     double value;
744
745     CvDTreeNode* parent;
746     CvDTreeNode* left;
747     CvDTreeNode* right;
748
749     CvDTreeSplit* split;
750
751     int sample_count;
752     int depth;
753     int* num_valid;
754     int offset;
755     int buf_idx;
756     double maxlr;
757
758     // global pruning data
759     int complexity;
760     double alpha;
761     double node_risk, tree_risk, tree_error;
762
763     // cross-validation pruning data
764     int* cv_Tn;
765     double* cv_node_risk;
766     double* cv_node_error;
767
768     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
769     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
770 };
771
772
773 struct CV_EXPORTS CvDTreeParams
774 {
775     int   max_categories;
776     int   max_depth;
777     int   min_sample_count;
778     int   cv_folds;
779     bool  use_surrogates;
780     bool  use_1se_rule;
781     bool  truncate_pruned_tree;
782     float regression_accuracy;
783     const float* priors;
784
785     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
786         cv_folds(10), use_surrogates(true), use_1se_rule(true),
787         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
788     {}
789
790     CvDTreeParams( int _max_depth, int _min_sample_count,
791                    float _regression_accuracy, bool _use_surrogates,
792                    int _max_categories, int _cv_folds,
793                    bool _use_1se_rule, bool _truncate_pruned_tree,
794                    const float* _priors ) :
795         max_categories(_max_categories), max_depth(_max_depth),
796         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
797         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
798         truncate_pruned_tree(_truncate_pruned_tree),
799         regression_accuracy(_regression_accuracy),
800         priors(_priors)
801     {}
802 };
803
804
805 struct CV_EXPORTS CvDTreeTrainData
806 {
807     CvDTreeTrainData();
808     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
809                       const CvMat* _responses, const CvMat* _var_idx=0,
810                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
811                       const CvMat* _missing_mask=0,
812                       const CvDTreeParams& _params=CvDTreeParams(),
813                       bool _shared=false, bool _add_labels=false );
814     virtual ~CvDTreeTrainData();
815
816     virtual void set_data( const CvMat* _train_data, int _tflag,
817                           const CvMat* _responses, const CvMat* _var_idx=0,
818                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
819                           const CvMat* _missing_mask=0,
820                           const CvDTreeParams& _params=CvDTreeParams(),
821                           bool _shared=false, bool _add_labels=false,
822                           bool _update_data=false );
823     virtual void do_responses_copy();
824
825     virtual void get_vectors( const CvMat* _subsample_idx,
826          float* values, uchar* missing, float* responses, bool get_class_idx=false );
827
828     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
829
830     virtual void write_params( CvFileStorage* fs ) const;
831     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
832
833     // release all the data
834     virtual void clear();
835
836     int get_num_classes() const;
837     int get_var_type(int vi) const;
838     int get_work_var_count() const {return work_var_count;}
839
840     virtual void get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values );    
841     virtual void get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
842     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
843     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** labels );
844     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
845     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf,
846         const float** ord_values, const int** indices );
847     virtual int get_child_buf_idx( CvDTreeNode* n );
848
849     ////////////////////////////////////
850
851     virtual bool set_params( const CvDTreeParams& params );
852     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
853                                    int storage_idx, int offset );
854
855     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
856                 int split_point, int inversed, float quality );
857     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
858     virtual void free_node_data( CvDTreeNode* node );
859     virtual void free_train_data();
860     virtual void free_node( CvDTreeNode* node );
861
862     // inner arrays for getting predictors and responses
863     float* get_pred_float_buf();
864     int* get_pred_int_buf();
865     float* get_resp_float_buf();
866     int* get_resp_int_buf();
867     int* get_cv_lables_buf();
868     int* get_sample_idx_buf();
869
870     std::vector<std::vector<float> > pred_float_buf;
871     std::vector<std::vector<int> > pred_int_buf;
872     std::vector<std::vector<float> > resp_float_buf;
873     std::vector<std::vector<int> > resp_int_buf;
874     std::vector<std::vector<int> > cv_lables_buf;
875     std::vector<std::vector<int> > sample_idx_buf;
876
877     int sample_count, var_all, var_count, max_c_count;
878     int ord_var_count, cat_var_count, work_var_count;
879     bool have_labels, have_priors;
880     bool is_classifier;
881     int tflag;
882
883     const CvMat* train_data;
884     const CvMat* responses;
885     CvMat* responses_copy; // used in Boosting
886
887     int buf_count, buf_size;
888     bool shared;
889     int is_buf_16u;
890     
891     CvMat* cat_count;
892     CvMat* cat_ofs;
893     CvMat* cat_map;
894
895     CvMat* counts;
896     CvMat* buf;
897     CvMat* direction;
898     CvMat* split_buf;
899
900     CvMat* var_idx;
901     CvMat* var_type; // i-th element =
902                      //   k<0  - ordered
903                      //   k>=0 - categorical, see k-th element of cat_* arrays
904     CvMat* priors;
905     CvMat* priors_mult;
906
907     CvDTreeParams params;
908
909     CvMemStorage* tree_storage;
910     CvMemStorage* temp_storage;
911
912     CvDTreeNode* data_root;
913
914     CvSet* node_heap;
915     CvSet* split_heap;
916     CvSet* cv_heap;
917     CvSet* nv_heap;
918
919     CvRNG rng;
920 };
921
922
923 class CV_EXPORTS CvDTree : public CvStatModel
924 {
925 public:
926     CvDTree();
927     virtual ~CvDTree();
928
929     virtual bool train( const CvMat* _train_data, int _tflag,
930                         const CvMat* _responses, const CvMat* _var_idx=0,
931                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
932                         const CvMat* _missing_mask=0,
933                         CvDTreeParams params=CvDTreeParams() );
934
935     virtual bool train( CvMLData* _data, CvDTreeParams _params=CvDTreeParams() );
936
937     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
938
939     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
940
941     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
942                                   bool preprocessed_input=false ) const;
943
944 #ifndef SWIG
945     virtual bool train( const cv::Mat& _train_data, int _tflag,
946                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
947                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
948                        const cv::Mat& _missing_mask=cv::Mat(),
949                        CvDTreeParams params=CvDTreeParams() );
950     
951     virtual CvDTreeNode* predict( const cv::Mat& _sample, const cv::Mat& _missing_data_mask=cv::Mat(),
952                                   bool preprocessed_input=false ) const;
953 #endif
954     
955     virtual const CvMat* get_var_importance();
956     virtual void clear();
957
958     virtual void read( CvFileStorage* fs, CvFileNode* node );
959     virtual void write( CvFileStorage* fs, const char* name ) const;
960
961     // special read & write methods for trees in the tree ensembles
962     virtual void read( CvFileStorage* fs, CvFileNode* node,
963                        CvDTreeTrainData* data );
964     virtual void write( CvFileStorage* fs ) const;
965
966     const CvDTreeNode* get_root() const;
967     int get_pruned_tree_idx() const;
968     CvDTreeTrainData* get_data();
969
970 protected:
971
972     virtual bool do_train( const CvMat* _subsample_idx );
973
974     virtual void try_split_node( CvDTreeNode* n );
975     virtual void split_node_data( CvDTreeNode* n );
976     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
977     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
978                             float init_quality = 0, CvDTreeSplit* _split = 0 );
979     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
980                             float init_quality = 0, CvDTreeSplit* _split = 0 );
981     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
982                             float init_quality = 0, CvDTreeSplit* _split = 0 );
983     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
984                             float init_quality = 0, CvDTreeSplit* _split = 0 );
985     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
986     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
987     virtual double calc_node_dir( CvDTreeNode* node );
988     virtual void complete_node_dir( CvDTreeNode* node );
989     virtual void cluster_categories( const int* vectors, int vector_count,
990         int var_count, int* sums, int k, int* cluster_labels );
991
992     virtual void calc_node_value( CvDTreeNode* node );
993
994     virtual void prune_cv();
995     virtual double update_tree_rnc( int T, int fold );
996     virtual int cut_tree( int T, int fold, double min_alpha );
997     virtual void free_prune_data(bool cut_tree);
998     virtual void free_tree();
999
1000     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
1001     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
1002     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
1003     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
1004     virtual void write_tree_nodes( CvFileStorage* fs ) const;
1005     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
1006
1007     CvDTreeNode* root;
1008     CvMat* var_importance;
1009     CvDTreeTrainData* data;
1010
1011 public:
1012     int pruned_tree_idx;
1013 };
1014
1015
1016 /****************************************************************************************\
1017 *                                   Random Trees Classifier                              *
1018 \****************************************************************************************/
1019
1020 class CvRTrees;
1021
1022 class CV_EXPORTS CvForestTree: public CvDTree
1023 {
1024 public:
1025     CvForestTree();
1026     virtual ~CvForestTree();
1027
1028     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
1029
1030     virtual int get_var_count() const {return data ? data->var_count : 0;}
1031     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
1032
1033     /* dummy methods to avoid warnings: BEGIN */
1034     virtual bool train( const CvMat* _train_data, int _tflag,
1035                         const CvMat* _responses, const CvMat* _var_idx=0,
1036                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1037                         const CvMat* _missing_mask=0,
1038                         CvDTreeParams params=CvDTreeParams() );
1039
1040     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1041     virtual void read( CvFileStorage* fs, CvFileNode* node );
1042     virtual void read( CvFileStorage* fs, CvFileNode* node,
1043                        CvDTreeTrainData* data );
1044     /* dummy methods to avoid warnings: END */
1045
1046 protected:
1047     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
1048     CvRTrees* forest;
1049 };
1050
1051
1052 struct CV_EXPORTS CvRTParams : public CvDTreeParams
1053 {
1054     //Parameters for the forest
1055     bool calc_var_importance; // true <=> RF processes variable importance
1056     int nactive_vars;
1057     CvTermCriteria term_crit;
1058
1059     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
1060         calc_var_importance(false), nactive_vars(0)
1061     {
1062         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
1063     }
1064
1065     CvRTParams( int _max_depth, int _min_sample_count,
1066                 float _regression_accuracy, bool _use_surrogates,
1067                 int _max_categories, const float* _priors, bool _calc_var_importance,
1068                 int _nactive_vars, int max_num_of_trees_in_the_forest,
1069                 float forest_accuracy, int termcrit_type ) :
1070         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
1071                        _use_surrogates, _max_categories, 0,
1072                        false, false, _priors ),
1073         calc_var_importance(_calc_var_importance),
1074         nactive_vars(_nactive_vars)
1075     {
1076         term_crit = cvTermCriteria(termcrit_type,
1077             max_num_of_trees_in_the_forest, forest_accuracy);
1078     }
1079 };
1080
1081
1082 class CV_EXPORTS CvRTrees : public CvStatModel
1083 {
1084 public:
1085     CvRTrees();
1086     virtual ~CvRTrees();
1087     virtual bool train( const CvMat* _train_data, int _tflag,
1088                         const CvMat* _responses, const CvMat* _var_idx=0,
1089                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1090                         const CvMat* _missing_mask=0,
1091                         CvRTParams params=CvRTParams() );
1092     
1093     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1094     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1095     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1096
1097 #ifndef SWIG
1098     virtual bool train( const cv::Mat& _train_data, int _tflag,
1099                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1100                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1101                        const cv::Mat& _missing_mask=cv::Mat(),
1102                        CvRTParams params=CvRTParams() );
1103     virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1104     virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1105 #endif
1106     
1107     virtual void clear();
1108
1109     virtual const CvMat* get_var_importance();
1110     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1111         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1112     
1113     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1114
1115     virtual float get_train_error();    
1116
1117     virtual void read( CvFileStorage* fs, CvFileNode* node );
1118     virtual void write( CvFileStorage* fs, const char* name ) const;
1119
1120     CvMat* get_active_var_mask();
1121     CvRNG* get_rng();
1122
1123     int get_tree_count() const;
1124     CvForestTree* get_tree(int i) const;
1125
1126 protected:
1127
1128     virtual bool grow_forest( const CvTermCriteria term_crit );
1129
1130     // array of the trees of the forest
1131     CvForestTree** trees;
1132     CvDTreeTrainData* data;
1133     int ntrees;
1134     int nclasses;
1135     double oob_error;
1136     CvMat* var_importance;
1137     int nsamples;
1138
1139     CvRNG rng;
1140     CvMat* active_var_mask;
1141 };
1142
1143 /****************************************************************************************\
1144 *                           Extremely randomized trees Classifier                        *
1145 \****************************************************************************************/
1146 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
1147 {
1148     virtual void set_data( const CvMat* _train_data, int _tflag,
1149                           const CvMat* _responses, const CvMat* _var_idx=0,
1150                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1151                           const CvMat* _missing_mask=0,
1152                           const CvDTreeParams& _params=CvDTreeParams(),
1153                           bool _shared=false, bool _add_labels=false,
1154                           bool _update_data=false );
1155     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1156         const float** ord_values, const int** missing );
1157     virtual void get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices );
1158     virtual void get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels );
1159     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values );
1160     virtual void get_vectors( const CvMat* _subsample_idx,
1161          float* values, uchar* missing, float* responses, bool get_class_idx=false );
1162     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1163     const CvMat* missing_mask;
1164 };
1165
1166 class CV_EXPORTS CvForestERTree : public CvForestTree
1167 {
1168 protected:
1169     virtual double calc_node_dir( CvDTreeNode* node );
1170     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
1171         float init_quality = 0, CvDTreeSplit* _split = 0 );
1172     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1173         float init_quality = 0, CvDTreeSplit* _split = 0 );
1174     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
1175         float init_quality = 0, CvDTreeSplit* _split = 0 );
1176     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
1177         float init_quality = 0, CvDTreeSplit* _split = 0 );
1178     //virtual void complete_node_dir( CvDTreeNode* node );
1179     virtual void split_node_data( CvDTreeNode* n );
1180 };
1181
1182 class CV_EXPORTS CvERTrees : public CvRTrees
1183 {
1184 public:
1185     CvERTrees();
1186     virtual ~CvERTrees();
1187     virtual bool train( const CvMat* _train_data, int _tflag,
1188                         const CvMat* _responses, const CvMat* _var_idx=0,
1189                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1190                         const CvMat* _missing_mask=0,
1191                         CvRTParams params=CvRTParams());
1192 #ifndef SWIG
1193     virtual bool train( const cv::Mat& _train_data, int _tflag,
1194                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1195                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1196                        const cv::Mat& _missing_mask=cv::Mat(),
1197                        CvRTParams params=CvRTParams());
1198 #endif
1199     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1200 protected:
1201     virtual bool grow_forest( const CvTermCriteria term_crit );
1202 };
1203
1204
1205 /****************************************************************************************\
1206 *                                   Boosted tree classifier                              *
1207 \****************************************************************************************/
1208
1209 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
1210 {
1211     int boost_type;
1212     int weak_count;
1213     int split_criteria;
1214     double weight_trim_rate;
1215
1216     CvBoostParams();
1217     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1218                    int max_depth, bool use_surrogates, const float* priors );
1219 };
1220
1221
1222 class CvBoost;
1223
1224 class CV_EXPORTS CvBoostTree: public CvDTree
1225 {
1226 public:
1227     CvBoostTree();
1228     virtual ~CvBoostTree();
1229
1230     virtual bool train( CvDTreeTrainData* _train_data,
1231                         const CvMat* subsample_idx, CvBoost* ensemble );
1232
1233     virtual void scale( double s );
1234     virtual void read( CvFileStorage* fs, CvFileNode* node,
1235                        CvBoost* ensemble, CvDTreeTrainData* _data );
1236     virtual void clear();
1237
1238     /* dummy methods to avoid warnings: BEGIN */
1239     virtual bool train( const CvMat* _train_data, int _tflag,
1240                         const CvMat* _responses, const CvMat* _var_idx=0,
1241                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1242                         const CvMat* _missing_mask=0,
1243                         CvDTreeParams params=CvDTreeParams() );
1244     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1245
1246     virtual void read( CvFileStorage* fs, CvFileNode* node );
1247     virtual void read( CvFileStorage* fs, CvFileNode* node,
1248                        CvDTreeTrainData* data );
1249     /* dummy methods to avoid warnings: END */
1250
1251 protected:
1252
1253     virtual void try_split_node( CvDTreeNode* n );
1254     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
1255     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
1256     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
1257         float init_quality = 0, CvDTreeSplit* _split = 0 );
1258     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1259         float init_quality = 0, CvDTreeSplit* _split = 0 );
1260     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
1261         float init_quality = 0, CvDTreeSplit* _split = 0 );
1262     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
1263         float init_quality = 0, CvDTreeSplit* _split = 0 );
1264     virtual void calc_node_value( CvDTreeNode* n );
1265     virtual double calc_node_dir( CvDTreeNode* n );
1266
1267     CvBoost* ensemble;
1268 };
1269
1270
1271 class CV_EXPORTS CvBoost : public CvStatModel
1272 {
1273 public:
1274     // Boosting type
1275     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1276
1277     // Splitting criteria
1278     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1279
1280     CvBoost();
1281     virtual ~CvBoost();
1282
1283     CvBoost( const CvMat* _train_data, int _tflag,
1284              const CvMat* _responses, const CvMat* _var_idx=0,
1285              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1286              const CvMat* _missing_mask=0,
1287              CvBoostParams params=CvBoostParams() );
1288     
1289     virtual bool train( const CvMat* _train_data, int _tflag,
1290              const CvMat* _responses, const CvMat* _var_idx=0,
1291              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1292              const CvMat* _missing_mask=0,
1293              CvBoostParams params=CvBoostParams(),
1294              bool update=false );
1295     
1296     virtual bool train( CvMLData* data,
1297              CvBoostParams params=CvBoostParams(),
1298              bool update=false );
1299
1300     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
1301                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1302                            bool raw_mode=false, bool return_sum=false ) const;
1303
1304 #ifndef SWIG
1305     CvBoost( const cv::Mat& _train_data, int _tflag,
1306             const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1307             const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1308             const cv::Mat& _missing_mask=cv::Mat(),
1309             CvBoostParams params=CvBoostParams() );
1310     
1311     virtual bool train( const cv::Mat& _train_data, int _tflag,
1312                        const cv::Mat& _responses, const cv::Mat& _var_idx=cv::Mat(),
1313                        const cv::Mat& _sample_idx=cv::Mat(), const cv::Mat& _var_type=cv::Mat(),
1314                        const cv::Mat& _missing_mask=cv::Mat(),
1315                        CvBoostParams params=CvBoostParams(),
1316                        bool update=false );
1317     
1318     virtual float predict( const cv::Mat& _sample, const cv::Mat& _missing=cv::Mat(),
1319                           cv::Mat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1320                           bool raw_mode=false, bool return_sum=false ) const;
1321 #endif
1322     
1323     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1324
1325     virtual void prune( CvSlice slice );
1326
1327     virtual void clear();
1328
1329     virtual void write( CvFileStorage* storage, const char* name ) const;
1330     virtual void read( CvFileStorage* storage, CvFileNode* node );
1331     virtual const CvMat* get_active_vars(bool absolute_idx=true);
1332
1333     CvSeq* get_weak_predictors();
1334
1335     CvMat* get_weights();
1336     CvMat* get_subtree_weights();
1337     CvMat* get_weak_response();
1338     const CvBoostParams& get_params() const;
1339     const CvDTreeTrainData* get_data() const;
1340
1341 protected:
1342
1343     virtual bool set_params( const CvBoostParams& _params );
1344     virtual void update_weights( CvBoostTree* tree );
1345     virtual void trim_weights();
1346     virtual void write_params( CvFileStorage* fs ) const;
1347     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1348
1349     CvDTreeTrainData* data;
1350     CvBoostParams params;
1351     CvSeq* weak;
1352
1353     CvMat* active_vars;
1354     CvMat* active_vars_abs;
1355     bool have_active_cat_vars;
1356
1357     CvMat* orig_response;
1358     CvMat* sum_response;
1359     CvMat* weak_eval;
1360     CvMat* subsample_mask;
1361     CvMat* weights;
1362     CvMat* subtree_weights;
1363     bool have_subsample;
1364 };
1365
1366
1367 /****************************************************************************************\
1368 *                              Artificial Neural Networks (ANN)                          *
1369 \****************************************************************************************/
1370
1371 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1372
1373 struct CV_EXPORTS CvANN_MLP_TrainParams
1374 {
1375     CvANN_MLP_TrainParams();
1376     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1377                            double param1, double param2=0 );
1378     ~CvANN_MLP_TrainParams();
1379
1380     enum { BACKPROP=0, RPROP=1 };
1381
1382     CvTermCriteria term_crit;
1383     int train_method;
1384
1385     // backpropagation parameters
1386     double bp_dw_scale, bp_moment_scale;
1387
1388     // rprop parameters
1389     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1390 };
1391
1392
1393 class CV_EXPORTS CvANN_MLP : public CvStatModel
1394 {
1395 public:
1396     CvANN_MLP();
1397     CvANN_MLP( const CvMat* _layer_sizes,
1398                int _activ_func=SIGMOID_SYM,
1399                double _f_param1=0, double _f_param2=0 );
1400
1401     virtual ~CvANN_MLP();
1402
1403     virtual void create( const CvMat* _layer_sizes,
1404                          int _activ_func=SIGMOID_SYM,
1405                          double _f_param1=0, double _f_param2=0 );
1406     
1407     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
1408                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
1409                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1410                        int flags=0 );
1411     virtual float predict( const CvMat* _inputs, CvMat* _outputs ) const;
1412     
1413 #ifndef SWIG
1414     CvANN_MLP( const cv::Mat& _layer_sizes,
1415               int _activ_func=SIGMOID_SYM,
1416               double _f_param1=0, double _f_param2=0 );
1417     
1418     virtual void create( const cv::Mat& _layer_sizes,
1419                         int _activ_func=SIGMOID_SYM,
1420                         double _f_param1=0, double _f_param2=0 );    
1421     
1422     virtual int train( const cv::Mat& _inputs, const cv::Mat& _outputs,
1423                       const cv::Mat& _sample_weights, const cv::Mat& _sample_idx=cv::Mat(),
1424                       CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1425                       int flags=0 );    
1426     
1427     virtual float predict( const cv::Mat& _inputs, cv::Mat& _outputs ) const;
1428 #endif
1429     
1430     virtual void clear();
1431
1432     // possible activation functions
1433     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1434
1435     // available training flags
1436     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1437
1438     virtual void read( CvFileStorage* fs, CvFileNode* node );
1439     virtual void write( CvFileStorage* storage, const char* name ) const;
1440
1441     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1442     const CvMat* get_layer_sizes() { return layer_sizes; }
1443     double* get_weights(int layer)
1444     {
1445         return layer_sizes && weights &&
1446             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1447     }
1448
1449 protected:
1450
1451     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1452             const CvMat* _sample_weights, const CvMat* _sample_idx,
1453             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1454
1455     // sequential random backpropagation
1456     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1457
1458     // RPROP algorithm
1459     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1460
1461     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1462     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1463     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1464                                  double _f_param1=0, double _f_param2=0 );
1465     virtual void init_weights();
1466     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1467     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1468     virtual void calc_input_scale( const CvVectors* vecs, int flags );
1469     virtual void calc_output_scale( const CvVectors* vecs, int flags );
1470
1471     virtual void write_params( CvFileStorage* fs ) const;
1472     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1473
1474     CvMat* layer_sizes;
1475     CvMat* wbuf;
1476     CvMat* sample_weights;
1477     double** weights;
1478     double f_param1, f_param2;
1479     double min_val, max_val, min_val1, max_val1;
1480     int activ_func;
1481     int max_count, max_buf_sz;
1482     CvANN_MLP_TrainParams params;
1483     CvRNG rng;
1484 };
1485
1486 #if 0
1487 /****************************************************************************************\
1488 *                            Convolutional Neural Network                                *
1489 \****************************************************************************************/
1490 typedef struct CvCNNLayer CvCNNLayer;
1491 typedef struct CvCNNetwork CvCNNetwork;
1492
1493 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
1494 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
1495 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
1496
1497 #define CV_CNN_GRAD_ESTIM_RANDOM        0
1498 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
1499
1500 #define ICV_CNN_LAYER                0x55550000
1501 #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
1502 #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
1503 #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
1504
1505 #define ICV_IS_CNN_LAYER( layer )                                          \
1506     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
1507         == ICV_CNN_LAYER ))
1508
1509 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
1510     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1511         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
1512
1513 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
1514     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1515         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
1516
1517 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
1518     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1519         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
1520
1521 typedef void (CV_CDECL *CvCNNLayerForward)
1522     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
1523
1524 typedef void (CV_CDECL *CvCNNLayerBackward)
1525     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
1526
1527 typedef void (CV_CDECL *CvCNNLayerRelease)
1528     (CvCNNLayer** layer);
1529
1530 typedef void (CV_CDECL *CvCNNetworkAddLayer)
1531     (CvCNNetwork* network, CvCNNLayer* layer);
1532
1533 typedef void (CV_CDECL *CvCNNetworkRelease)
1534     (CvCNNetwork** network);
1535
1536 #define CV_CNN_LAYER_FIELDS()           \
1537     /* Indicator of the layer's type */ \
1538     int flags;                          \
1539                                         \
1540     /* Number of input images */        \
1541     int n_input_planes;                 \
1542     /* Height of each input image */    \
1543     int input_height;                   \
1544     /* Width of each input image */     \
1545     int input_width;                    \
1546                                         \
1547     /* Number of output images */       \
1548     int n_output_planes;                \
1549     /* Height of each output image */   \
1550     int output_height;                  \
1551     /* Width of each output image */    \
1552     int output_width;                   \
1553                                         \
1554     /* Learning rate at the first iteration */                      \
1555     float init_learn_rate;                                          \
1556     /* Dynamics of learning rate decreasing */                      \
1557     int learn_rate_decrease_type;                                   \
1558     /* Trainable weights of the layer (including bias) */           \
1559     /* i-th row is a set of weights of the i-th output plane */     \
1560     CvMat* weights;                                                 \
1561                                                                     \
1562     CvCNNLayerForward  forward;                                     \
1563     CvCNNLayerBackward backward;                                    \
1564     CvCNNLayerRelease  release;                                     \
1565     /* Pointers to the previous and next layers in the network */   \
1566     CvCNNLayer* prev_layer;                                         \
1567     CvCNNLayer* next_layer
1568
1569 typedef struct CvCNNLayer
1570 {
1571     CV_CNN_LAYER_FIELDS();
1572 }CvCNNLayer;
1573
1574 typedef struct CvCNNConvolutionLayer
1575 {
1576     CV_CNN_LAYER_FIELDS();
1577     // Kernel size (height and width) for convolution.
1578     int K;
1579     // connections matrix, (i,j)-th element is 1 iff there is a connection between
1580     // i-th plane of the current layer and j-th plane of the previous layer;
1581     // (i,j)-th element is equal to 0 otherwise
1582     CvMat *connect_mask;
1583     // value of the learning rate for updating weights at the first iteration
1584 }CvCNNConvolutionLayer;
1585
1586 typedef struct CvCNNSubSamplingLayer
1587 {
1588     CV_CNN_LAYER_FIELDS();
1589     // ratio between the heights (or widths - ratios are supposed to be equal)
1590     // of the input and output planes
1591     int sub_samp_scale;
1592     // amplitude of sigmoid activation function
1593     float a;
1594     // scale parameter of sigmoid activation function
1595     float s;
1596     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
1597     // - is the vector used in computing of the activation function in backward
1598     CvMat* exp2ssumWX;
1599     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
1600     // - is the vector used in computing of the activation function in backward
1601     CvMat* sumX;
1602 }CvCNNSubSamplingLayer;
1603
1604 // Structure of the last layer.
1605 typedef struct CvCNNFullConnectLayer
1606 {
1607     CV_CNN_LAYER_FIELDS();
1608     // amplitude of sigmoid activation function
1609     float a;
1610     // scale parameter of sigmoid activation function
1611     float s;
1612     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
1613     // activation function and it's derivative by the formulae
1614     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
1615     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
1616     CvMat* exp2ssumWX;
1617 }CvCNNFullConnectLayer;
1618
1619 typedef struct CvCNNetwork
1620 {
1621     int n_layers;
1622     CvCNNLayer* layers;
1623     CvCNNetworkAddLayer add_layer;
1624     CvCNNetworkRelease release;
1625 }CvCNNetwork;
1626
1627 typedef struct CvCNNStatModel
1628 {
1629     CV_STAT_MODEL_FIELDS();
1630     CvCNNetwork* network;
1631     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
1632     CvMat* etalons;
1633     // classes labels
1634     CvMat* cls_labels;
1635 }CvCNNStatModel;
1636
1637 typedef struct CvCNNStatModelParams
1638 {
1639     CV_STAT_MODEL_PARAM_FIELDS();
1640     // network must be created by the functions cvCreateCNNetwork and <add_layer>
1641     CvCNNetwork* network;
1642     CvMat* etalons;
1643     // termination criteria
1644     int max_iter;
1645     int start_iter;
1646     int grad_estim_type;
1647 }CvCNNStatModelParams;
1648
1649 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
1650     int n_input_planes, int input_height, int input_width,
1651     int n_output_planes, int K,
1652     float init_learn_rate, int learn_rate_decrease_type,
1653     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
1654
1655 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
1656     int n_input_planes, int input_height, int input_width,
1657     int sub_samp_scale, float a, float s,
1658     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
1659
1660 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
1661     int n_inputs, int n_outputs, float a, float s,
1662     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
1663
1664 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
1665
1666 CVAPI(CvStatModel*) cvTrainCNNClassifier(
1667             const CvMat* train_data, int tflag,
1668             const CvMat* responses,
1669             const CvStatModelParams* params,
1670             const CvMat* CV_DEFAULT(0),
1671             const CvMat* sample_idx CV_DEFAULT(0),
1672             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
1673
1674 /****************************************************************************************\
1675 *                               Estimate classifiers algorithms                          *
1676 \****************************************************************************************/
1677 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
1678                     ( const CvStatModel* estimateModel );
1679
1680 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
1681                     ( CvStatModel* estimateModel );
1682
1683 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
1684                     ( CvStatModel* estimateModel,
1685                 const CvStatModel* model,
1686                 const CvMat*       features,
1687                       int          sample_t_flag,
1688                 const CvMat*       responses );
1689
1690 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
1691                     ( CvStatModel* estimateModel,
1692                 const CvStatModel* model );
1693
1694 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
1695                     ( const CvStatModel* estimateModel,
1696                             float*       correlation );
1697
1698 typedef void (CV_CDECL *CvStatModelEstimateReset)
1699                     ( CvStatModel* estimateModel );
1700
1701 //-------------------------------- Cross-validation --------------------------------------
1702 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
1703     CV_STAT_MODEL_PARAM_FIELDS();                                 \
1704     int     k_fold;                                               \
1705     int     is_regression;                                        \
1706     CvRNG*  rng
1707
1708 typedef struct CvCrossValidationParams
1709 {
1710     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
1711 } CvCrossValidationParams;
1712
1713 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
1714     CvStatModelEstimateGetMat               getTrainIdxMat; \
1715     CvStatModelEstimateGetMat               getCheckIdxMat; \
1716     CvStatModelEstimateNextStep             nextStep;       \
1717     CvStatModelEstimateCheckClassifier      check;          \
1718     CvStatModelEstimateGetCurrentResult     getResult;      \
1719     CvStatModelEstimateReset                reset;          \
1720     int     is_regression;                                  \
1721     int     folds_all;                                      \
1722     int     samples_all;                                    \
1723     int*    sampleIdxAll;                                   \
1724     int*    folds;                                          \
1725     int     max_fold_size;                                  \
1726     int         current_fold;                               \
1727     int         is_checked;                                 \
1728     CvMat*      sampleIdxTrain;                             \
1729     CvMat*      sampleIdxEval;                              \
1730     CvMat*      predict_results;                            \
1731     int     correct_results;                                \
1732     int     all_results;                                    \
1733     double  sq_error;                                       \
1734     double  sum_correct;                                    \
1735     double  sum_predict;                                    \
1736     double  sum_cc;                                         \
1737     double  sum_pp;                                         \
1738     double  sum_cp
1739
1740 typedef struct CvCrossValidationModel
1741 {
1742     CV_STAT_MODEL_FIELDS();
1743     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
1744 } CvCrossValidationModel;
1745
1746 CVAPI(CvStatModel*)
1747 cvCreateCrossValidationEstimateModel
1748            ( int                samples_all,
1749        const CvStatModelParams* estimateParams CV_DEFAULT(0),
1750        const CvMat*             sampleIdx CV_DEFAULT(0) );
1751
1752 CVAPI(float)
1753 cvCrossValidation( const CvMat*             trueData,
1754                          int                tflag,
1755                    const CvMat*             trueClasses,
1756                          CvStatModel*     (*createClassifier)( const CvMat*,
1757                                                                      int,
1758                                                                const CvMat*,
1759                                                                const CvStatModelParams*,
1760                                                                const CvMat*,
1761                                                                const CvMat*,
1762                                                                const CvMat*,
1763                                                                const CvMat* ),
1764                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
1765                    const CvStatModelParams* trainParams CV_DEFAULT(0),
1766                    const CvMat*             compIdx CV_DEFAULT(0),
1767                    const CvMat*             sampleIdx CV_DEFAULT(0),
1768                          CvStatModel**      pCrValModel CV_DEFAULT(0),
1769                    const CvMat*             typeMask CV_DEFAULT(0),
1770                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
1771 #endif
1772
1773 /****************************************************************************************\
1774 *                           Auxilary functions declarations                              *
1775 \****************************************************************************************/
1776
1777 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1778    average row vector, <cov> - symmetric covariation matrix */
1779 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1780                            CvRNG* rng CV_DEFAULT(0) );
1781
1782 /* Generates sample from gaussian mixture distribution */
1783 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1784                                CvMat* covs[],
1785                                float weights[],
1786                                int clsnum,
1787                                CvMat* sample,
1788                                CvMat* sampClasses CV_DEFAULT(0) );
1789
1790 #define CV_TS_CONCENTRIC_SPHERES 0
1791
1792 /* creates test set */
1793 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1794                  int num_samples,
1795                  int num_features,
1796                  CvMat** responses,
1797                  int num_classes, ... );
1798
1799
1800 #endif
1801
1802 /****************************************************************************************\
1803 *                                      Data                                             *
1804 \****************************************************************************************/
1805
1806 #include <map>
1807 #include <string>
1808 #include <iostream>
1809
1810 #define CV_COUNT     0
1811 #define CV_PORTION   1
1812
1813 struct CV_EXPORTS CvTrainTestSplit
1814 {
1815 public:
1816     CvTrainTestSplit();
1817     CvTrainTestSplit( int _train_sample_count, bool _mix = true);
1818     CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
1819
1820     union
1821     {
1822         int count;
1823         float portion;
1824     } train_sample_part;
1825     int train_sample_part_mode;
1826
1827     union
1828     {
1829         int *count;
1830         float *portion;
1831     } *class_part;
1832     int class_part_mode;
1833
1834     bool mix;    
1835 };
1836
1837 class CV_EXPORTS CvMLData
1838 {
1839 public:
1840     CvMLData();
1841     virtual ~CvMLData();
1842
1843     // returns:
1844     // 0 - OK  
1845     // 1 - file can not be opened or is not correct
1846     int read_csv(const char* filename);
1847
1848     const CvMat* get_values(){ return values; };
1849
1850     const CvMat* get_responses();
1851
1852     const CvMat* get_missing(){ return missing; };
1853
1854     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
1855                                       // if idx < 0 there will be no response
1856     int get_response_idx() { return response_idx; }
1857
1858     const CvMat* get_train_sample_idx() { return train_sample_idx; };
1859     const CvMat* get_test_sample_idx() { return test_sample_idx; };
1860     void mix_train_and_test_idx();
1861     void set_train_test_split( const CvTrainTestSplit * spl);
1862     
1863     const CvMat* get_var_idx();
1864     void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
1865
1866     const CvMat* get_var_types();
1867     int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
1868     // following 2 methods enable to change vars type
1869     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1870     // with numerical labels; in the other cases var types are correctly determined automatically
1871     void set_var_types( const char* str );  // str examples:
1872                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1873                                             // "cat", "ord" (all vars are categorical/ordered)
1874     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
1875  
1876     void set_delimiter( char ch );
1877     char get_delimiter() { return delimiter; };
1878
1879     void set_miss_ch( char ch );
1880     char get_miss_ch() { return miss_ch; };
1881     
1882 protected:
1883     virtual void clear();
1884
1885     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
1886     void free_train_test_idx();
1887     
1888     char delimiter;
1889     char miss_ch;
1890     //char flt_separator;
1891
1892     CvMat* values;
1893     CvMat* missing;
1894     CvMat* var_types;
1895     CvMat* var_idx_mask;
1896
1897     CvMat* response_out; // header
1898     CvMat* var_idx_out; // mat
1899     CvMat* var_types_out; // mat
1900
1901     int response_idx;
1902
1903     int train_sample_count;
1904     bool mix;
1905    
1906     int total_class_count;
1907     std::map<std::string, int> *class_map;
1908
1909     CvMat* train_sample_idx;
1910     CvMat* test_sample_idx;
1911     int* sample_idx; // data of train_sample_idx and test_sample_idx
1912
1913     CvRNG rng;
1914 };
1915
1916
1917 namespace cv
1918 {
1919     
1920 typedef CvStatModel StatModel;
1921 typedef CvParamGrid ParamGrid;
1922 typedef CvNormalBayesClassifier NormalBayesClassifier;
1923 typedef CvKNearest KNearest;
1924 typedef CvSVMParams SVMParams;
1925 typedef CvSVMKernel SVMKernel;
1926 typedef CvSVMSolver SVMSolver;
1927 typedef CvSVM SVM;
1928 typedef CvEMParams EMParams;
1929 typedef CvEM ExpectationMaximization;
1930 typedef CvDTreeParams DTreeParams;
1931 typedef CvMLData TrainData;
1932 typedef CvDTree DecisionTree;
1933 typedef CvForestTree ForestTree;
1934 typedef CvRTParams RandomTreeParams;
1935 typedef CvRTrees RandomTrees;
1936 typedef CvERTreeTrainData ERTreeTRainData;
1937 typedef CvForestERTree ERTree;
1938 typedef CvERTrees ERTrees;
1939 typedef CvBoostParams BoostParams;
1940 typedef CvBoostTree BoostTree;
1941 typedef CvBoost Boost;
1942 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
1943 typedef CvANN_MLP NeuralNet_MLP;
1944     
1945 }
1946
1947 #endif
1948 /* End of file. */