]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/tests/cv/src/acascadeandhog.cpp
added progress points to cascade and hog tests
[opencv.git] / opencv / tests / cv / src / acascadeandhog.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "cvtest.h"
43 using namespace cv;
44 using namespace std;
45
46 //#define GET_STAT
47
48 #define DIST_E              "distE"
49 #define S_E                 "sE"
50 #define NO_PAIR_E           "noPairE"
51 //#define TOTAL_NO_PAIR_E     "totalNoPairE"
52
53 #define DETECTOR_NAMES      "detector_names"
54 #define DETECTORS                       "detectors"
55 #define IMAGE_FILENAMES     "image_filenames"
56 #define VALIDATION          "validation"
57 #define FILENAME                        "fn"
58
59 #define C_SCALE_CASCADE         "scale_cascade"
60
61 class CV_DetectorTest : public CvTest
62 {
63 public:
64     CV_DetectorTest( const char* test_name );
65     virtual int init( CvTS* system );
66 protected:
67     virtual int prepareData( FileStorage& fs );
68     virtual void run( int startFrom );
69     virtual string& getValidationFilename();
70         
71         virtual void readDetector( const FileNode& fn ) = 0;
72         virtual void writeDetector( FileStorage& fs, int di ) = 0;
73     int runTestCase( int detectorIdx, vector<vector<Rect> >& objects );
74     virtual int detectMultiScale( int di, const Mat& img, vector<Rect>& objects ) = 0;
75     int validate( int detectorIdx, vector<vector<Rect> >& objects );
76
77     struct
78     {
79         float dist;
80         float s;
81         float noPair;
82         //float totalNoPair;
83     } eps;
84     vector<string> detectorNames;
85     vector<string> detectorFilenames;
86     vector<string> imageFilenames;
87     vector<Mat> images;
88     string validationFilename;
89     FileStorage validationFS;
90 };
91
92 CV_DetectorTest::CV_DetectorTest( const char* test_name ) : CvTest( test_name, "detectMultiScale" )
93 {
94 }
95
96 int CV_DetectorTest::init(CvTS *system)
97 {
98     clear();
99     ts = system;
100     string dataPath = ts->get_data_path();
101     validationFS.open( dataPath + getValidationFilename(), FileStorage::READ );
102     return prepareData( validationFS );
103 }
104
105 string& CV_DetectorTest::getValidationFilename()
106 {
107     return validationFilename;
108 }
109
110 int CV_DetectorTest::prepareData( FileStorage& _fs )
111 {
112     if( !_fs.isOpened() )
113         test_case_count = -1;
114     else
115     {
116         FileNode fn = _fs.getFirstTopLevelNode();
117
118         fn[DIST_E] >> eps.dist;
119         fn[S_E] >> eps.s;
120         fn[NO_PAIR_E] >> eps.noPair;
121 //        fn[TOTAL_NO_PAIR_E] >> eps.totalNoPair;
122
123         // read detectors
124         if( fn[DETECTOR_NAMES].node->data.seq != 0 )
125         {
126             FileNodeIterator it = fn[DETECTOR_NAMES].begin();
127             for( ; it != fn[DETECTOR_NAMES].end(); )
128             {
129                 string name;
130                 it >> name; 
131                 detectorNames.push_back(name);
132                                 readDetector(fn[DETECTORS][name]);
133             }
134         }
135         test_case_count = (int)detectorNames.size();
136
137         // read images filenames and images
138         string dataPath = ts->get_data_path();
139         if( fn[IMAGE_FILENAMES].node->data.seq != 0 )
140         {
141             for( FileNodeIterator it = fn[IMAGE_FILENAMES].begin(); it != fn[IMAGE_FILENAMES].end(); )
142             {
143                 string filename;
144                 it >> filename;
145                 imageFilenames.push_back(filename);
146                 Mat img = imread( dataPath+filename, 1 );
147                 images.push_back( img );
148             }
149         }
150     }
151     return CvTS::OK;
152 }
153
154 void CV_DetectorTest::run( int start_from )
155 {
156     int code = CvTS::OK;
157     start_from = 0;
158
159 #ifdef GET_STAT
160     validationFS.release();
161     string filename = ts->get_data_path();
162     filename += getValidationFilename();
163     validationFS.open( filename, FileStorage::WRITE );
164     validationFS << FileStorage::getDefaultObjectName(validationFilename) << "{";
165
166     validationFS << DIST_E << eps.dist;
167     validationFS << S_E << eps.s;
168     validationFS << NO_PAIR_E << eps.noPair;
169 //    validationFS << TOTAL_NO_PAIR_E << eps.totalNoPair;
170
171     // write detector names
172     validationFS << DETECTOR_NAMES << "[";
173     vector<string>::const_iterator nit = detectorNames.begin();
174     for( ; nit != detectorNames.end(); ++nit )
175     {
176         validationFS << *nit;
177     }
178     validationFS << "]"; // DETECTOR_NAMES
179
180         // write detectors
181         validationFS << DETECTORS << "{";
182         assert( detectorNames.size() == detectorFilenames.size() );
183         nit = detectorNames.begin();
184         for( int di = 0; di < detectorNames.size(), nit != detectorNames.end(); ++nit, di++ )
185         {
186                 validationFS << *nit << "{";
187                 writeDetector( validationFS, di );
188                 validationFS << "}";
189         }
190         validationFS << "}";
191     
192     // write image filenames
193     validationFS << IMAGE_FILENAMES << "[";
194     vector<string>::const_iterator it = imageFilenames.begin();
195     for( int ii = 0; it != imageFilenames.end(); ++it, ii++ )
196     {
197         char buf[10];
198         sprintf( buf, "%s%d", "img_", ii );
199         cvWriteComment( validationFS.fs, buf, 0 );
200         validationFS << *it;
201     }
202     validationFS << "]"; // IMAGE_FILENAMES
203
204     validationFS << VALIDATION << "{";
205 #endif
206
207     int progress = 0;
208     for( int di = 0; di < test_case_count; di++ )
209     {
210         progress = update_progress( progress, di, test_case_count, 0 );
211 #ifdef GET_STAT
212         validationFS << detectorNames[di] << "{";
213 #endif
214         vector<vector<Rect> > objects;
215         int temp_code = runTestCase( di, objects );
216 #ifndef GET_STAT
217         if (temp_code == CvTS::OK)
218             temp_code = validate( di, objects );
219 #endif
220         if (temp_code != CvTS::OK)
221             code = temp_code;
222 #ifdef GET_STAT
223         validationFS << "}"; // detectorNames[di]
224 #endif
225     }
226
227 #ifdef GET_STAT
228     validationFS << "}"; // VALIDATION
229     validationFS << "}"; // getDefaultObjectName
230 #endif
231     if ( test_case_count <= 0 || imageFilenames.size() <= 0 )
232     {
233         ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
234         code = CvTS::FAIL_INVALID_TEST_DATA;
235     }
236     ts->set_failed_test_info( code );
237 }
238
239 int CV_DetectorTest::runTestCase( int detectorIdx, vector<vector<Rect> >& objects )
240 {
241     string dataPath = ts->get_data_path(), detectorFilename;
242     if( !detectorFilenames[detectorIdx].empty() )
243         detectorFilename = dataPath + detectorFilenames[detectorIdx];
244
245     for( int ii = 0; ii < (int)imageFilenames.size(); ++ii )
246     {
247         vector<Rect> imgObjects;
248         Mat image = images[ii];
249         if( image.empty() )
250         {
251             char msg[30];
252             sprintf( msg, "%s %d %s", "image ", ii, " can not be read" );
253             ts->printf( CvTS::LOG, msg );
254             return CvTS::FAIL_INVALID_TEST_DATA;
255         }
256         int code = detectMultiScale( detectorIdx, image, imgObjects );
257                 if( code != CvTS::OK )
258                         return code;
259
260         objects.push_back( imgObjects );
261
262 #ifdef GET_STAT
263         char buf[10];
264         sprintf( buf, "%s%d", "img_", ii );
265         string imageIdxStr = buf;
266         validationFS << imageIdxStr << "[:";
267         for( vector<Rect>::const_iterator it = imgObjects.begin();
268                 it != imgObjects.end(); ++it )
269         {
270             validationFS << it->x << it->y << it->width << it->height;
271         }
272         validationFS << "]"; // imageIdxStr
273 #endif
274     }
275     return CvTS::OK;
276 }
277
278
279 bool isZero( uchar i ) {return i == 0;}
280
281 int CV_DetectorTest::validate( int detectorIdx, vector<vector<Rect> >& objects )
282 {
283     assert( imageFilenames.size() == objects.size() );
284     int imageIdx = 0;
285     int totalNoPair = 0, totalValRectCount = 0;
286
287     for( vector<vector<Rect> >::const_iterator it = objects.begin();
288         it != objects.end(); ++it, imageIdx++ ) // for image
289     {
290         Size imgSize = images[imageIdx].size();
291         float dist = min(imgSize.height, imgSize.width) * eps.dist;
292         float wDiff = imgSize.width * eps.s;
293         float hDiff = imgSize.height * eps.s;
294
295         int noPair = 0;
296
297         // read validation rectangles
298         char buf[10];
299         sprintf( buf, "%s%d", "img_", imageIdx );
300         string imageIdxStr = buf;
301         FileNode node = validationFS.getFirstTopLevelNode()[VALIDATION][detectorNames[detectorIdx]][imageIdxStr];
302         vector<Rect> valRects;
303         if( node.node->data.seq != 0 )
304         {
305             for( FileNodeIterator it = node.begin(); it != node.end(); )
306             {
307                 Rect r;
308                 it >> r.x >> r.y >> r.width >> r.height;
309                 valRects.push_back(r);
310             }
311         }
312         totalValRectCount += (int)valRects.size();
313                 
314         // compare rectangles
315                 vector<uchar> map(valRects.size(), 0);
316         for( vector<Rect>::const_iterator cr = it->begin();
317             cr != it->end(); ++cr )
318         {
319             // find nearest rectangle
320             Point2f cp1 = Point2f( cr->x + (float)cr->width/2.0f, cr->y + (float)cr->height/2.0f );
321             int minIdx = -1, vi = 0;
322             float minDist = (float)norm( Point(imgSize.width, imgSize.height) );
323             for( vector<Rect>::const_iterator vr = valRects.begin();
324                 vr != valRects.end(); ++vr, vi++ )
325             {
326                 Point2f cp2 = Point2f( vr->x + (float)vr->width/2.0f, vr->y + (float)vr->height/2.0f );
327                 float curDist = (float)norm(cp1-cp2);
328                 if( curDist < minDist )
329                 {
330                     minIdx = vi;
331                     minDist = curDist;
332                 }
333             }
334             if( minIdx == -1 )
335             {
336                 noPair++;
337             }
338             else
339             {
340                 Rect vr = valRects[minIdx];
341                 if( map[minIdx] != 0 || (minDist > dist) || (abs(cr->width - vr.width) > wDiff) ||
342                                                                                                                 (abs(cr->height - vr.height) > hDiff) )
343                     noPair++;
344                                 else
345                                         map[minIdx] = 1;
346             }
347         }
348         noPair += (int)count_if( map.begin(), map.end(), isZero );
349         totalNoPair += noPair;
350         if( noPair > valRects.size()*eps.noPair+1 )
351             break;
352     }
353     if( imageIdx < (int)imageFilenames.size() )
354     {
355         char msg[500];
356         sprintf( msg, "detector %s has overrated count of rectangles without pair on %s-image",
357             detectorNames[detectorIdx].c_str(), imageFilenames[imageIdx].c_str() );
358         ts->printf( CvTS::LOG, msg );
359         return CvTS::FAIL_BAD_ACCURACY;
360     }
361     if ( totalNoPair > totalValRectCount*eps./*total*/noPair+1 )
362     {
363         ts->printf( CvTS::LOG, "overrated count of rectangles without pair on all images set" );
364         return CvTS::FAIL_BAD_ACCURACY;
365     }
366
367     return CvTS::OK;
368 }
369
370 //----------------------------------------------- CascadeDetectorTest -----------------------------------
371 class CV_CascadeDetectorTest : public CV_DetectorTest
372 {
373 public:
374     CV_CascadeDetectorTest( const char* test_name );
375 protected:
376         virtual void readDetector( const FileNode& fn );
377         virtual void writeDetector( FileStorage& fs, int di );
378     virtual int detectMultiScale( int di, const Mat& img, vector<Rect>& objects );
379         vector<int> flags;
380 };
381
382 CV_CascadeDetectorTest::CV_CascadeDetectorTest(const char *test_name)
383     : CV_DetectorTest( test_name )
384 {
385     validationFilename = "cascadeandhog/cascade.xml";
386 }
387
388 void CV_CascadeDetectorTest::readDetector( const FileNode& fn )
389 {
390         string filename;
391         int flag;
392         fn[FILENAME] >> filename;
393         detectorFilenames.push_back(filename);
394         fn[C_SCALE_CASCADE] >> flag;
395         if( flag )
396                 flags.push_back( 0 );
397         else
398                 flags.push_back( CV_HAAR_SCALE_IMAGE );
399 }
400
401 void CV_CascadeDetectorTest::writeDetector( FileStorage& fs, int di )
402 {
403         int sc = flags[di] & CV_HAAR_SCALE_IMAGE ? 0 : 1;
404         fs << FILENAME << detectorFilenames[di];
405         fs << C_SCALE_CASCADE << sc;
406 }
407
408 int CV_CascadeDetectorTest::detectMultiScale( int di, const Mat& img,
409                                               vector<Rect>& objects)
410 {
411         string dataPath = ts->get_data_path(), filename;
412         filename = dataPath + detectorFilenames[di];
413     CascadeClassifier cascade( filename );
414         if( cascade.empty() )
415         {
416                 ts->printf( CvTS::LOG, "cascade %s can not be opened");
417                 return CvTS::FAIL_INVALID_TEST_DATA;
418         }
419     Mat grayImg;
420     cvtColor( img, grayImg, CV_BGR2GRAY );
421     equalizeHist( grayImg, grayImg );
422     cascade.detectMultiScale( grayImg, objects, 1.1, 3, flags[di] );
423         return CvTS::OK;
424 }
425
426 //----------------------------------------------- HOGDetectorTest -----------------------------------
427 class CV_HOGDetectorTest : public CV_DetectorTest
428 {
429 public:
430     CV_HOGDetectorTest( const char* test_name );
431 protected:
432         virtual void readDetector( const FileNode& fn );
433         virtual void writeDetector( FileStorage& fs, int di );
434     virtual int detectMultiScale( int di, const Mat& img, vector<Rect>& objects );
435 };
436
437 CV_HOGDetectorTest::CV_HOGDetectorTest(const char *test_name)
438 : CV_DetectorTest( test_name )
439 {
440     validationFilename = "cascadeandhog/hog.xml";
441 }
442
443 void CV_HOGDetectorTest::readDetector( const FileNode& fn )
444 {
445         string filename;
446         if( fn[FILENAME].node->data.seq != 0 )
447                 fn[FILENAME] >> filename;
448         detectorFilenames.push_back( filename);
449 }
450
451 void CV_HOGDetectorTest::writeDetector( FileStorage& fs, int di )
452 {
453         fs << FILENAME << detectorFilenames[di];
454 }
455
456 int CV_HOGDetectorTest::detectMultiScale( int di, const Mat& img,
457                                               vector<Rect>& objects)
458 {
459     HOGDescriptor hog;
460     if( detectorFilenames[di].empty() )
461         hog.setSVMDetector(HOGDescriptor::getDefaultPeopleDetector());
462     else
463         assert(0);
464     hog.detectMultiScale(img, objects);
465         return CvTS::OK;
466 }
467
468 CV_CascadeDetectorTest cascadeTest("cascade-detector");
469 CV_HOGDetectorTest hogTest("hog-detector");