]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blobdiff - src/fft_cufft.cpp
Add cudaStreamSynchronize after FFT
[hercules2020/kcf.git] / src / fft_cufft.cpp
index 329a26edee41c0cbf73a7fb42218dba7b7361a94..e551eaa41726690a1bce8d42b1f415931e691ff6 100644 (file)
@@ -1,66 +1,55 @@
 #include "fft_cufft.h"
-#include <cublas_v2.h>
 
 cuFFT::cuFFT()
 {
-    CublasErrorCheck(cublasCreate(&cublas));
+    CudaSafeCall(cudaSetDeviceFlags(cudaDeviceMapHost));
+    cudaErrorCheck(cublasCreate(&cublas));
 }
 
+cufftHandle cuFFT::create_plan_fwd(uint howmany) const
+{
+    int rank = 2;
+    int n[] = {(int)m_height, (int)m_width};
+    int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
+    int istride = 1, ostride = 1;
+    int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
+
+    cufftHandle plan;
+    cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
+    cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
+    return plan;
+}
+
+cufftHandle cuFFT::create_plan_inv(uint howmany) const
+{
+    int rank = 2;
+    int n[] = {(int)m_height, (int)m_width};
+    int idist = m_height * (m_width / 2 + 1), odist = m_height * m_width;
+    int istride = 1, ostride = 1;
+    int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
+
+    cufftHandle plan;
+    cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
+    cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
+    return plan;
+}
+
+
 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
 {
     Fft::init(width, height, num_of_feats, num_of_scales);
 
     std::cout << "FFT: cuFFT" << std::endl;
 
-    // FFT forward
-    {
-        int rank = 2;
-        int n[] = {int(m_height), int(m_width)};
-        int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
-        int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
-        int istride = 1, ostride = 1;
-        int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
-
-        CufftErrorCheck(cufftPlanMany(&plan_f, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
-        CufftErrorCheck(cufftSetStream(plan_f, cudaStreamPerThread));
-    }
+    plan_f = create_plan_fwd(1);
+    plan_fw = create_plan_fwd(m_num_of_feats);
+    plan_i_1ch = create_plan_inv(1);
 
-    // FFT forward window
-    {
-        int rank = 2;
-        int n[] = {int(m_height), int(m_width)};
-        int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
-        int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
-        int istride = 1, ostride = 1;
-        int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
-
-        CufftErrorCheck(cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
-        CufftErrorCheck(cufftSetStream(plan_fw, cudaStreamPerThread));
-    }
-    // FFT inverse all channels
-    {
-        int rank = 2;
-        int n[] = {int(m_height), int(m_width)};
-        int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
-        int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
-        int istride = 1, ostride = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
-        int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
-
-        CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
-        CufftErrorCheck(cufftSetStream(plan_i_features, cudaStreamPerThread));
-    }
-    // FFT inverse one channel
-    {
-        int rank = 2;
-        int n[] = {int(m_height), int(m_width)};
-        int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
-        int idist = m_height * (m_width / 2 + 1), odist = 1;
-        int istride = 1, ostride = IF_BIG_BATCH(m_num_of_scales, 1);
-        int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
-
-        CufftErrorCheck(cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
-        CufftErrorCheck(cufftSetStream(plan_i_1ch, cudaStreamPerThread));
-    }
+#ifdef BIG_BATCH
+    plan_f_all_scales = create_plan_fwd(m_num_of_scales);
+    plan_fw_all_scales = create_plan_fwd(m_num_of_scales * m_num_of_feats);
+    plan_i_all_scales = create_plan_inv(m_num_of_scales);
+#endif
 }
 
 void cuFFT::set_window(const MatDynMem &window)
@@ -69,30 +58,44 @@ void cuFFT::set_window(const MatDynMem &window)
     m_window = window;
 }
 
-void cuFFT::forward(const MatDynMem &real_input, ComplexMat &complex_result)
+void cuFFT::forward(const MatScales &real_input, ComplexMat &complex_result)
 {
     Fft::forward(real_input, complex_result);
-    auto in = static_cast<cufftReal *>(const_cast<MatDynMem&>(real_input).deviceMem());
-
-    CufftErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_p_data()));
+    auto in = static_cast<cufftReal *>(const_cast<MatScales&>(real_input).deviceMem());
+
+    if (real_input.size[0] == 1)
+        cudaErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_dev_data()));
+#ifdef BIG_BATCH
+    else
+        cudaErrorCheck(cufftExecR2C(plan_f_all_scales, in, complex_result.get_dev_data()));
+#endif
 }
 
