]> rtime.felk.cvut.cz Git - opencv.git/commitdiff
ticket 56 (SVM with set of selected features)
authormdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 16 Feb 2010 10:54:40 +0000 (10:54 +0000)
committermdim <mdim@73c94f0f-984f-4a5f-82bc-2d8db8d8ee08>
Tue, 16 Feb 2010 10:54:40 +0000 (10:54 +0000)
git-svn-id: https://code.ros.org/svn/opencv/trunk@2691 73c94f0f-984f-4a5f-82bc-2d8db8d8ee08

opencv/include/opencv/ml.h
opencv/src/ml/ml_inner_functions.cpp
opencv/src/ml/mlsvm.cpp

index 1ddccaab34b72d8a9426a2f3585618ec6695e8c2..1e1249dd2ece869d00d28bb4dc68712016e4fe5b 100644 (file)
@@ -585,6 +585,8 @@ protected:
     virtual void create_kernel();
     virtual void create_solver();
 
+    virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
+
     virtual void write_params( CvFileStorage* fs ) const;
     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
 
index 71a1c2ce84d29788b4514258b67f2861a8354f44..6bf898e220be6efb3114812c2cab1eb7f8129447 100644 (file)
@@ -1147,7 +1147,7 @@ cvPreparePredictData( const CvArr* _sample, int dims_all,
     if( CV_IS_MAT(sample) )
     {
         sample_data = sample->data.fl;
-        sample_step = sample->step / sizeof(row_sample[0]);
+        sample_step = CV_IS_MAT_CONT(sample->type) ? 1 : sample->step/sizeof(row_sample[0]);
 
         if( !comp_idx && CV_IS_MAT_CONT(sample->type) && !as_sparse )
             *_row_sample = sample_data;
@@ -1161,12 +1161,8 @@ cvPreparePredictData( const CvArr* _sample, int dims_all,
             else
             {
                 int* comp = comp_idx->data.i;
-                if( !sample_step )
-                    for( i = 0; i < dims_selected; i++ )
-                        row_sample[i] = sample_data[comp[i]];
-                else
-                    for( i = 0; i < dims_selected; i++ )
-                        row_sample[i] = sample_data[sample_step*comp[i]];
+                for( i = 0; i < dims_selected; i++ )
+                    row_sample[i] = sample_data[sample_step*comp[i]];
             }
 
             *_row_sample = row_sample;
index 40dc9ccb7a30e6a50aaaba983faffed297346b78..3530c03266182dd6e9949d3923670d57f85f8502 100644 (file)
@@ -1810,12 +1810,9 @@ bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
                         EXIT;
 
                     // Compute test set error on <test_size> samples
-                    CvMat s = cvMat( 1, var_count, CV_32FC1 );
                     for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
                     {
-                        float resp;
-                        s.data.fl = *test_samples_ptr;
-                        resp = predict( &s );
+                        float resp = predict( *test_samples_ptr, var_count );
                         error += is_regression ? powf( resp - *(float*)true_resp, 2 )
                             : ((int)resp != cls_lbls[*(int*)true_resp]);
                     }
@@ -1866,7 +1863,8 @@ bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
     delete solver;
     solver = 0;
     cvReleaseMemStorage( &temp_storage );
-    cvReleaseMat( &responses );
+    if( responses != _responses )
+        cvReleaseMat( &responses );
     cvReleaseMat( &responses_local );
     cvFree( &samples );
     cvFree( &samples_local );
@@ -1877,39 +1875,28 @@ bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
     return ok;
 }
 
-float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
+float CvSVM::predict( const float* row_sample, int row_len, bool returnDFVal ) const
 {
-    bool local_alloc = 0;
-    float result = 0;
-    float* row_sample = 0;
-    Qfloat* buffer = 0;
-
-    CV_FUNCNAME( "CvSVM::predict" );
-
-    __BEGIN__;
+    assert( kernel );
+    assert( row_sample );
 
-    int class_count;
-    int var_count, buf_sz;
-
-    if( !kernel )
-        CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
+    int var_count = get_var_count();
+    assert( row_len == var_count );
 
-    class_count = class_labels ? class_labels->cols :
+    int class_count = class_labels ? class_labels->cols :
                   params.svm_type == ONE_CLASS ? 1 : 0;
 
-    CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
-                                   class_count, 0, &row_sample ));
-
-    var_count = get_var_count();
-
-    buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
+    float result = 0;
+    bool local_alloc = 0;
+    float* buffer = 0;
+    int buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
     if( buf_sz <= CV_MAX_LOCAL_SIZE )
     {
-        CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
-        local_alloc = 1;
+        buffer = (float*)cvStackAlloc( buf_sz );
+        local_alloc = true;
     }
     else
-        CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
+        buffer = (float*)cvAlloc( buf_sz );
 
     if( params.svm_type == EPS_SVR ||
         params.svm_type == NU_SVR ||
@@ -1957,17 +1944,40 @@ float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
         result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
     }
     else
-        CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
+        CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
                                 "the SVM structure is probably corrupted" );
+    if( !local_alloc )
+        cvFree( &buffer );
+
+    return result;
+}
+
+float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
+{
+    float result = 0;
+    float* row_sample = 0;
+
+    CV_FUNCNAME( "CvSVM::predict" );
+
+    __BEGIN__;
+
+    int class_count;
+    
+    if( !kernel )
+        CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
+
+    class_count = class_labels ? class_labels->cols :
+                  params.svm_type == ONE_CLASS ? 1 : 0;
+
+    CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
+                                   class_count, 0, &row_sample ));
+    result = predict( row_sample, get_var_count(), returnDFVal );
 
     __END__;
 
     if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
         cvFree( &row_sample );
 
-    if( !local_alloc )
-        cvFree( &buffer );
-
     return result;
 }