]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/commitdiff
fft: Implement assertions in the base class
authorMichal Sojka <michal.sojka@cvut.cz>
Wed, 19 Sep 2018 10:53:00 +0000 (12:53 +0200)
committerMichal Sojka <michal.sojka@cvut.cz>
Thu, 20 Sep 2018 13:51:03 +0000 (15:51 +0200)
src/fft.cpp
src/fft.h
src/fft_cufft.cpp
src/fft_cufft.h
src/fft_fftw.cpp
src/fft_fftw.h
src/fft_opencv.cpp
src/fft_opencv.h

index 570e5fd1c160e06132e52234a21977e680fc19da..a7974009a4aa4b217b3dd9cdec4324f5563f8544 100644 (file)
@@ -1,7 +1,73 @@
 
 #include "fft.h"
+#include <cassert>
 
 Fft::~Fft()
 {
 
 }
+
+void Fft::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
+{
+    m_width = width;
+    m_height = height;
+    m_num_of_feats = num_of_feats;
+#ifdef BIG_BATCH
+    m_num_of_scales = num_of_scales;
+#else
+    (void)num_of_scales;
+#endif
+}
+
+void Fft::set_window(const MatDynMem &window)
+{
+    assert(window.dims == 2);
+    assert(window.size().width == m_width);
+    assert(window.size().height == m_height);
+    (void)window;
+}
+
+void Fft::forward(const MatDynMem &real_input, ComplexMat &complex_result)
+{
+    assert(real_input.dims == 2);
+    assert(real_input.size().width == m_width);
+    assert(real_input.size().height == m_height);
+    (void)real_input;
+    (void)complex_result;
+}
+
+void Fft::forward_window(MatDynMem &patch_feats, ComplexMat &complex_result, MatDynMem &tmp)
+{
+        assert(patch_feats.dims == 3);
+#ifndef BIG_BATCH
+        assert(patch_feats.size[0] == m_num_of_feats);
+#else
+        assert(patch_feats.size[0] == m_num_of_feats * m_num_of_scales);
+#endif
+        assert(patch_feats.size[1] == m_height);
+        assert(patch_feats.size[2] == m_width);
+
+        assert(tmp.dims == patch_feats.dims);
+        assert(tmp.size[0] == patch_feats.size[0]);
+        assert(tmp.size[1] == patch_feats.size[1]);
+        assert(tmp.size[2] == patch_feats.size[2]);
+
+        (void)patch_feats;
+        (void)complex_result;
+        (void)tmp;
+}
+
+void Fft::inverse(ComplexMat &complex_input, MatDynMem &real_result)
+{
+    assert(real_result.dims == 3);
+#ifndef BIG_BATCH
+    assert(real_result.size[0] == m_num_of_feats);
+#else
+    assert(real_result.size[0] == m_num_of_feats * m_num_of_scales);
+#endif
+    assert(real_result.size[1] == m_height);
+    assert(real_result.size[2] == m_width);
+
+    (void)complex_input;
+    (void)real_result;
+}
index 93dcea6145d3f848219afd855f8b6c5c0095bf02..a8c22bc215ae4aa281c0b26804fe23ca751e3fa8 100644 (file)
--- a/src/fft.h
+++ b/src/fft.h
 
 #ifdef BIG_BATCH
 #define BIG_BATCH_MODE 1
+#define IF_BIG_BATCH(true, false) true
 #else
 #define BIG_BATCH_MODE 0
+#define IF_BIG_BATCH(true, false) false
 #endif
 
 class Fft
 {
 public:
-    virtual void init(unsigned width, unsigned height,unsigned num_of_feats, unsigned num_of_scales) = 0;
-    virtual void set_window(const MatDynMem &window) = 0;
-    virtual void forward(MatDynMem & real_input, ComplexMat & complex_result) = 0;
-    virtual void forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_result, MatDynMem &tmp) = 0;
-    virtual void inverse(ComplexMat &  complex_input, MatDynMem & real_result) = 0;
+    virtual void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales);
+    virtual void set_window(const MatDynMem &window);
+    virtual void forward(const MatDynMem &real_input, ComplexMat &complex_result);
+    virtual void forward_window(MatDynMem &patch_feats_in, ComplexMat &complex_result, MatDynMem &tmp);
+    virtual void inverse(ComplexMat &complex_input, MatDynMem &real_result);
     virtual ~Fft() = 0;
 
     static cv::Size freq_size(cv::Size space_size)
