]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
Add cudaStreamSynchronize after FFT
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2
3 cuFFT::cuFFT()
4 {
5     CudaSafeCall(cudaSetDeviceFlags(cudaDeviceMapHost));
6     cudaErrorCheck(cublasCreate(&cublas));
7 }
8
9 cufftHandle cuFFT::create_plan_fwd(uint howmany) const
10 {
11     int rank = 2;
12     int n[] = {(int)m_height, (int)m_width};
13     int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
14     int istride = 1, ostride = 1;
15     int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
16
17     cufftHandle plan;
18     cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
19     cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
20     return plan;
21 }
22
23 cufftHandle cuFFT::create_plan_inv(uint howmany) const
24 {
25     int rank = 2;
26     int n[] = {(int)m_height, (int)m_width};
27     int idist = m_height * (m_width / 2 + 1), odist = m_height * m_width;
28     int istride = 1, ostride = 1;
29     int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
30
31     cufftHandle plan;
32     cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
33     cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
34     return plan;
35 }
36
37
38 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
39 {
40     Fft::init(width, height, num_of_feats, num_of_scales);
41
42     std::cout << "FFT: cuFFT" << std::endl;
43
44     plan_f = create_plan_fwd(1);
45     plan_fw = create_plan_fwd(m_num_of_feats);
46     plan_i_1ch = create_plan_inv(1);
47
48 #ifdef BIG_BATCH
49     plan_f_all_scales = create_plan_fwd(m_num_of_scales);
50     plan_fw_all_scales = create_plan_fwd(m_num_of_scales * m_num_of_feats);
51     plan_i_all_scales = create_plan_inv(m_num_of_scales);
52 #endif
53 }
54
55 void cuFFT::set_window(const MatDynMem &window)
56 {
57     Fft::set_window(window);
58     m_window = window;
59 }
60
61 void cuFFT::forward(const MatScales &real_input, ComplexMat &complex_result)
62 {
63     Fft::forward(real_input, complex_result);
64     auto in = static_cast<cufftReal *>(const_cast<MatScales&>(real_input).deviceMem());
65
66     if (real_input.size[0] == 1)
67         cudaErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_dev_data()));
68 #ifdef BIG_BATCH
69     else
70         cudaErrorCheck(cufftExecR2C(plan_f_all_scales, in, complex_result.get_dev_data()));
71 #endif
72 }
73
74 void cuFFT::forward_window(MatScaleFeats &feat, ComplexMat &complex_result, MatScaleFeats &temp)
75 {
76     Fft::forward_window(feat, complex_result, temp);
77
78     cufftReal *temp_data = temp.deviceMem();
79     uint n_scales = feat.size[0];
80
81     for (uint s = 0; s < n_scales; ++s) {
82         for (uint ch = 0; ch < uint(feat.size[1]); ++ch) {
83             cv::Mat feat_plane = feat.plane(s, ch);
84             cv::Mat temp_plane = temp.plane(s, ch);
85             temp_plane = feat_plane.mul(m_window);
86         }
87     }
88
89     if (n_scales == 1)
90         cudaErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_dev_data()));
91 #ifdef BIG_BATCH
92     else
93         cudaErrorCheck(cufftExecR2C(plan_fw_all_scales, temp_data, complex_result.get_dev_data()));
94 #endif
95     CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
96 }
97
98 void cuFFT::inverse(ComplexMat &complex_input, MatScales &real_result)
99 {
100     Fft::inverse(complex_input, real_result);
101
102     uint n_channels = complex_input.n_channels;
103     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
104     cufftReal *out = real_result.deviceMem();
105     float alpha = 1.0 / (m_width * m_height);
106
107     if (n_channels == 1)
108         cudaErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
109 #ifdef BIG_BATCH
110         cudaErrorCheck(cufftExecC2R(plan_i_all_scales, in, out));
111 #endif
112     // TODO: Investigate whether this scalling is needed or not
113     cudaErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
114     CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
115 }
116
117 cuFFT::~cuFFT()
118 {
119     cudaErrorCheck(cublasDestroy(cublas));
120
121     cudaErrorCheck(cufftDestroy(plan_f));
122     cudaErrorCheck(cufftDestroy(plan_fw));
123     cudaErrorCheck(cufftDestroy(plan_i_1ch));
124
125 #ifdef BIG_BATCH
126     cudaErrorCheck(cufftDestroy(plan_f_all_scales));
127     cudaErrorCheck(cufftDestroy(plan_fw_all_scales));
128     cudaErrorCheck(cufftDestroy(plan_i_all_scales));
129 #endif
130 }