]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/tests/ml/src/aemknearestkmeans.cpp
some more fixed warnings
[opencv.git] / opencv / tests / ml / src / aemknearestkmeans.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "mltest.h"
43
44 using namespace std;
45 using namespace cv;
46
47 void defaultDistribs( vector<Mat>& means, vector<Mat>& covs )
48 {
49     float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
50     float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
51     float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
52     Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
53     Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
54     Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
55     means.resize(3), covs.resize(3);
56     m0.copyTo(means[0]), c0.copyTo(covs[0]);
57     m1.copyTo(means[1]), c1.copyTo(covs[1]);
58     m2.copyTo(means[2]), c2.copyTo(covs[2]);
59 }
60
61 // generate points sets by normal distributions
62 void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vector<Mat>& means, const vector<Mat>& covs, int labelType )
63 {
64     vector<int>::const_iterator sit = sizes.begin();
65     int total = 0;
66     for( ; sit != sizes.end(); ++sit )
67         total += *sit;
68     assert( means.size() == sizes.size() && covs.size() == sizes.size() );
69     assert( !data.empty() && data.rows == total );
70     assert( data.type() == CV_32FC1 );
71     
72     labels.create( data.rows, 1, labelType );
73
74     randn( data, Scalar::all(0.0), Scalar::all(1.0) );
75     vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
76     int bi, ei = 0;
77     sit = sizes.begin();
78     for( int p = 0, l = 0; sit != sizes.end(); ++sit, ++mit, ++cit, l++ )
79     {
80         bi = ei;
81         ei = bi + *sit;
82         assert( mit->rows == 1 && mit->cols == data.cols );
83         assert( cit->rows == data.cols && cit->cols == data.cols );
84         for( int i = bi; i < ei; i++, p++ )
85         {
86             Mat r(1, data.cols, CV_32FC1, data.ptr<float>(i));
87             r =  r * (*cit) + *mit; 
88             if( labelType == CV_32FC1 )
89                 labels.at<float>(p, 0) = (float)l;
90             else
91                 labels.at<int>(p, 0) = l;
92         }
93     }
94 }
95
96 int maxIdx( const vector<int>& count )
97 {
98     int idx = -1;
99     int maxVal = -1;
100     vector<int>::const_iterator it = count.begin();
101     for( int i = 0; it != count.end(); ++it, i++ )
102     {
103         if( *it > maxVal)
104         {
105             maxVal = *it;
106             idx = i;
107         }
108     }
109     assert( idx >= 0);
110     return idx;
111 }
112
113 bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
114 {
115     int total = 0, setCount = (int)sizes.size();
116     vector<int>::const_iterator sit = sizes.begin();
117     for( ; sit != sizes.end(); ++sit )
118         total += *sit;
119     assert( !labels.empty() );
120     assert( labels.rows == total && labels.cols == 1 );
121     assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
122
123     bool isFlt = labels.type() == CV_32FC1;
124     labelsMap.resize(setCount);
125     vector<int>::iterator lmit = labelsMap.begin();
126     vector<bool> buzy(setCount, false);
127     int bi, ei = 0;
128     for( sit = sizes.begin(); sit != sizes.end(); ++sit, ++lmit )
129     {
130         vector<int> count( setCount, 0 );
131         bi = ei;
132         ei = bi + *sit;
133         if( isFlt )
134         {
135             for( int i = bi; i < ei; i++ )
136                 count[(int)labels.at<float>(i, 0)]++;
137         }
138         else
139         {
140             for( int i = bi; i < ei; i++ )
141                 count[labels.at<int>(i, 0)]++;
142         }
143   
144         *lmit = maxIdx( count );
145         if( buzy[*lmit] )
146             return false;
147         buzy[*lmit] = true;
148     }
149     return true;    
150 }
151
152 float calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, bool labelsEquivalent = true )
153 {
154     int err = 0;
155     assert( !labels.empty() && !origLabels.empty() );
156     assert( labels.cols == 1 && origLabels.cols == 1 );
157     assert( labels.rows == origLabels.rows );
158     assert( labels.type() == origLabels.type() );
159     assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
160
161     vector<int> labelsMap;
162     bool isFlt = labels.type() == CV_32FC1;
163     if( !labelsEquivalent )
164     {
165         getLabelsMap( labels, sizes, labelsMap );
166         for( int i = 0; i < labels.rows; i++ )
167             if( isFlt )
168                 err += labels.at<float>(i, 0) != labelsMap[(int)origLabels.at<float>(i, 0)];
169             else
170                 err += labels.at<int>(i, 0) != labelsMap[origLabels.at<int>(i, 0)];
171     }
172     else
173     {
174         for( int i = 0; i < labels.rows; i++ )
175             if( isFlt )
176                 err += labels.at<float>(i, 0) != origLabels.at<float>(i, 0);
177             else
178                 err += labels.at<int>(i, 0) != origLabels.at<int>(i, 0);
179     }
180     return (float)err / (float)labels.rows;
181 }
182
183 //--------------------------------------------------------------------------------------------
184 class CV_KMeansTest : public CvTest {
185 public:
186     CV_KMeansTest() : CvTest( "kmeans", "kmeans" ) {}
187 protected:
188     virtual void run( int start_from );
189 };
190
191 void CV_KMeansTest::run( int /*start_from*/ )
192 {
193     const int iters = 100;
194     int sizesArr[] = { 5000, 7000, 8000 };
195     int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
196     
197     Mat data( pointsCount, 2, CV_32FC1 ), labels;
198     vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
199     vector<Mat> means, covs;
200     defaultDistribs( means, covs );
201     generateData( data, labels, sizes, means, covs, CV_32SC1 );
202     
203     int code = CvTS::OK;
204     Mat bestLabels;
205     // 1. flag==KMEANS_PP_CENTERS
206     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, 0 );
207     if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
208     {
209         ts->printf( CvTS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
210         code = CvTS::FAIL_BAD_ACCURACY;
211     }
212
213     // 2. flag==KMEANS_RANDOM_CENTERS
214     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, 0 );
215     if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
216     {
217         ts->printf( CvTS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
218         code = CvTS::FAIL_BAD_ACCURACY;
219     }
220
221     // 3. flag==KMEANS_USE_INITIAL_LABELS
222     labels.copyTo( bestLabels );
223     RNG rng;
224     for( int i = 0; i < 0.5f * pointsCount; i++ )
225         bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
226     kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, 0 );
227     if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
228     {
229         ts->printf( CvTS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
230         code = CvTS::FAIL_BAD_ACCURACY;
231     }
232
233     ts->set_failed_test_info( code );
234 }
235
236 //--------------------------------------------------------------------------------------------
237 class CV_KNearestTest : public CvTest {
238 public:
239     CV_KNearestTest() : CvTest( "knearest", "CvKNearest funcs" ) {}
240 protected:
241     virtual void run( int start_from );
242 };
243
244 void CV_KNearestTest::run( int /*start_from*/ )
245 {
246     int sizesArr[] = { 500, 700, 800 };
247     int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
248
249     // train data
250     Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
251     vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
252     vector<Mat> means, covs;
253     defaultDistribs( means, covs );
254     generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1 );
255
256     // test data
257     Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;
258     generateData( testData, testLabels, sizes, means, covs, CV_32FC1 );
259
260     int code = CvTS::OK;
261     KNearest knearest;
262     knearest.train( trainData, trainLabels );
263     knearest.find_nearest( testData, 4, &bestLabels );
264     if( calcErr( bestLabels, testLabels, sizes, true ) > 0.01f )
265     {
266         ts->printf( CvTS::LOG, "bad accuracy on test data" );
267         code = CvTS::FAIL_BAD_ACCURACY;
268     }
269     ts->set_failed_test_info( code );
270 }
271
272 //--------------------------------------------------------------------------------------------
273 class CV_EMTest : public CvTest {
274 public:
275     CV_EMTest() : CvTest( "em", "CvEM funcs" ) {}
276 protected:
277     virtual void run( int start_from );
278 };
279
280 void CV_EMTest::run( int /*start_from*/ )
281 {
282     int sizesArr[] = { 5000, 7000, 8000 };
283     int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
284
285     // train data
286     Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
287     vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
288     vector<Mat> means, covs;
289     defaultDistribs( means, covs );
290     generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
291
292     // test data
293     Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;
294     generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
295
296     int code = CvTS::OK;
297     ExpectationMaximization em;
298     CvEMParams params;
299     params.nclusters = 3;
300     em.train( trainData, Mat(), params, &bestLabels );
301
302     // check train error
303     if( calcErr( bestLabels, trainLabels, sizes, true ) > 0.002f )
304     {
305         ts->printf( CvTS::LOG, "bad accuracy on train data" );
306         code = CvTS::FAIL_BAD_ACCURACY;
307     }
308
309     // check test error
310     bestLabels.create( testData.rows, 1, CV_32SC1 );
311     for( int i = 0; i < testData.rows; i++ )
312     {
313         Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
314         bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
315     }
316     if( calcErr( bestLabels, testLabels, sizes, true ) > 0.005f )
317     {
318         ts->printf( CvTS::LOG, "bad accuracy on test data" );
319         code = CvTS::FAIL_BAD_ACCURACY;
320     }
321     
322     ts->set_failed_test_info( code );
323 }
324
325 CV_KMeansTest kmeans_test;
326 CV_KNearestTest knearest_test;
327 CV_EMTest em_test;