@@ -38,12 +40,10 @@ public:
     }
 
 protected:
-    bool is_patch_feats_valid(const MatDynMem &patch_feats)
-    {
-        return patch_feats.dims == 3;
-               // && patch_feats.size[1] == width
-               // && patch_feats.size[2] == height
-    }
+    unsigned m_width, m_height, m_num_of_feats;
+#ifdef BIG_BATCH
+    unsigned m_num_of_scales;
+#endif
 };
 
 #endif // FFT_H
index 63f2558621c58857f3bfd1532686f7ca0d0c98b3..25023b92fe5f465ae5153f489c002b311bc5e370 100644 (file)
@@ -8,141 +8,79 @@ cuFFT::cuFFT()
 
 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
 {
-    m_width = width;
-    m_height = height;
-    m_num_of_feats = num_of_feats;
-    m_num_of_scales = num_of_scales;
+    Fft::init(width, height, num_of_feats, num_of_scales);
 
     std::cout << "FFT: cuFFT" << std::endl;
 
-    // FFT forward one scale
+    // FFT forward
     {
-        CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
-    }
-#ifdef BIG_BATCH
-    // FFT forward all scales
-    if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
         int rank = 2;
-        int n[] = {(int)m_height, (int)m_width};
-        int howmany = m_num_of_scales;
+        int n[] = {int(m_height), int(m_width)};
+        int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
         int istride = 1, ostride = 1;
-        int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
+        int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
 
-        CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
-                                      CUFFT_R2C, howmany));
-        CufftErrorCheck(cufftSetStream(plan_f_all_scales, cudaStreamPerThread));
+        CufftErrorCheck(cufftPlanMany(&plan_f, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
+        CufftErrorCheck(cufftSetStream(plan_f, cudaStreamPerThread));
     }
-#endif
-    // FFT forward window one scale
+
+    // FFT forward window
     {
         int rank = 2;
         int n[] = {int(m_height), int(m_width)};
-        int howmany = int(m_num_of_feats);
-        int idist = int(m_height * m_width), odist = int(m_height * (m_width / 2 + 1));
-        int istride = 1, ostride = 1;
-        int *inembed = n, onembed[] = {int(m_height), int(m_width / 2 + 1)};
-
-        CufftErrorCheck(
-            cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
-        CufftErrorCheck(cufftSetStream(plan_fw, cudaStreamPerThread));
-    }
-#ifdef BIG_BATCH
-    // FFT forward window all scales all feats
-    if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
-        int rank = 2;
-        int n[] = {(int)m_height, (int)m_width};
-        int howmany = m_num_of_scales * m_num_of_feats;
+        int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
         int istride = 1, ostride = 1;
-        int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
+        int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
 
-        CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
-                                      CUFFT_R2C, howmany));
-        CufftErrorCheck(cufftSetStream(plan_fw_all_scales, cudaStreamPerThread));
+        CufftErrorCheck(cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
+        CufftErrorCheck(cufftSetStream(plan_fw, cudaStreamPerThread));
     }
-#endif
-    // FFT inverse one scale
+    // FFT inverse all channels
     {
         int rank = 2;
         int n[] = {int(m_height), int(m_width)};
-        int howmany = int(m_num_of_feats);
+        int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
-        int istride = 1, ostride = int(m_num_of_feats);
+        int istride = 1, ostride = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
 
-        CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist,
-                                      CUFFT_C2R, howmany));
+        CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
         CufftErrorCheck(cufftSetStream(plan_i_features, cudaStreamPerThread));
     }
