]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
4b45b99c72b76f1b0443058d8be0ddff738dc086
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2
3 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales, bool big_batch_mode)
4 {
5     m_width = width;
6     m_height = height;
7     m_num_of_feats = num_of_feats;
8     m_num_of_scales = num_of_scales;
9     m_big_batch_mode = big_batch_mode;
10
11     std::cout << "FFT: cuFFT" << std::endl;
12
13     // FFT forward one scale
14     {
15         CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
16     }
17 #ifdef BIG_BATCH
18     // FFT forward all scales
19     if (m_num_of_scales > 1 && m_big_batch_mode) {
20         int rank = 2;
21         int n[] = {(int)m_height, (int)m_width};
22         int howmany = m_num_of_scales;
23         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
24         int istride = 1, ostride = 1;
25         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
26
27         CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
28                                       CUFFT_R2C, howmany));
29     }
30 #endif
31     // FFT forward window one scale
32     {
33         int rank = 2;
34         int n[] = {int(m_height), int(m_width)};
35         int howmany = int(m_num_of_feats);
36         int idist = int(m_height * m_width), odist = int(m_height * (m_width / 2 + 1));
37         int istride = 1, ostride = 1;
38         int *inembed = n, onembed[] = {int(m_height), int(m_width / 2 + 1)};
39
40         CufftErrorCheck(
41             cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
42     }
43 #ifdef BIG_BATCH
44     // FFT forward window all scales all feats
45     if (m_num_of_scales > 1 && m_big_batch_mode) {
46         int rank = 2;
47         int n[] = {(int)m_height, (int)m_width};
48         int howmany = m_num_of_scales * m_num_of_feats;
49         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
50         int istride = 1, ostride = 1;
51         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
52
53         CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
54                                       CUFFT_R2C, howmany));
55     }
56 #endif
57     // FFT inverse one scale
58     {
59         int rank = 2;
60         int n[] = {int(m_height), int(m_width)};
61         int howmany = int(m_num_of_feats);
62         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
63         int istride = 1, ostride = int(m_num_of_feats);
64         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
65
66         CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist,
67                                       CUFFT_C2R, howmany));
68     }
69     // FFT inverse all scales
70 #ifdef BIG_BATCH
71     if (m_num_of_scales > 1 && m_big_batch_mode) {
72         int rank = 2;
73         int n[] = {(int)m_height, (int)m_width};
74         int howmany = m_num_of_feats * m_num_of_scales;
75         int idist = m_height * (m_width / 2 + 1), odist = 1;
76         int istride = 1, ostride = m_num_of_feats * m_num_of_scales;
77         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
78
79         CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n, inembed, istride, idist, onembed, ostride,
80                                       odist, CUFFT_C2R, howmany));
81     }
82 #endif
83     // FFT inverse one channel one scale
84     {
85         int rank = 2;
86         int n[] = {int(m_height), int(m_width)};
87         int howmany = 1;
88         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
89         int istride = 1, ostride = 1;
90         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
91
92         CufftErrorCheck(
93             cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
94     }
95 #ifdef BIG_BATCH
96     // FFT inverse one channel all scales
97     if (m_num_of_scales > 1 && m_big_batch_mode) {
98         int rank = 2;
99         int n[] = {(int)m_height, (int)m_width};
100         int howmany = m_num_of_scales;
101         int idist = m_height * (m_width / 2 + 1), odist = 1;
102         int istride = 1, ostride = m_num_of_scales;
103         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
104
105         CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
106                                       CUFFT_C2R, howmany));
107     }
108 #endif
109 }
110
111 void cuFFT::set_window(const cv::Mat &window)
112 {
113     m_window = window;
114 }
115
116 void cuFFT::forward(const cv::Mat &real_input, ComplexMat &complex_result, float *real_input_arr, cudaStream_t stream)
117 {
118     (void)real_input;
119
120     if (m_big_batch_mode && real_input.rows == int(m_height * m_num_of_scales)) {
121         CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal *>(real_input_arr),
122                                      complex_result.get_p_data()));
123     } else {
124 #pragma omp critical
125         {
126             CufftErrorCheck(cufftSetStream(plan_f, stream));
127             CufftErrorCheck(
128                 cufftExecR2C(plan_f, reinterpret_cast<cufftReal *>(real_input_arr), complex_result.get_p_data()));
129             cudaStreamSynchronize(stream);
130         }
131     }
132     return;
133 }
134
135 void cuFFT::forward_window(std::vector<cv::Mat> patch_feats, ComplexMat &complex_result, cv::Mat &fw_all,
136                            float *real_input_arr, cudaStream_t stream)
137 {
138     int n_channels = int(patch_feats.size());
139
140     if (n_channels > int(m_num_of_feats)) {
141         for (uint i = 0; i < uint(n_channels); ++i) {
142             cv::Mat in_roi(fw_all, cv::Rect(0, int(i * m_height), int(m_width), int(m_height)));
143             in_roi = patch_feats[i].mul(m_window);
144         }
145         CufftErrorCheck(cufftExecR2C(plan_fw_all_scales, reinterpret_cast<cufftReal *>(real_input_arr),
146                                      complex_result.get_p_data()));
147     } else {
148         for (uint i = 0; i < uint(n_channels); ++i) {
149             cv::Mat in_roi(fw_all, cv::Rect(0, int(i * m_height), int(m_width), int(m_height)));
150             in_roi = patch_feats[i].mul(m_window);
151         }
152 #pragma omp critical
153         {
154             CufftErrorCheck(cufftSetStream(plan_fw, stream));
155             CufftErrorCheck(
156                 cufftExecR2C(plan_fw, reinterpret_cast<cufftReal *>(real_input_arr), complex_result.get_p_data()));
157             cudaStreamSynchronize(stream);
158         }
159     }
160     return;
161 }
162
163 void cuFFT::inverse(ComplexMat &complex_input, cv::Mat &real_result, float *real_result_arr, cudaStream_t stream)
164 {
165     int n_channels = complex_input.n_channels;
166     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
167
168     if (n_channels == 1) {
169 #pragma omp critical
170         {
171             CufftErrorCheck(cufftSetStream(plan_i_1ch, stream));
172             CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, reinterpret_cast<cufftReal *>(real_result_arr)));
173             cudaStreamSynchronize(stream);
174         }
175         real_result = real_result / (m_width * m_height);
176         return;
177     } else if (n_channels == int(m_num_of_scales)) {
178         CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, reinterpret_cast<cufftReal *>(real_result_arr)));
179         cudaStreamSynchronize(stream);
180
181         real_result = real_result / (m_width * m_height);
182         return;
183     } else if (n_channels == int(m_num_of_feats) * int(m_num_of_scales)) {
184         CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, reinterpret_cast<cufftReal *>(real_result_arr)));
185         return;
186     }
187 #pragma omp critical
188     {
189         CufftErrorCheck(cufftSetStream(plan_i_features, stream));
190         CufftErrorCheck(cufftExecC2R(plan_i_features, in, reinterpret_cast<cufftReal *>(real_result_arr)));
191 #if defined(OPENMP) && !defined(BIG_BATCH)
192         cudaStreamSynchronize(stream);
193 #endif
194     }
195     return;
196 }
197
198 cuFFT::~cuFFT()
199 {
200     CufftErrorCheck(cufftDestroy(plan_f));
201     CufftErrorCheck(cufftDestroy(plan_fw));
202     CufftErrorCheck(cufftDestroy(plan_i_1ch));
203     CufftErrorCheck(cufftDestroy(plan_i_features));
204
205     if (m_big_batch_mode) {
206         CufftErrorCheck(cufftDestroy(plan_f_all_scales));
207         CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
208         CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
209         CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));
210     }
211 }