]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
Simplify cuFFT class
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2
3 cuFFT::cuFFT()
4 {
5     cudaErrorCheck(cublasCreate(&cublas));
6 }
7
8 cufftHandle cuFFT::create_plan_fwd(uint howmany) const
9 {
10     int rank = 2;
11     int n[] = {(int)m_height, (int)m_width};
12     int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
13     int istride = 1, ostride = 1;
14     int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
15
16     cufftHandle plan;
17     cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
18     cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
19     return plan;
20 }
21
22 cufftHandle cuFFT::create_plan_inv(uint howmany) const
23 {
24     int rank = 2;
25     int n[] = {(int)m_height, (int)m_width};
26     int idist = m_height * (m_width / 2 + 1), odist = m_height * m_width;
27     int istride = 1, ostride = 1;
28     int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
29
30     cufftHandle plan;
31     cudaErrorCheck(cufftPlanMany(&plan, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
32     cudaErrorCheck(cufftSetStream(plan, cudaStreamPerThread));
33     return plan;
34 }
35
36
37 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
38 {
39     Fft::init(width, height, num_of_feats, num_of_scales);
40
41     std::cout << "FFT: cuFFT" << std::endl;
42
43     plan_f = create_plan_fwd(1);
44     plan_fw = create_plan_fwd(m_num_of_feats);
45     plan_i_1ch = create_plan_inv(1);
46
47 #ifdef BIG_BATCH
48     plan_f_all_scales = create_plan_fwd(m_num_of_scales);
49     plan_fw_all_scales = create_plan_fwd(m_num_of_scales * m_num_of_feats);
50     plan_i_all_scales = create_plan_inv(m_num_of_scales);
51 #endif
52 }
53
54 void cuFFT::set_window(const MatDynMem &window)
55 {
56     Fft::set_window(window);
57     m_window = window;
58 }
59
60 void cuFFT::forward(const MatScales &real_input, ComplexMat &complex_result)
61 {
62     Fft::forward(real_input, complex_result);
63     auto in = static_cast<cufftReal *>(const_cast<MatScales&>(real_input).deviceMem());
64
65     if (real_input.size[0] == 1)
66         cudaErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_p_data()));
67 #ifdef BIG_BATCH
68     else
69         cudaErrorCheck(cufftExecR2C(plan_f_all_scales, in, complex_result.get_p_data()));
70 #endif
71 }
72
73 void cuFFT::forward_window(MatScaleFeats &feat, ComplexMat &complex_result, MatScaleFeats &temp)
74 {
75     Fft::forward_window(feat, complex_result, temp);
76
77     cufftReal *temp_data = temp.deviceMem();
78     uint n_scales = feat.size[0];
79
80     for (uint s = 0; s < n_scales; ++s) {
81         for (uint ch = 0; ch < uint(feat.size[1]); ++ch) {
82             cv::Mat feat_plane = feat.plane(s, ch);
83             cv::Mat temp_plane = temp.plane(s, ch);
84             temp_plane = feat_plane.mul(m_window);
85         }
86     }
87
88     if (n_scales == 1)
89         cudaErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_p_data()));
90 #ifdef BIG_BATCH
91     else
92         cudaErrorCheck(cufftExecR2C(plan_fw_all_scales, temp_data, complex_result.get_p_data()));
93 #endif
94 }
95
96 void cuFFT::inverse(ComplexMat &complex_input, MatScales &real_result)
97 {
98     Fft::inverse(complex_input, real_result);
99
100     uint n_channels = complex_input.n_channels;
101     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
102     cufftReal *out = real_result.deviceMem();
103     float alpha = 1.0 / (m_width * m_height);
104
105     if (n_channels == 1)
106         cudaErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
107 #ifdef BIG_BATCH
108         cudaErrorCheck(cufftExecC2R(plan_i_all_scales, in, out));
109 #endif
110     // TODO: Investigate whether this scalling is needed or not
111     cudaErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
112 }
113
114 cuFFT::~cuFFT()
115 {
116     cudaErrorCheck(cublasDestroy(cublas));
117
118     cudaErrorCheck(cufftDestroy(plan_f));
119     cudaErrorCheck(cufftDestroy(plan_fw));
120     cudaErrorCheck(cufftDestroy(plan_i_1ch));
121
122 #ifdef BIG_BATCH
123     cudaErrorCheck(cufftDestroy(plan_f_all_scales));
124     cudaErrorCheck(cufftDestroy(plan_fw_all_scales));
125     cudaErrorCheck(cufftDestroy(plan_i_all_scales));
126 #endif
127 }