]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
Fix indentation
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2
3 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
4 {
5     m_width = width;
6     m_height = height;
7     m_num_of_feats = num_of_feats;
8     m_num_of_scales = num_of_scales;
9
10     std::cout << "FFT: cuFFT" << std::endl;
11
12     // FFT forward one scale
13     {
14         CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
15     }
16 #ifdef BIG_BATCH
17     // FFT forward all scales
18     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
19         int rank = 2;
20         int n[] = {(int)m_height, (int)m_width};
21         int howmany = m_num_of_scales;
22         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
23         int istride = 1, ostride = 1;
24         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
25
26         CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
27                                       CUFFT_R2C, howmany));
28     }
29 #endif
30     // FFT forward window one scale
31     {
32         int rank = 2;
33         int n[] = {int(m_height), int(m_width)};
34         int howmany = int(m_num_of_feats);
35         int idist = int(m_height * m_width), odist = int(m_height * (m_width / 2 + 1));
36         int istride = 1, ostride = 1;
37         int *inembed = n, onembed[] = {int(m_height), int(m_width / 2 + 1)};
38
39         CufftErrorCheck(
40             cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
41     }
42 #ifdef BIG_BATCH
43     // FFT forward window all scales all feats
44     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
45         int rank = 2;
46         int n[] = {(int)m_height, (int)m_width};
47         int howmany = m_num_of_scales * m_num_of_feats;
48         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
49         int istride = 1, ostride = 1;
50         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
51
52         CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
53                                       CUFFT_R2C, howmany));
54     }
55 #endif
56     // FFT inverse one scale
57     {
58         int rank = 2;
59         int n[] = {int(m_height), int(m_width)};
60         int howmany = int(m_num_of_feats);
61         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
62         int istride = 1, ostride = int(m_num_of_feats);
63         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
64
65         CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist,
66                                       CUFFT_C2R, howmany));
67     }
68     // FFT inverse all scales
69 #ifdef BIG_BATCH
70     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
71         int rank = 2;
72         int n[] = {(int)m_height, (int)m_width};
73         int howmany = m_num_of_feats * m_num_of_scales;
74         int idist = m_height * (m_width / 2 + 1), odist = 1;
75         int istride = 1, ostride = m_num_of_feats * m_num_of_scales;
76         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
77
78         CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n, inembed, istride, idist, onembed, ostride,
79                                       odist, CUFFT_C2R, howmany));
80     }
81 #endif
82     // FFT inverse one channel one scale
83     {
84         int rank = 2;
85         int n[] = {int(m_height), int(m_width)};
86         int howmany = 1;
87         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
88         int istride = 1, ostride = 1;
89         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
90
91         CufftErrorCheck(
92             cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
93     }
94 #ifdef BIG_BATCH
95     // FFT inverse one channel all scales
96     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
97         int rank = 2;
98         int n[] = {(int)m_height, (int)m_width};
99         int howmany = m_num_of_scales;
100         int idist = m_height * (m_width / 2 + 1), odist = 1;
101         int istride = 1, ostride = m_num_of_scales;
102         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
103
104         CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
105                                       CUFFT_C2R, howmany));
106     }
107 #endif
108 }
109
110 void cuFFT::set_window(const cv::Mat &window)
111 {
112     m_window = window;
113 }
114
115 void cuFFT::forward(const cv::Mat &real_input, ComplexMat &complex_result, float *real_input_arr, cudaStream_t stream)
116 {
117     if (BIG_BATCH_MODE && real_input.rows == int(m_height * m_num_of_scales)) {
118         CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal *>(real_input_arr),
119                                      complex_result.get_p_data()));
120     } else {
121         NORMAL_OMP_CRITICAL
122         {
123             CufftErrorCheck(cufftSetStream(plan_f, stream));
124             CufftErrorCheck(
125                 cufftExecR2C(plan_f, reinterpret_cast<cufftReal *>(real_input_arr), complex_result.get_p_data()));
126             cudaStreamSynchronize(stream);
127         }
128     }
129     return;
130 }
131
132 void cuFFT::forward_window(std::vector<cv::Mat> patch_feats, ComplexMat &complex_result, cv::Mat &fw_all,
133                            float *real_input_arr, cudaStream_t stream)
134 {
135     int n_channels = int(patch_feats.size());
136
137     if (n_channels > int(m_num_of_feats)) {
138         for (uint i = 0; i < uint(n_channels); ++i) {
139             cv::Mat in_roi(fw_all, cv::Rect(0, int(i * m_height), int(m_width), int(m_height)));
140             in_roi = patch_feats[i].mul(m_window);
141         }
142         CufftErrorCheck(cufftExecR2C(plan_fw_all_scales, reinterpret_cast<cufftReal *>(real_input_arr),
143                                      complex_result.get_p_data()));
144     } else {
145         for (uint i = 0; i < uint(n_channels); ++i) {
146             cv::Mat in_roi(fw_all, cv::Rect(0, int(i * m_height), int(m_width), int(m_height)));
147             in_roi = patch_feats[i].mul(m_window);
148         }
149         NORMAL_OMP_CRITICAL
150         {
151             CufftErrorCheck(cufftSetStream(plan_fw, stream));
152             CufftErrorCheck(
153                 cufftExecR2C(plan_fw, reinterpret_cast<cufftReal *>(real_input_arr), complex_result.get_p_data()));
154             cudaStreamSynchronize(stream);
155         }
156     }
157     return;
158 }
159
160 void cuFFT::inverse(ComplexMat &complex_input, cv::Mat &real_result, float *real_result_arr, cudaStream_t stream)
161 {
162     int n_channels = complex_input.n_channels;
163     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
164
165     if (n_channels == 1) {
166         NORMAL_OMP_CRITICAL
167         {
168             CufftErrorCheck(cufftSetStream(plan_i_1ch, stream));
169             CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, reinterpret_cast<cufftReal *>(real_result_arr)));
170             cudaStreamSynchronize(stream);
171         }
172         real_result = real_result / (m_width * m_height);
173         return;
174     } else if (n_channels == int(m_num_of_scales)) {
175         CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, reinterpret_cast<cufftReal *>(real_result_arr)));
176         cudaStreamSynchronize(stream);
177
178         real_result = real_result / (m_width * m_height);
179         return;
180     } else if (n_channels == int(m_num_of_feats) * int(m_num_of_scales)) {
181         CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, reinterpret_cast<cufftReal *>(real_result_arr)));
182         return;
183     }
184     NORMAL_OMP_CRITICAL
185     {
186         CufftErrorCheck(cufftSetStream(plan_i_features, stream));
187         CufftErrorCheck(cufftExecC2R(plan_i_features, in, reinterpret_cast<cufftReal *>(real_result_arr)));
188 #if defined(OPENMP) && !defined(BIG_BATCH)
189         CudaSafeCall(cudaStreamSynchronize(stream));
190 #endif
191     }
192     return;
193 }
194
195 cuFFT::~cuFFT()
196 {
197     CufftErrorCheck(cufftDestroy(plan_f));
198     CufftErrorCheck(cufftDestroy(plan_fw));
199     CufftErrorCheck(cufftDestroy(plan_i_1ch));
200     CufftErrorCheck(cufftDestroy(plan_i_features));
201
202     if (BIG_BATCH_MODE) {
203         CufftErrorCheck(cufftDestroy(plan_f_all_scales));
204         CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
205         CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
206         CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));
207     }
208 }