]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/apps/traincascade/cascadeclassifier.h
87df6f502eaf3fefe8e02e705a3369d3e4a98c95
[opencv.git] / opencv / apps / traincascade / cascadeclassifier.h
1 #ifndef CASCADECLASSIFIER_H
2 #define CASCADECLASSIFIER_H
3
4 #include "features.h"
5 #include "haarfeatures.h"
6 #include "lbpfeatures.h"
7 #include "boost.h"
8 #include "cv.h"
9 #include "cxcore.h"
10
11 #define CC_CASCADE_FILENAME "cascade.xml"
12 #define CC_PARAMS_FILENAME "params.xml"
13
14 #define CC_CASCADE_PARAMS "cascadeParams"
15 #define CC_STAGE_TYPE "stageType"
16 #define CC_FEATURE_TYPE "featureType"
17 #define CC_HEIGHT "height"
18 #define CC_WIDTH  "width"
19
20 #define CC_STAGE_NUM    "stageNum"
21 #define CC_STAGES       "stages"
22 #define CC_STAGE_PARAMS "stageParams"
23
24 #define CC_BOOST            "BOOST"
25 #define CC_BOOST_TYPE       "boostType"
26 #define CC_DISCRETE_BOOST   "DAB"
27 #define CC_REAL_BOOST       "RAB"
28 #define CC_LOGIT_BOOST      "LB"
29 #define CC_GENTLE_BOOST     "GAB"
30 #define CC_MINHITRATE       "minHitRate"
31 #define CC_MAXFALSEALARM    "maxFalseAlarm"
32 #define CC_TRIM_RATE        "weightTrimRate"
33 #define CC_MAX_DEPTH        "maxDepth"
34 #define CC_WEAK_COUNT       "maxWeakCount"
35 #define CC_STAGE_THRESHOLD  "stageThreshold"
36 #define CC_WEAK_CLASSIFIERS "weakClassifiers"
37 #define CC_INTERNAL_NODES   "internalNodes"
38 #define CC_LEAF_VALUES      "leafValues"
39
40 #define CC_FEATURES "features"
41 #define CC_FEATURE_PARAMS "featureParams"
42 #define CC_MAX_CAT_COUNT  "maxCatCount"
43
44 #define CC_HAAR        "HAAR"
45 #define CC_MODE        "mode"
46 #define CC_MODE_BASIC  "BASIC"
47 #define CC_MODE_CORE   "CORE"
48 #define CC_MODE_ALL    "ALL"
49 #define CC_RECTS       "rects"
50 #define CC_TILTED      "tilted"
51
52 #define CC_LBP  "LBP"
53 #define CC_RECT "rect"
54
55 #define CV_NEW_SAVE_FORMAT 0
56 #define CV_OLD_SAVE_FORMAT 1
57
58 struct CvCascadeParams : CvParams
59 {
60     enum { BOOST = 0 };
61     enum { HAAR = 0, LBP = 1 };
62     
63     static const int defaultStageType = BOOST;
64     static const int defaultFeatureType = HAAR;
65
66     CvCascadeParams() : stageType( defaultStageType ), featureType( defaultFeatureType ), winSize( cvSize(24, 24) )
67     { name = CC_CASCADE_PARAMS; }
68     CvCascadeParams( int _stageType, int _featureType ) :
69         stageType( _stageType ), featureType( _featureType ), winSize( cvSize(24, 24) )
70     { name = CC_CASCADE_PARAMS; }
71     virtual ~CvCascadeParams() {}
72     void write( CvFileStorage* fs ) const;
73     bool read( CvFileStorage* fs, CvFileNode* node );
74
75     void printDefaults();
76     void printAttrs();    
77     bool scanAttr( const char* prmName, const char* val );
78
79     int stageType;
80     int featureType;
81     CvSize winSize;
82 };
83
84 class CvCascadeClassifier
85 {
86 public:
87     CvCascadeClassifier();
88     virtual ~CvCascadeClassifier();
89
90     virtual bool train( const char* _cascadeDirName,
91                         const char* _vecFileName,
92                         const char* _bgfileName, 
93                         int _numPos, int _numNeg, 
94                         int _numPrecalcVal, int _numPrecalcIdx,
95                         int _numStages,
96                         const CvCascadeParams& _cascadeParams,
97                         const CvFeatureParams& _featureParams,
98                         const CvCascadeBoostParams& _stageParams,
99                         bool baseFormatSave = false );
100     virtual int predict( int sampleIdx );
101     virtual bool save( const char* cascadeDirName, bool baseFormat = false );
102     const CvCascadeParams* getParams() const { return &cascadeParams; }
103 protected:     
104     void createStageParams();
105     void createCurStage();
106     void createFeatureParams();
107     void createCascadeData();
108
109     virtual void writeParams( CvFileStorage* fs ) const;
110     virtual void writeStages( CvFileStorage* fs, const CvMat* featureMap ) const;
111     virtual void writeFeatures( CvFileStorage* fs, const CvMat* featureMap ) const;
112     
113     virtual bool readParams( CvFileStorage* fs, CvFileNode* node );
114     virtual bool readStages( CvFileStorage* fs, CvFileNode* node );
115
116     virtual bool loadTempInfo( const char* cascadeDirName, const char* _vecFileName, const char* _bgFileName, 
117         int _numPos, int _numNeg );
118     void markFeaturesInMap( CvMat* featureMap);
119
120     int numStages, numCurStages;
121
122     CvCascadeData* cascadeData;    
123     CvCascadeBoost** stageClassifiers;
124
125     CvCascadeParams cascadeParams;
126     CvCascadeBoostParams* stageParams;
127     CvFeatureParams* featureParams;
128 };
129
130 #endif