-    // FFT inverse all scales
-#ifdef BIG_BATCH
-    if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
-        int rank = 2;
-        int n[] = {(int)m_height, (int)m_width};
-        int howmany = m_num_of_feats * m_num_of_scales;
-        int idist = m_height * (m_width / 2 + 1), odist = 1;
-        int istride = 1, ostride = m_num_of_feats * m_num_of_scales;
-        int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
-
-        CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n, inembed, istride, idist, onembed, ostride,
-                                      odist, CUFFT_C2R, howmany));
-        CufftErrorCheck(cufftSetStream(plan_i_features_all_scales, cudaStreamPerThread));
-    }
-#endif
-    // FFT inverse one channel one scale
+    // FFT inverse one channel
     {
         int rank = 2;
         int n[] = {int(m_height), int(m_width)};
-        int howmany = 1;
-        int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
-        int istride = 1, ostride = 1;
+        int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
+        int idist = m_height * (m_width / 2 + 1), odist = 1;
+        int istride = 1, ostride = IF_BIG_BATCH(m_num_of_scales, 1);
         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
 
-        CufftErrorCheck(
-            cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
+        CufftErrorCheck(cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
         CufftErrorCheck(cufftSetStream(plan_i_1ch, cudaStreamPerThread));
     }
-#ifdef BIG_BATCH
-    // FFT inverse one channel all scales
-    if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
-        int rank = 2;
-        int n[] = {(int)m_height, (int)m_width};
-        int howmany = m_num_of_scales;
-        int idist = m_height * (m_width / 2 + 1), odist = 1;
-        int istride = 1, ostride = m_num_of_scales;
-        int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
-
-        CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
-                                      CUFFT_C2R, howmany));
-        CufftErrorCheck(cufftSetStream(plan_i_1ch_all_scales, cudaStreamPerThread));
-    }
-#endif
 }
 
 void cuFFT::set_window(const MatDynMem &window)
 {
+    Fft::set_window(window);
     m_window = window;
 }
 
-void cuFFT::forward(MatDynMem & real_input, ComplexMat & complex_result)
+void cuFFT::forward(const MatDynMem &real_input, ComplexMat &complex_result)
 {
-    if (BIG_BATCH_MODE && real_input.rows == int(m_height * m_num_of_scales)) {
-        CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal *>(real_input.deviceMem()),
-                                     complex_result.get_p_data()));
-    } else {
-        NORMAL_OMP_CRITICAL
-        {
-            CufftErrorCheck(
-                cufftExecR2C(plan_f, reinterpret_cast<cufftReal *>(real_input.deviceMem()), complex_result.get_p_data()));
-            cudaStreamSynchronize(cudaStreamPerThread);
-        }
-    }
-    return;
+    Fft::forward(real_input, complex_result);
+    auto in = static_cast<cufftReal *>(const_cast<MatDynMem&>(real_input).deviceMem());
+
+    CufftErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_p_data()));
 }
 
-void cuFFT::forward_window(MatDynMem &feat, ComplexMat & complex_result, MatDynMem &temp)
+void cuFFT::forward_window(MatDynMem &feat, ComplexMat &complex_result, MatDynMem &temp)
 {
+    Fft::forward_window(feat, complex_result, temp);
+
     uint n_channels = feat.size[0];
     cufftReal *temp_data = temp.deviceMem();
 
@@ -154,33 +92,24 @@ void cuFFT::forward_window(MatDynMem &feat, ComplexMat & complex_result, MatDynM
         cv::Mat temp_plane(temp.dims - 1, temp.size + 1, temp.cv::Mat::type(), temp.ptr(i));
         temp_plane = feat_plane.mul(m_window);
     }
-    CufftErrorCheck(cufftExecR2C((n_channels == m_num_of_feats) ? plan_fw : plan_fw_all_scales,
-                                 temp_data, complex_result.get_p_data()));
+    CufftErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_p_data()));
 }
 
 void cuFFT::inverse(ComplexMat &complex_input, MatDynMem &real_result)
 {
+    Fft::inverse(complex_input, real_result);
+
     uint n_channels = complex_input.n_channels;
     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
     cufftReal *out = real_result.deviceMem();
     float alpha = 1.0 / (m_width * m_height);
-    cufftHandle plan;
 
     if (n_channels == 1) {
         CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
         CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
-        return;
-    } else if (n_channels == m_num_of_scales) {
-        CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, out));
-        CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
-        return;
-    } else if (n_channels == m_num_of_feats * m_num_of_scales) {
-        CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, out));
-        cudaStreamSynchronize(cudaStreamPerThread);
-        return;
+    } else {
+        CufftErrorCheck(cufftExecC2R(plan_i_features, in, out));
     }
