]> rtime.felk.cvut.cz Git - opencv.git/blob - opencv/tests/cv/src/anearestneighbors.cpp
added tests for flann( kmeans index )
[opencv.git] / opencv / tests / cv / src / anearestneighbors.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14 // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "cvtest.h"
44
45 #include <algorithm>
46 #include <vector>
47 #include <iostream>
48
49 using namespace cv;
50 using namespace cv::flann;
51
52 //--------------------------------------------------------------------------------
53 class NearestNeighborTest : public CvTest
54 {
55 public:
56     NearestNeighborTest( const char* test_name, const char* test_funcs ) 
57         : CvTest( test_name, test_funcs ) {}
58 protected:
59     virtual void run( int start_from );
60     virtual void createModel( const Mat& data ) = 0;
61     virtual void searchNeighbors( Mat& points, Mat& neighbors ) = 0;
62     virtual void releaseModel() = 0;
63 };
64
65 void NearestNeighborTest::run( int /*start_from*/ ) {
66     int dims = 64;
67     int featuresCount = 2000;
68     int K = 1; // * should also test 2nd nn etc.?
69     float noise = 0.2f;
70     int pointsCount = 1000;
71
72     RNG rng;
73     Mat desc( featuresCount, dims, CV_32FC1 );
74     rng.fill( desc, RNG::UNIFORM, Scalar(0.0f), Scalar(1.0f) );
75
76     createModel( desc );
77     
78     Mat points( pointsCount, dims, CV_32FC1 );
79     Mat results( pointsCount, K, CV_32SC1 );
80
81     std::vector<int> fmap( pointsCount );
82     for( int pi = 0; pi < pointsCount; pi++ )
83     {
84         int fi = rng.next() % featuresCount;
85         fmap[pi] = fi;
86         for( int d = 0; d < dims; d++ )
87             points.at<float>(pi, d) = desc.at<float>(fi, d) + rng.uniform(0.0f, 1.0f) * noise;
88     }
89     searchNeighbors( points, results );
90
91     releaseModel();
92
93     int correctMatches = 0;
94     for( int pi = 0; pi < pointsCount; pi++ )
95     {
96         if( fmap[pi] == results.at<int>(pi, 0) )
97             correctMatches++;
98     }
99
100     double correctPerc = correctMatches / (double)pointsCount;
101     ts->printf( CvTS::LOG, "correct_perc = %d\n", correctPerc );
102     if (correctPerc < .8)
103         ts->set_failed_test_info(CvTS::FAIL_INVALID_OUTPUT);
104 }
105
106 //--------------------------------------------------------------------------------
107 class CV_LSHTest : public NearestNeighborTest
108 {
109 public:
110     CV_LSHTest() : NearestNeighborTest( "lsh", "cvLSHQuery" ) {}
111 protected:
112     virtual void createModel( const Mat& data );
113     virtual void searchNeighbors( Mat& points, Mat& neighbors );
114     virtual void releaseModel();
115     struct CvLSH* lsh;
116     CvMat desc;
117 };
118
119 void CV_LSHTest::createModel( const Mat& data )
120 {
121     desc = data;
122     lsh = cvCreateMemoryLSH( data.cols, data.rows, 70, 20, CV_32FC1 );
123     cvLSHAdd( lsh, &desc );
124 }
125
126 void CV_LSHTest::searchNeighbors( Mat& points, Mat& neighbors )
127 {
128     const int emax = 20;
129     Mat dist( points.rows, neighbors.cols, CV_64FC1);
130     CvMat _dist = dist, _points = points, _neighbors = neighbors;
131     cvLSHQuery( lsh, &_points, &_neighbors, &_dist, neighbors.cols, emax );
132 }
133
134 void CV_LSHTest::releaseModel()
135 {
136     cvReleaseLSH( &lsh );
137 }
138
139 //--------------------------------------------------------------------------------
140 class CV_FeatureTreeTest_C : public NearestNeighborTest
141 {
142 public:
143     CV_FeatureTreeTest_C( const char* test_name, const char* test_funcs ) 
144         : NearestNeighborTest( test_name, test_funcs ) {}
145 protected:
146     virtual void searchNeighbors( Mat& points, Mat& neighbors );
147     virtual void releaseModel();
148     CvFeatureTree* tr;
149     CvMat desc;
150 };
151
152 void CV_FeatureTreeTest_C::searchNeighbors( Mat& points, Mat& neighbors )
153 {
154     const int emax = 20;
155     Mat dist( points.rows, neighbors.cols, CV_64FC1);
156     CvMat _dist = dist, _points = points, _neighbors = neighbors;
157     cvFindFeatures( tr, &_points, &_neighbors, &_dist, neighbors.cols, emax );
158 }
159
160 void CV_FeatureTreeTest_C::releaseModel()
161 {
162     cvReleaseFeatureTree( tr );
163 }
164
165 //--------------------------------------
166 class CV_SpillTreeTest_C : public CV_FeatureTreeTest_C
167 {
168 public:
169     CV_SpillTreeTest_C(): CV_FeatureTreeTest_C( "spilltree_c", "cvFindFeatures-spill" ) {}
170 protected:
171     virtual void createModel( const Mat& data );
172 };
173
174 void CV_SpillTreeTest_C::createModel( const Mat& data )
175 {
176     desc = data;
177     tr = cvCreateSpillTree( &desc );
178 }
179
180 //--------------------------------------
181 class CV_KDTreeTest_C : public CV_FeatureTreeTest_C
182 {
183 public:
184     CV_KDTreeTest_C(): CV_FeatureTreeTest_C( "kdtree_c", "cvFindFeatures-kd" ) {}
185 protected:
186     virtual void createModel( const Mat& data );
187 };
188
189 void CV_KDTreeTest_C::createModel( const Mat& data )
190 {
191     desc = data;
192     tr = cvCreateKDTree( &desc );
193 }
194
195 //--------------------------------------------------------------------------------
196 class CV_KDTreeTest_CPP : public NearestNeighborTest
197 {
198 public:
199     CV_KDTreeTest_CPP() : NearestNeighborTest( "kdtree_cpp", "cv::KDTree funcs" ) {}
200 protected:
201     virtual void createModel( const Mat& data );
202     virtual void searchNeighbors( Mat& points, Mat& neighbors );
203     virtual void releaseModel();
204     KDTree* tr;
205 };
206
207 void CV_KDTreeTest_CPP::createModel( const Mat& data )
208 {
209     tr = new KDTree( data );
210 }
211
212 void CV_KDTreeTest_CPP::searchNeighbors( Mat& points, Mat& neighbors )
213 {
214     const int emax = 20;
215     for( int pi = 0; pi < points.rows; pi++ )
216         tr->findNearest( points.ptr<float>(pi), neighbors.cols, emax, neighbors.ptr<int>(pi) );
217 }
218
219 void CV_KDTreeTest_CPP::releaseModel()
220 {
221     delete tr;
222 }
223
224 //--------------------------------------------------------------------------------
225 class CV_FlannTest : public NearestNeighborTest
226 {
227 public:
228     CV_FlannTest( const char* test_name, const char* test_funcs ) 
229         : NearestNeighborTest( test_name, test_funcs ) {}
230 protected:
231     void createIndex( const Mat& data, const IndexParams& params );
232     void knnSearch( Mat& points, Mat& neighbors );
233     void radiusSearch( Mat& points, Mat& neighbors );
234     virtual void releaseModel();
235     Index* index;
236 };
237
238 void CV_FlannTest::createIndex( const Mat& data, const IndexParams& params )
239 {
240     index = new Index( data, params );
241 }
242
243 void CV_FlannTest::knnSearch( Mat& points, Mat& neighbors )
244 {
245     Mat dist( points.rows, neighbors.cols, CV_32FC1);
246     index->knnSearch( points, neighbors, dist, 1, SearchParams() );
247 }
248
249 void CV_FlannTest::radiusSearch( Mat& points, Mat& neighbors )
250 {
251     Mat dist( 1, neighbors.cols, CV_32FC1);
252     // radiusSearch can only search one feature at a time for range search
253     for( int i = 0; i < points.rows; i++ )
254     {
255         Mat p( 1, points.cols, CV_32FC1, points.ptr<float>(i) ),
256             n( 1, neighbors.cols, CV_32SC1, neighbors.ptr<int>(i) );
257         index->radiusSearch( p, n, dist, 10.0f, SearchParams() );
258     }
259 }
260
261 void CV_FlannTest::releaseModel()
262 {
263     delete index;
264 }
265
266 //---------------------------------------
267 class CV_FlannLinearIndexTest : public CV_FlannTest
268 {
269 public:
270     CV_FlannLinearIndexTest() : CV_FlannTest( "flann_linear", "LinearIndex" ) {}
271 protected:
272     virtual void createModel( const Mat& data ) { createIndex( data, LinearIndexParams() ); }
273     virtual void searchNeighbors( Mat& points, Mat& neighbors ) { knnSearch( points, neighbors ); }
274 };
275
276 //---------------------------------------
277 class CV_FlannKMeansIndexTest : public CV_FlannTest
278 {
279 public:
280     CV_FlannKMeansIndexTest() : CV_FlannTest( "flann_kmeans", "KMeansIndex" ) {}
281 protected:
282     virtual void createModel( const Mat& data ) { createIndex( data, KMeansIndexParams() ); }
283     virtual void searchNeighbors( Mat& points, Mat& neighbors ) { radiusSearch( points, neighbors ); }
284 };
285
286 //---------------------------------------
287 class CV_FlannKDTreeIndexTest : public CV_FlannTest
288 {
289 public:
290     CV_FlannKDTreeIndexTest() : CV_FlannTest( "flann_kdtree", "KDTreeIndex" ) {}
291 protected:
292     virtual void createModel( const Mat& data ) { createIndex( data, KDTreeIndexParams() ); }
293     virtual void searchNeighbors( Mat& points, Mat& neighbors ) { radiusSearch( points, neighbors ); }
294 };
295
296 //----------------------------------------
297 class CV_FlannAutotunedIndexTest : public CV_FlannTest
298 {
299 public:
300     CV_FlannAutotunedIndexTest() : CV_FlannTest( "flann_autotuned", "AutotunedIndex" ) {}
301 protected:
302     virtual void createModel( const Mat& data ) { createIndex( data, AutotunedIndexParams() ); }
303     virtual void searchNeighbors( Mat& points, Mat& neighbors ) { knnSearch( points, neighbors ); }
304 };
305
306 CV_LSHTest lsh_test;
307 CV_SpillTreeTest_C spilltree_test_c;
308 CV_KDTreeTest_C kdtree_test_c;
309 CV_KDTreeTest_CPP kdtree_test_cpp;
310 CV_FlannLinearIndexTest flann_linear_index;
311 CV_FlannKMeansIndexTest flann_kmeans_index;
312 CV_FlannKDTreeIndexTest flann_kdtree_index;
313 CV_FlannAutotunedIndexTest flann_autotuned_index;