-void cuFFT::forward_window(MatDynMem &feat, ComplexMat &complex_result, MatDynMem &temp)
+void cuFFT::forward_window(MatScaleFeats &feat, ComplexMat &complex_result, MatScaleFeats &temp)
 {
     Fft::forward_window(feat, complex_result, temp);
 
-    uint n_channels = feat.size[0];
     cufftReal *temp_data = temp.deviceMem();
-
-    for (uint i = 0; i < n_channels; ++i) {
-        cv::Mat feat_plane = feat.plane(i);
-        cv::Mat temp_plane = temp.plane(i);
-        temp_plane = feat_plane.mul(m_window);
+    uint n_scales = feat.size[0];
+
+    for (uint s = 0; s < n_scales; ++s) {
+        for (uint ch = 0; ch < uint(feat.size[1]); ++ch) {
+            cv::Mat feat_plane = feat.plane(s, ch);
+            cv::Mat temp_plane = temp.plane(s, ch);
+            temp_plane = feat_plane.mul(m_window);
+        }
     }
-    CufftErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_p_data()));
+
+    if (n_scales == 1)
+        cudaErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_dev_data()));
+#ifdef BIG_BATCH
+    else
+        cudaErrorCheck(cufftExecR2C(plan_fw_all_scales, temp_data, complex_result.get_dev_data()));
+#endif
+    CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
 }
 
-void cuFFT::inverse(ComplexMat &complex_input, MatDynMem &real_result)
+void cuFFT::inverse(ComplexMat &complex_input, MatScales &real_result)
 {
     Fft::inverse(complex_input, real_result);
 
@@ -101,21 +104,27 @@ void cuFFT::inverse(ComplexMat &complex_input, MatDynMem &real_result)
     cufftReal *out = real_result.deviceMem();
     float alpha = 1.0 / (m_width * m_height);
 
-    if (n_channels == 1) {
-        CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
-    } else {
-        CufftErrorCheck(cufftExecC2R(plan_i_features, in, out));
-    }
+    if (n_channels == 1)
+        cudaErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
+#ifdef BIG_BATCH
+        cudaErrorCheck(cufftExecC2R(plan_i_all_scales, in, out));
+#endif
     // TODO: Investigate whether this scalling is needed or not
-    CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
+    cudaErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
+    CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
 }
 
 cuFFT::~cuFFT()
 {
-    CublasErrorCheck(cublasDestroy(cublas));
+    cudaErrorCheck(cublasDestroy(cublas));
+
+    cudaErrorCheck(cufftDestroy(plan_f));
+    cudaErrorCheck(cufftDestroy(plan_fw));
+    cudaErrorCheck(cufftDestroy(plan_i_1ch));
 
-    CufftErrorCheck(cufftDestroy(plan_f));
-    CufftErrorCheck(cufftDestroy(plan_fw));
-    CufftErrorCheck(cufftDestroy(plan_i_1ch));
-    CufftErrorCheck(cufftDestroy(plan_i_features));
+#ifdef BIG_BATCH
+    cudaErrorCheck(cufftDestroy(plan_f_all_scales));
+    cudaErrorCheck(cufftDestroy(plan_fw_all_scales));
+    cudaErrorCheck(cufftDestroy(plan_i_all_scales));
+#endif
 }