]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/apps/traincascade/cascadeclassifier.cpp
updated traincascade
[opencv.git] / opencv / apps / traincascade / cascadeclassifier.cpp
1 #include "cascadeclassifier.h"
2 #include <queue>
3
4 using namespace std;
5
6 static const char* stageTypes[] = { CC_BOOST };
7 static const char* featureTypes[] = { CC_HAAR, CC_LBP };
8
9 CvCascadeParams::CvCascadeParams() : stageType( defaultStageType ), 
10     featureType( defaultFeatureType ), winSize( cvSize(24, 24) )
11
12     name = CC_CASCADE_PARAMS; 
13 }
14 CvCascadeParams::CvCascadeParams( int _stageType, int _featureType ) : stageType( _stageType ),
15     featureType( _featureType ), winSize( cvSize(24, 24) )
16
17     name = CC_CASCADE_PARAMS;
18 }
19
20 //---------------------------- CascadeParams --------------------------------------
21
22 void CvCascadeParams::write( FileStorage &fs ) const
23 {
24     String stageTypeStr = stageType == BOOST ? CC_BOOST : String();
25     CV_Assert( !stageTypeStr.empty() );
26     fs << CC_STAGE_TYPE << stageTypeStr;
27     String featureTypeStr = featureType == CvFeatureParams::HAAR ? CC_HAAR :
28                             featureType == CvFeatureParams::LBP ? CC_LBP : 0;
29     CV_Assert( !stageTypeStr.empty() );
30     fs << CC_FEATURE_TYPE << featureTypeStr;
31     fs << CC_HEIGHT << winSize.height;
32     fs << CC_WIDTH << winSize.width;
33 }
34
35 bool CvCascadeParams::read( const FileNode &node )
36 {
37     if ( node.empty() )
38         return false;
39     String stageTypeStr, featureTypeStr;
40     FileNode rnode = node[CC_STAGE_TYPE];
41     if ( !rnode.isString() )
42         return false;
43     rnode >> stageTypeStr;
44     stageType = !stageTypeStr.compare( CC_BOOST ) ? BOOST : -1;
45     if (stageType == -1)
46         return false;
47     rnode = node[CC_FEATURE_TYPE];
48     if ( !rnode.isString() )
49         return false;
50     rnode >> featureTypeStr;
51     featureType = !featureTypeStr.compare( CC_HAAR ) ? CvFeatureParams::HAAR :
52                   !featureTypeStr.compare( CC_LBP ) ? CvFeatureParams::LBP : -1;
53     if (featureType == -1)
54         return false;
55     node[CC_HEIGHT] >> winSize.height;
56     node[CC_WIDTH] >> winSize.width;
57     return winSize.height > 0 && winSize.width > 0;
58 }
59
60 void CvCascadeParams::printDefaults() const
61 {
62     CvParams::printDefaults();
63     cout << "  [-stageType <";
64     for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )
65     {
66         cout << (i ? " | " : "") << stageTypes[i];
67         if ( i == defaultStageType )
68             cout << "(default)";
69     }
70     cout << ">]" << endl;
71
72     cout << "  [-featureType <{";
73     for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )
74     {
75         cout << (i ? ", " : "") << featureTypes[i];
76         if ( i == defaultStageType )
77             cout << "(default)";
78     }
79     cout << "}>]" << endl;
80     cout << "  [-w <sampleWidth = " << winSize.width << ">]" << endl;
81     cout << "  [-h <sampleHeight = " << winSize.height << ">]" << endl;
82 }
83
84 void CvCascadeParams::printAttrs() const
85 {
86     cout << "stageType: " << stageTypes[stageType] << endl;
87     cout << "featureType: " << featureTypes[featureType] << endl;
88     cout << "sampleWidth: " << winSize.width << endl;
89     cout << "sampleHeight: " << winSize.height << endl;
90 }
91
92 bool CvCascadeParams::scanAttr( const String prmName, const String val )
93 {
94     bool res = true;
95     if( !prmName.compare( "-stageType" ) )
96     {
97         for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )
98             if( !val.compare( stageTypes[i] ) )
99                 stageType = i;
100     }
101     else if( !prmName.compare( "-featureType" ) )
102     {
103         for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )
104             if( !val.compare( featureTypes[i] ) )
105                 featureType = i;
106     }
107     else if( !prmName.compare( "-w" ) )
108     {
109         winSize.width = atoi( val.c_str() );
110     }
111     else if( !prmName.compare( "-h" ) )
112     {
113         winSize.height = atoi( val.c_str() );
114     }
115     else
116         res = false;
117     return res;
118 }
119
120 //---------------------------- CascadeClassifier --------------------------------------
121
122 bool CvCascadeClassifier::train( const String _cascadeDirName,
123                                 const String _posFilename,
124                                 const String _negFilename, 
125                                 int _numPos, int _numNeg, 
126                                 int _numPrecalcVal, int _numPrecalcIdx,
127                                 int _numStages,
128                                 const CvCascadeParams& _cascadeParams,
129                                 const CvFeatureParams& _featureParams,
130                                 const CvCascadeBoostParams& _stageParams,
131                                 bool baseFormatSave )
132 {   
133     if( _cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty() )
134         CV_Error( CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL" );
135
136     String dirName;
137     if ( _cascadeDirName.find('/') )
138         dirName = _cascadeDirName + '/';
139     else
140         dirName = _cascadeDirName + '\\';
141
142     numPos = _numPos;
143     numNeg = _numNeg;
144     numStages = _numStages;
145     imgReader.create( _posFilename, _negFilename, cascadeParams.winSize );
146     if ( !load( dirName ) )
147     {
148         cascadeParams = _cascadeParams;
149         featureParams = CvFeatureParams::create(cascadeParams.featureType);
150         featureParams->init(_featureParams);
151         stageParams = new CvCascadeBoostParams;
152         *stageParams = _stageParams;
153         featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);
154         featureEvaluator->init( (CvFeatureParams*)featureParams, numPos + numNeg, cascadeParams.winSize );
155         stageClassifiers.reserve( numStages );
156     }
157     cout << "PARAMETERS:" << endl;
158     cout << "cascadeDirName: " << _cascadeDirName << endl;
159     cout << "vecFileName: " << _posFilename << endl;
160     cout << "bgFileName: " << _negFilename << endl;
161     cout << "numPos: " << _numPos << endl;
162     cout << "numNeg: " << _numNeg << endl;
163     cout << "numStages: " << numStages << endl;
164     cout << "numPrecalcValues : " << _numPrecalcVal << endl;
165     cout << "numPrecalcIndices : " << _numPrecalcIdx << endl;
166     cascadeParams.printAttrs();
167     stageParams->printAttrs();
168     featureParams->printAttrs();
169
170     int startNumStages = (int)stageClassifiers.size();
171     if ( startNumStages > 1 )
172         cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;
173     else if ( startNumStages == 1)
174         cout << endl << "Stage 0 is loaded" << endl;
175     
176     double requiredLeafFARate = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) /
177                                 (double)stageParams->max_depth;
178     double tempLeafFARate;
179     
180     for( int i = startNumStages; i < numStages; i++ )
181     {
182         cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
183         cout << "<BEGIN" << endl;
184
185         if ( !updateTrainingSet( tempLeafFARate ) ) 
186         {
187             cout << "Train dataset for temp stage can not be filled."
188                 "Branch training terminated." << endl;
189             break;
190         }
191         if( tempLeafFARate <= requiredLeafFARate )
192         {
193             cout << "Required leaf false alarm rate achieved. "
194                  "Branch training terminated." << endl;
195             break;
196         }
197
198         CvCascadeBoost* tempStage = new CvCascadeBoost;
199         tempStage->train( (CvFeatureEvaluator*)featureEvaluator,
200                            curNumSamples, _numPrecalcVal, _numPrecalcIdx,
201                           *((CvCascadeBoostParams*)stageParams) );
202         stageClassifiers.push_back( tempStage );
203
204         cout << "END>" << endl;
205         
206         // save params
207         String filename;
208         if ( i == 0) 
209         {
210             filename = dirName + CC_PARAMS_FILENAME;
211             FileStorage fs( filename, FileStorage::WRITE);
212             if ( !fs.isOpened() )
213                 return false;
214             fs << FileStorage::getDefaultObjectName(filename) << "{";
215             writeParams( fs );
216             fs << "}";
217         }
218         // save temp stage
219         char buf[10];
220         sprintf(buf, "%s%d", "stage", i );
221         filename = dirName + buf + ".xml";
222         FileStorage fs( filename, FileStorage::WRITE );
223         if ( !fs.isOpened() )
224             return false;
225         fs << FileStorage::getDefaultObjectName(filename) << "{";
226         tempStage->write( fs, Mat() );
227         fs << "}";
228     }
229     save( dirName + CC_CASCADE_FILENAME, baseFormatSave );
230     return true;
231 }
232
233 int CvCascadeClassifier::predict( int sampleIdx )
234 {
235     CV_DbgAssert( sampleIdx < numPos + numNeg );
236     for (Vector<Ptr<CvCascadeBoost>>::iterator it = stageClassifiers.begin();
237         it != stageClassifiers.end(); it++ )
238     {
239         if ( (*it)->predict( sampleIdx ) == 0.f )
240             return 0;
241     }
242     return 1;
243 }
244
245 bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio)
246 {
247     int64 posConsumed = 0, negConsumed = 0;
248     imgReader.restart();
249     int posCount = fillPassedSamles( 0, numPos, true, posConsumed );
250     if( !posCount )
251         return false;
252     cout << "POS count : consumed   " << posCount << " : " << (int)posConsumed << endl;
253
254     int negCount = fillPassedSamles( numPos, numNeg, false, negConsumed );
255     if ( !negCount )
256         return false;
257     curNumSamples = posCount + negCount;
258     acceptanceRatio = negConsumed == 0 ? 0 : ( (double)negCount/(double)(int64)negConsumed );
259     cout << "NEG count : acceptanceRatio    " << negCount << " : " << acceptanceRatio << endl;
260     return true;
261 }
262
263 int CvCascadeClassifier::fillPassedSamles( int first, int count, bool isPositive, int64& consumed )
264 {
265     int getcount = 0;
266     Mat img(cascadeParams.winSize, CV_8UC1);
267     for( int i = first; i < first + count; i++ )
268     {
269         for( ; ; )
270         {
271             bool isGetImg = isPositive ? imgReader.getPos( img ) :
272                                            imgReader.getNeg( img );
273             if( !isGetImg ) 
274                 return getcount;
275             consumed++;
276
277             featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
278             if( predict( i ) == 1.0F )
279             {
280                 getcount++;
281                 break;
282             }
283         }
284     }
285     return getcount;
286 }
287
288 void CvCascadeClassifier::writeParams( FileStorage &fs ) const
289 {
290     cascadeParams.write( fs );
291     fs << CC_STAGE_PARAMS << "{"; stageParams->write( fs ); fs << "}";
292     fs << CC_FEATURE_PARAMS << "{"; featureParams->write( fs ); fs << "}";
293 }
294
295 void CvCascadeClassifier::writeFeatures( FileStorage &fs, const Mat& featureMap ) const
296 {
297     ((CvFeatureEvaluator*)((Ptr<CvFeatureEvaluator>)featureEvaluator))->writeFeatures( fs, featureMap ); 
298 }
299
300 void CvCascadeClassifier::writeStages( FileStorage &fs, const Mat& featureMap ) const
301 {
302     //char cmnt[30];
303     //int i = 0;
304     fs << CC_STAGES << "["; 
305     for( Vector<Ptr<CvCascadeBoost>>::const_iterator it = stageClassifiers.begin();
306         it != stageClassifiers.end(); it++/*, i++*/ )
307     {
308         /*sprintf( cmnt, "stage %d", i );
309         CV_CALL( cvWriteComment( fs, cmnt, 0 ) );*/
310         fs << "{";
311         ((CvCascadeBoost*)((Ptr<CvCascadeBoost>)*it))->write( fs, featureMap );
312         fs << "}";
313     }
314     fs << "]";
315 }
316
317 bool CvCascadeClassifier::readParams( const FileNode &node )
318 {
319     if ( !node.isMap() || !cascadeParams.read( node ) )
320         return false;
321     
322     stageParams = new CvCascadeBoostParams;
323     FileNode rnode = node[CC_STAGE_PARAMS];
324     if ( !stageParams->read( rnode ) )
325         return false;
326     
327     featureParams = CvFeatureParams::create(cascadeParams.featureType);
328     rnode = node[CC_FEATURE_PARAMS];
329     if ( !featureParams->read( rnode ) )
330         return false;
331     return true;    
332 }
333
334 bool CvCascadeClassifier::readStages( const FileNode &node)
335 {
336     FileNode rnode = node[CC_STAGES];
337     if (!rnode.empty() || !rnode.isSeq())
338         return false;
339     stageClassifiers.reserve(numStages);
340     FileNodeIterator it = rnode.begin();
341     for( int i = 0; i < min( (int)rnode.size(), numStages ); i++, it++ )
342     {
343         CvCascadeBoost* tempStage = new CvCascadeBoost;
344         if ( !tempStage->read( *it, (CvFeatureEvaluator *)featureEvaluator, *((CvCascadeBoostParams*)stageParams) ) )
345         {
346             delete tempStage;
347             return false;
348         }
349         stageClassifiers.push_back(tempStage);
350     }
351     return true;
352 }
353
354 // For old Haar Classifier file saving
355 #define ICV_HAAR_SIZE_NAME            "size"
356 #define ICV_HAAR_STAGES_NAME          "stages"
357 #define ICV_HAAR_TREES_NAME             "trees"
358 #define ICV_HAAR_FEATURE_NAME             "feature"
359 #define ICV_HAAR_RECTS_NAME                 "rects"
360 #define ICV_HAAR_TILTED_NAME                "tilted"
361 #define ICV_HAAR_THRESHOLD_NAME           "threshold"
362 #define ICV_HAAR_LEFT_NODE_NAME           "left_node"
363 #define ICV_HAAR_LEFT_VAL_NAME            "left_val"
364 #define ICV_HAAR_RIGHT_NODE_NAME          "right_node"
365 #define ICV_HAAR_RIGHT_VAL_NAME           "right_val"
366 #define ICV_HAAR_STAGE_THRESHOLD_NAME   "stage_threshold"
367 #define ICV_HAAR_PARENT_NAME            "parent"
368 #define ICV_HAAR_NEXT_NAME              "next"
369
370 void CvCascadeClassifier::save( const String filename, bool baseFormat )
371 {
372     FileStorage fs( filename, FileStorage::WRITE );
373
374     if ( !fs.isOpened() )
375         return;
376
377     fs << FileStorage::getDefaultObjectName(filename) << "{";
378     if ( !baseFormat )
379     {
380         Mat featureMap; 
381         getUsedFeaturesIdxMap( featureMap );
382         writeParams( fs );
383         fs << CC_STAGE_NUM << (int)stageClassifiers.size();
384         writeStages( fs, featureMap );
385         writeFeatures( fs, featureMap );
386     }
387     else
388     {
389         //char buf[256];
390         CvSeq* weak;
391         if ( cascadeParams.featureType != CvFeatureParams::HAAR )
392             CV_Error( CV_StsBadFunc, "old file format is used for Haar-like features only");
393         fs << ICV_HAAR_SIZE_NAME << "[:" << cascadeParams.winSize.width << 
394             cascadeParams.winSize.height << "]";
395         fs << ICV_HAAR_STAGES_NAME << "[";
396         for( size_t si = 0; si < stageClassifiers.size(); si++ )
397         {
398             fs << "{"; //stage
399             /*sprintf( buf, "stage %d", si );
400             CV_CALL( cvWriteComment( fs, buf, 1 ) );*/
401             weak = stageClassifiers[si]->get_weak_predictors();
402             fs << ICV_HAAR_TREES_NAME << "[";
403             for( int wi = 0; wi < weak->total; wi++ )
404             {
405                 int inner_node_idx = -1, total_inner_node_idx = -1;
406                 queue<const CvDTreeNode*> inner_nodes_queue;
407                 CvCascadeBoostTree* tree = *((CvCascadeBoostTree**) cvGetSeqElem( weak, wi ));
408                 
409                 fs << "[";
410                 /*sprintf( buf, "tree %d", wi );
411                 CV_CALL( cvWriteComment( fs, buf, 1 ) );*/
412
413                 const CvDTreeNode* tempNode;
414                 
415                 inner_nodes_queue.push( tree->get_root() );
416                 total_inner_node_idx++;
417                 
418                 while (!inner_nodes_queue.empty())
419                 {
420                     tempNode = inner_nodes_queue.front();
421                     inner_node_idx++;
422
423                     fs << "{";
424                     fs << ICV_HAAR_FEATURE_NAME << "{";
425                     ((CvHaarEvaluator*)((CvFeatureEvaluator*)featureEvaluator))->writeFeature( fs, tempNode->split->var_idx );
426                     fs << "}";
427
428                     fs << ICV_HAAR_THRESHOLD_NAME << tempNode->split->ord.c;
429
430                     if( tempNode->left->left || tempNode->left->right )
431                     {
432                         inner_nodes_queue.push( tempNode->left );
433                         total_inner_node_idx++;
434                         fs << ICV_HAAR_LEFT_NODE_NAME << total_inner_node_idx;
435                     }
436                     else
437                         fs << ICV_HAAR_LEFT_VAL_NAME << tempNode->left->value;
438
439                     if( tempNode->right->left || tempNode->right->right )
440                     {
441                         inner_nodes_queue.push( tempNode->right );
442                         total_inner_node_idx++;
443                         fs << ICV_HAAR_RIGHT_NODE_NAME << total_inner_node_idx;
444                     }
445                     else
446                         fs << ICV_HAAR_RIGHT_VAL_NAME << tempNode->right->value;
447                     fs << "}"; // ICV_HAAR_FEATURE_NAME
448                     inner_nodes_queue.pop();
449                 }
450                 fs << "]";
451             }
452             fs << "]"; //ICV_HAAR_TREES_NAME
453             fs << ICV_HAAR_STAGE_THRESHOLD_NAME << stageClassifiers[si]->getThreshold();
454             fs << ICV_HAAR_PARENT_NAME << (int)si-1 << ICV_HAAR_NEXT_NAME << -1;
455             fs << "}"; //stage
456         } /* for each stage */
457         fs << "]"; //ICV_HAAR_STAGES_NAME
458     }
459     fs << "}";
460 }
461
462 bool CvCascadeClassifier::load( const String cascadeDirName )
463 {
464     FileStorage fs( cascadeDirName + CC_PARAMS_FILENAME, FileStorage::READ );
465     if ( !fs.isOpened() )
466         return false;
467     FileNode node = fs.getFirstTopLevelNode();
468     if ( !readParams( node ) )
469         return false;
470     featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);
471     featureEvaluator->init( ((CvFeatureParams*)featureParams), numPos + numNeg, cascadeParams.winSize );
472     fs.release();
473
474     char buf[10];
475     for ( int si = 0; si < numStages; si++ )
476     {
477         sprintf( buf, "%s%d", "stage", si);
478         fs.open( cascadeDirName + buf + ".xml", FileStorage::READ );
479         node = fs.getFirstTopLevelNode();
480         if ( !fs.isOpened() )
481             break;
482         CvCascadeBoost *tempStage = new CvCascadeBoost; 
483
484         if ( !tempStage->read( node, (CvFeatureEvaluator*)featureEvaluator, *((CvCascadeBoostParams*)stageParams )) )
485         {
486             delete tempStage;
487             fs.release();
488             break;
489         }
490         stageClassifiers.push_back(tempStage);
491     }
492     return true;
493 }
494
495 void CvCascadeClassifier::getUsedFeaturesIdxMap( Mat& featureMap )
496 {
497     featureMap.create( 1, featureEvaluator->getNumFeatures(), CV_32SC1 );
498     featureMap.setTo(Scalar(-1));
499     
500     for( Vector<Ptr<CvCascadeBoost>>::const_iterator it = stageClassifiers.begin();
501         it != stageClassifiers.end(); it++ )
502         ((CvCascadeBoost*)((Ptr<CvCascadeBoost>)(*it)))->markUsedFeaturesInMap( featureMap );
503     
504     for( int fi = 0, idx = 0; fi < featureEvaluator->getNumFeatures(); fi++ )
505         if ( featureMap.at<int>(0, fi) >= 0 )
506             featureMap.ptr<int>(0)[fi] = idx++;
507 }