-    CufftErrorCheck(cufftExecC2R(plan_i_features, in, out));
-    return;
 }
 
 cuFFT::~cuFFT()
@@ -191,11 +120,4 @@ cuFFT::~cuFFT()
     CufftErrorCheck(cufftDestroy(plan_fw));
     CufftErrorCheck(cufftDestroy(plan_i_1ch));
     CufftErrorCheck(cufftDestroy(plan_i_features));
-
-    if (BIG_BATCH_MODE) {
-        CufftErrorCheck(cufftDestroy(plan_f_all_scales));
-        CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
-        CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
-        CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));
-    }
 }
index 4dd0f1042e1d025ddb17f9118383a0bc10da6b5d..ffbac57635d1ca310f79d766f3cbb6a91faea89b 100644 (file)
@@ -18,10 +18,11 @@ public:
     cuFFT();
     void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales) override;
     void set_window(const MatDynMem &window) override;
-    void forward(MatDynMem & real_input, ComplexMat & complex_result) override;
-    void forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_result, MatDynMem &tmp) override;
-    void inverse(ComplexMat &  complex_input, MatDynMem & real_result) override;
+    void forward(const MatDynMem &real_input, ComplexMat &complex_result) override;
+    void forward_window(MatDynMem &patch_feats_in, ComplexMat &complex_result, MatDynMem &tmp) override;
+    void inverse(ComplexMat &complex_input, MatDynMem &real_result) override;
     ~cuFFT() override;
+
 private:
     cv::Mat m_window;
     unsigned m_width, m_height, m_num_of_feats, m_num_of_scales;
