]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blobdiff - src/fft_cufft.h
Do not use virtual methods in FFT class
[hercules2020/kcf.git] / src / fft_cufft.h
index 367b95a6e11876293dcbe574a2156d5cfe537000..4241c0664f7cbe86bf4a836e305a92250110b34f 100644 (file)
@@ -1,12 +1,12 @@
 #ifndef FFT_CUDA_H
 #define FFT_CUDA_H
 
-
 #include <cufft.h>
 #include <cuda_runtime.h>
+#include <cublas_v2.h>
 
 #include "fft.h"
-#include "cuda/cuda_error_check.cuh"
+#include "cuda_error_check.hpp"
 #include "pragmas.h"
 
 struct ThreadCtx;
@@ -14,17 +14,25 @@ struct ThreadCtx;
 class cuFFT : public Fft
 {
 public:
-    void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales) override;
-    void set_window(const cv::Mat & window) override;
-    void forward(const cv::Mat & real_input, ComplexMat & complex_result, float *real_input_arr) override;
-    void forward_window(std::vector<cv::Mat> patch_feats, ComplexMat & complex_result, cv::Mat & fw_all, float *real_input_arr) override;
-    void inverse(ComplexMat &  complex_input, cv::Mat & real_result, float *real_result_arr) override;
-    ~cuFFT() override;
+    cuFFT();
+    void init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales);
+    void set_window(const MatDynMem &window);
+    void forward(const MatScales &real_input, ComplexMat &complex_result);
+    void forward_window(MatScaleFeats &patch_feats_in, ComplexMat &complex_result, MatScaleFeats &tmp);
+    void inverse(ComplexMat &complex_input, MatScales &real_result);
+    ~cuFFT();
+
+protected:
+    cufftHandle create_plan_fwd(uint howmany) const;
+    cufftHandle create_plan_inv(uint howmany) const;
+
 private:
     cv::Mat m_window;
-    unsigned m_width, m_height, m_num_of_feats, m_num_of_scales;
-    cufftHandle plan_f, plan_f_all_scales, plan_fw, plan_fw_all_scales, plan_i_features,
-     plan_i_features_all_scales, plan_i_1ch, plan_i_1ch_all_scales;
+    cufftHandle plan_f, plan_fw, plan_i_1ch;
+#ifdef BIG_BATCH
+    cufftHandle plan_f_all_scales, plan_fw_all_scales, plan_i_all_scales;
+#endif
+    cublasHandle_t cublas;
 };
 
 #endif // FFT_CUDA_H