3 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales, bool big_batch_mode)
7 m_num_of_feats = num_of_feats;
8 m_num_of_scales = num_of_scales;
9 m_big_batch_mode = big_batch_mode;
11 std::cout << "FFT: cuFFT" << std::endl;
13 //FFT forward one scale
15 CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
18 //FFT forward all scales
19 if(m_num_of_scales > 1 && m_big_batch_mode)
22 int n[] = {(int)m_height, (int)m_width};
23 int howmany = m_num_of_scales;
24 int idist = m_height*m_width, odist = m_height*(m_width/2+1);
25 int istride = 1, ostride = 1;
26 int *inembed = n, onembed[] = {(int)m_height, (int)m_width/2+1};
28 CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n,
29 inembed, istride, idist,
30 onembed, ostride, odist,
34 //FFT forward window one scale
37 int n[] = {int(m_height), int(m_width)};
38 int howmany = int(m_num_of_feats);
39 int idist = int(m_height*m_width), odist = int(m_height*(m_width/2+1));
40 int istride = 1, ostride = 1;
41 int *inembed = n, onembed[] = {int(m_height), int(m_width/2+1)};
43 CufftErrorCheck(cufftPlanMany(&plan_fw, rank, n,
44 inembed, istride, idist,
45 onembed, ostride, odist,
49 //FFT forward window all scales all feats
50 if(m_num_of_scales > 1 && m_big_batch_mode)
53 int n[] = {(int)m_height, (int)m_width};
54 int howmany = m_num_of_scales*m_num_of_feats;
55 int idist = m_height*m_width, odist = m_height*(m_width/2+1);
56 int istride = 1, ostride = 1;
57 int *inembed = n, onembed[] = {(int)m_height, (int)m_width/2+1};
59 CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n,
60 inembed, istride, idist,
61 onembed, ostride, odist,
65 //FFT inverse one scale
68 int n[] = {int(m_height), int(m_width)};
69 int howmany = int(m_num_of_feats);
70 int idist = int(m_height*(m_width/2+1)), odist = 1;
71 int istride = 1, ostride = int(m_num_of_feats);
72 int inembed[] = {int(m_height), int(m_width/2+1)}, *onembed = n;
74 CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n,
75 inembed, istride, idist,
76 onembed, ostride, odist,
79 //FFT inverse all scales
81 if(m_num_of_scales > 1 && m_big_batch_mode)
84 int n[] = {(int)m_height, (int)m_width};
85 int howmany = m_num_of_feats*m_num_of_scales;
86 int idist = m_height*(m_width/2+1), odist = 1;
87 int istride = 1, ostride = m_num_of_feats*m_num_of_scales;
88 int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
90 CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n,
91 inembed, istride, idist,
92 onembed, ostride, odist,
96 //FFT inverse one channel one scale
99 int n[] = {int(m_height), int(m_width)};
101 int idist = int(m_height*(m_width/2+1)), odist = 1;
102 int istride = 1, ostride = 1;
103 int inembed[] = {int(m_height), int(m_width/2+1)}, *onembed = n;
105 CufftErrorCheck(cufftPlanMany(&plan_i_1ch, rank, n,
106 inembed, istride, idist,
107 onembed, ostride, odist,
108 CUFFT_C2R, howmany));
111 //FFT inverse one channel all scales
112 if(m_num_of_scales > 1 && m_big_batch_mode)
115 int n[] = {(int)m_height, (int)m_width};
116 int howmany = m_num_of_scales;
117 int idist = m_height*(m_width/2+1), odist = 1;
118 int istride = 1, ostride = m_num_of_scales;
119 int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
121 CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n,
122 inembed, istride, idist,
123 onembed, ostride, odist,
124 CUFFT_C2R, howmany));
129 void cuFFT::set_window(const cv::Mat & window)
134 void cuFFT::forward(const cv::Mat & real_input, ComplexMat & complex_result, float *real_input_arr, cudaStream_t stream)
138 if(m_big_batch_mode && real_input.rows == int(m_height*m_num_of_scales)){
139 CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal*>(real_input_arr),
140 complex_result.get_p_data()));
142 CufftErrorCheck(cufftSetStream(plan_f, stream));
143 CufftErrorCheck(cufftExecR2C(plan_f, reinterpret_cast<cufftReal*>(real_input_arr),
144 complex_result.get_p_data()));
149 void cuFFT::forward_window(std::vector<cv::Mat> patch_feats, ComplexMat & complex_result, cv::Mat & fw_all, float *real_input_arr, cudaStream_t stream)
151 int n_channels = int(patch_feats.size());
153 if(n_channels > int(m_num_of_feats)){
154 for (uint i = 0; i < uint(n_channels); ++i) {
155 cv::Mat in_roi(fw_all, cv::Rect(0, int(i*m_height), int(m_width), int(m_height)));
156 in_roi = patch_feats[i].mul(m_window);
158 CufftErrorCheck(cufftExecR2C(plan_fw_all_scales, reinterpret_cast<cufftReal*>(real_input_arr), complex_result.get_p_data()));
160 for (uint i = 0; i < uint(n_channels); ++i) {
161 cv::Mat in_roi(fw_all, cv::Rect(0, int(i*m_height), int(m_width), int(m_height)));
162 in_roi = patch_feats[i].mul(m_window);
164 CufftErrorCheck(cufftSetStream(plan_fw, stream));
165 CufftErrorCheck(cufftExecR2C(plan_fw, reinterpret_cast<cufftReal*>(real_input_arr), complex_result.get_p_data()));
170 void cuFFT::inverse(ComplexMat & complex_input, cv::Mat & real_result, float *real_result_arr, cudaStream_t stream)
172 int n_channels = complex_input.n_channels;
173 cufftComplex *in = reinterpret_cast<cufftComplex*>(complex_input.get_p_data());
176 CufftErrorCheck(cufftSetStream(plan_i_1ch, stream));
177 CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, reinterpret_cast<cufftReal*>(real_result_arr)));
178 cudaStreamSynchronize(stream);
179 real_result = real_result/(m_width*m_height);
181 } else if(n_channels == int(m_num_of_scales)){
182 CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, reinterpret_cast<cufftReal*>(real_result_arr)));
183 cudaStreamSynchronize(stream);
185 real_result = real_result/(m_width*m_height);
187 } else if(n_channels == int(m_num_of_feats) * int(m_num_of_scales)){
188 CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, reinterpret_cast<cufftReal*>(real_result_arr)));
191 CufftErrorCheck(cufftSetStream(plan_i_features, stream));
192 CufftErrorCheck(cufftExecC2R(plan_i_features, in, reinterpret_cast<cufftReal*>(real_result_arr)));
198 CufftErrorCheck(cufftDestroy(plan_f));
199 CufftErrorCheck(cufftDestroy(plan_fw));
200 CufftErrorCheck(cufftDestroy(plan_i_1ch));
201 CufftErrorCheck(cufftDestroy(plan_i_features));
203 if (m_big_batch_mode) {
204 CufftErrorCheck(cufftDestroy(plan_f_all_scales));
205 CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
206 CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
207 CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));