index 3754036e43b065bf0b487f213ab235c4a720e079..8425e6be921053eaff9d22b02fc4f42c949b6a93 100644 (file)
@@ -16,10 +16,7 @@ Fftw::Fftw(){}
 
 void Fftw::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
 {
-    m_width = width;
-    m_height = height;
-    m_num_of_feats = num_of_feats;
-    m_num_of_scales = num_of_scales;
+    Fft::init(width, height, num_of_feats, num_of_scales);
 
 #if (!defined(ASYNC) && !defined(CUFFTW)) && defined(OPENMP)
     fftw_init_threads();
@@ -169,11 +166,14 @@ void Fftw::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned
 
 void Fftw::set_window(const MatDynMem &window)
 {
+    Fft::set_window(window);
     m_window = window;
 }
 
-void Fftw::forward(MatDynMem & real_input, ComplexMat & complex_result)
+void Fftw::forward(MatDynMem &&real_input, ComplexMat & complex_result)
 {
+    Fft::forward(real_input, complex_result);
+
     if (BIG_BATCH_MODE && real_input.rows == int(m_height * m_num_of_scales)) {
         fftwf_execute_dft_r2c(plan_f_all_scales, reinterpret_cast<float *>(real_input.data),
                               reinterpret_cast<fftwf_complex *>(complex_result.get_p_data()));
@@ -186,7 +186,7 @@ void Fftw::forward(MatDynMem & real_input, ComplexMat & complex_result)
 
 void Fftw::forward_window(MatDynMem &feat, ComplexMat & complex_result, MatDynMem &temp)
 {
-    assert(is_patch_feats_valid(feat));
+    Fft::forward_window(feat, complex_result, temp);
 
     int n_channels = feat.size[0];
     for (int i = 0; i < n_channels; ++i) {
@@ -207,6 +207,8 @@ void Fftw::forward_window(MatDynMem &feat, ComplexMat & complex_result, MatDynMe
 
 void Fftw::inverse(ComplexMat &  complex_input, MatDynMem & real_result)
 {
+    Fft::inverse(complex_input, real_result);
+
     int n_channels = complex_input.n_channels;
     fftwf_complex *in = reinterpret_cast<fftwf_complex *>(complex_input.get_p_data());
     float *out = real_result.ptr<float>();
index cb4a901d41d2f4f8494f2ae63fb9f66daecd2edc..d3a8e5493df52caf6222577e135bc13f2fa8aa32 100644 (file)
@@ -20,9 +20,9 @@ public:
     Fftw();
     void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales) override;
     void set_window(const MatDynMem &window) override;
-    void forward(MatDynMem & real_input, ComplexMat & complex_result) override;
-    void forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_result, MatDynMem &tmp) override;
-    void inverse(ComplexMat &  complex_input, MatDynMem & real_result) override;
+    void forward(const MatDynMem &real_input, ComplexMat &complex_result) override;
+    void forward_window(MatDynMem &patch_feats_in, ComplexMat &complex_result, MatDynMem &tmp) override;
+    void inverse(ComplexMat &complex_input, MatDynMem &real_result) override;
     ~Fftw() override;
 private:
     unsigned m_width, m_height, m_num_of_feats, m_num_of_scales;
index bb7675734177ce109125b428253021ba92d83c05..c72eb556db1f545245c52426c92be0c7d07fe904 100644 (file)
@@ -2,10 +2,7 @@
 
 void FftOpencv::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
 {
-    (void)width;
-    (void)height;
-    (void)num_of_feats;
-    (void)num_of_scales;
+    Fft::init(width, height, num_of_feats, num_of_scales);
     std::cout << "FFT: OpenCV" << std::endl;
 }
 
@@ -16,7 +13,7 @@ void FftOpencv::set_window(const MatDynMem &window)
 
 void FftOpencv::forward(const cv::Mat &real_input, ComplexMat &complex_result, float *real_input_arr)
 {
-    (void)real_input_arr;
+    Fft::forward(real_input, complex_result);
 
     cv::Mat tmp;
     cv::dft(real_input, tmp, cv::DFT_COMPLEX_OUTPUT);
@@ -26,6 +23,8 @@ void FftOpencv::forward(const cv::Mat &real_input, ComplexMat &complex_result, f
 
 void FftOpencv::forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_result, MatDynMem &tmp)
 {
+    Fft::forward_window(feat, complex_result, temp);
+
     (void)real_input_arr;
     (void)fw_all;
 
@@ -40,6 +39,8 @@ void FftOpencv::forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_r
 
 void FftOpencv::inverse(ComplexMat &  complex_input, MatDynMem & real_result)
 {
+    Fft::inverse(complex_input, real_result);
+
     (void)real_result_arr;
 
     if (complex_input.n_channels == 1) {
index 5e016db3b39ecf66404ab611288a96b7f75a8fbb..f57906997a1a06a78b7356907b8b21664b03be68 100644 (file)
@@ -9,9 +9,9 @@ class FftOpencv : public Fft
 public:
     void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales) override;
     void set_window(const MatDynMem &window) override;
-    void forward(MatDynMem & real_input, ComplexMat & complex_result) override;
-    void forward_window(MatDynMem &patch_feats_in, ComplexMat & complex_result, MatDynMem &tmp) override;
-    void inverse(ComplexMat &  complex_input, MatDynMem & real_result) override;
+    void forward(const MatDynMem &real_input, ComplexMat &complex_result) override;
+    void forward_window(MatDynMem &patch_feats_in, ComplexMat &complex_result, MatDynMem &tmp) override;
+    void inverse(ComplexMat &complex_input, MatDynMem &real_result) override;
     ~FftOpencv() override;
 private:
     cv::Mat m_window;