]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
63f2558621c58857f3bfd1532686f7ca0d0c98b3
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2 #include <cublas_v2.h>
3
4 cuFFT::cuFFT()
5 {
6     CublasErrorCheck(cublasCreate(&cublas));
7 }
8
9 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
10 {
11     m_width = width;
12     m_height = height;
13     m_num_of_feats = num_of_feats;
14     m_num_of_scales = num_of_scales;
15
16     std::cout << "FFT: cuFFT" << std::endl;
17
18     // FFT forward one scale
19     {
20         CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
21     }
22 #ifdef BIG_BATCH
23     // FFT forward all scales
24     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
25         int rank = 2;
26         int n[] = {(int)m_height, (int)m_width};
27         int howmany = m_num_of_scales;
28         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
29         int istride = 1, ostride = 1;
30         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
31
32         CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
33                                       CUFFT_R2C, howmany));
34         CufftErrorCheck(cufftSetStream(plan_f_all_scales, cudaStreamPerThread));
35     }
36 #endif
37     // FFT forward window one scale
38     {
39         int rank = 2;
40         int n[] = {int(m_height), int(m_width)};
41         int howmany = int(m_num_of_feats);
42         int idist = int(m_height * m_width), odist = int(m_height * (m_width / 2 + 1));
43         int istride = 1, ostride = 1;
44         int *inembed = n, onembed[] = {int(m_height), int(m_width / 2 + 1)};
45
46         CufftErrorCheck(
47             cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
48         CufftErrorCheck(cufftSetStream(plan_fw, cudaStreamPerThread));
49     }
50 #ifdef BIG_BATCH
51     // FFT forward window all scales all feats
52     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
53         int rank = 2;
54         int n[] = {(int)m_height, (int)m_width};
55         int howmany = m_num_of_scales * m_num_of_feats;
56         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
57         int istride = 1, ostride = 1;
58         int *inembed = n, onembed[] = {(int)m_height, (int)m_width / 2 + 1};
59
60         CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
61                                       CUFFT_R2C, howmany));
62         CufftErrorCheck(cufftSetStream(plan_fw_all_scales, cudaStreamPerThread));
63     }
64 #endif
65     // FFT inverse one scale
66     {
67         int rank = 2;
68         int n[] = {int(m_height), int(m_width)};
69         int howmany = int(m_num_of_feats);
70         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
71         int istride = 1, ostride = int(m_num_of_feats);
72         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
73
74         CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist,
75                                       CUFFT_C2R, howmany));
76         CufftErrorCheck(cufftSetStream(plan_i_features, cudaStreamPerThread));
77     }
78     // FFT inverse all scales
79 #ifdef BIG_BATCH
80     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
81         int rank = 2;
82         int n[] = {(int)m_height, (int)m_width};
83         int howmany = m_num_of_feats * m_num_of_scales;
84         int idist = m_height * (m_width / 2 + 1), odist = 1;
85         int istride = 1, ostride = m_num_of_feats * m_num_of_scales;
86         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
87
88         CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n, inembed, istride, idist, onembed, ostride,
89                                       odist, CUFFT_C2R, howmany));
90         CufftErrorCheck(cufftSetStream(plan_i_features_all_scales, cudaStreamPerThread));
91     }
92 #endif
93     // FFT inverse one channel one scale
94     {
95         int rank = 2;
96         int n[] = {int(m_height), int(m_width)};
97         int howmany = 1;
98         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
99         int istride = 1, ostride = 1;
100         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
101
102         CufftErrorCheck(
103             cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
104         CufftErrorCheck(cufftSetStream(plan_i_1ch, cudaStreamPerThread));
105     }
106 #ifdef BIG_BATCH
107     // FFT inverse one channel all scales
108     if (m_num_of_scales > 1 && BIG_BATCH_MODE) {
109         int rank = 2;
110         int n[] = {(int)m_height, (int)m_width};
111         int howmany = m_num_of_scales;
112         int idist = m_height * (m_width / 2 + 1), odist = 1;
113         int istride = 1, ostride = m_num_of_scales;
114         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
115
116         CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n, inembed, istride, idist, onembed, ostride, odist,
117                                       CUFFT_C2R, howmany));
118         CufftErrorCheck(cufftSetStream(plan_i_1ch_all_scales, cudaStreamPerThread));
119     }
120 #endif
121 }
122
123 void cuFFT::set_window(const MatDynMem &window)
124 {
125     m_window = window;
126 }
127
128 void cuFFT::forward(MatDynMem & real_input, ComplexMat & complex_result)
129 {
130     if (BIG_BATCH_MODE && real_input.rows == int(m_height * m_num_of_scales)) {
131         CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal *>(real_input.deviceMem()),
132                                      complex_result.get_p_data()));
133     } else {
134         NORMAL_OMP_CRITICAL
135         {
136             CufftErrorCheck(
137                 cufftExecR2C(plan_f, reinterpret_cast<cufftReal *>(real_input.deviceMem()), complex_result.get_p_data()));
138             cudaStreamSynchronize(cudaStreamPerThread);
139         }
140     }
141     return;
142 }
143
144 void cuFFT::forward_window(MatDynMem &feat, ComplexMat & complex_result, MatDynMem &temp)
145 {
146     uint n_channels = feat.size[0];
147     cufftReal *temp_data = temp.deviceMem();
148
149     assert(feat.dims == 3);
150     assert(n_channels == m_num_of_feats || n_channels == m_num_of_feats * m_num_of_scales);
151
152     for (uint i = 0; i < n_channels; ++i) {
153         cv::Mat feat_plane(feat.dims - 1, feat.size + 1, feat.cv::Mat::type(), feat.ptr<void>(i));
154         cv::Mat temp_plane(temp.dims - 1, temp.size + 1, temp.cv::Mat::type(), temp.ptr(i));
155         temp_plane = feat_plane.mul(m_window);
156     }
157     CufftErrorCheck(cufftExecR2C((n_channels == m_num_of_feats) ? plan_fw : plan_fw_all_scales,
158                                  temp_data, complex_result.get_p_data()));
159 }
160
161 void cuFFT::inverse(ComplexMat &complex_input, MatDynMem &real_result)
162 {
163     uint n_channels = complex_input.n_channels;
164     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
165     cufftReal *out = real_result.deviceMem();
166     float alpha = 1.0 / (m_width * m_height);
167     cufftHandle plan;
168
169     if (n_channels == 1) {
170         CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
171         CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
172         return;
173     } else if (n_channels == m_num_of_scales) {
174         CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, out));
175         CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
176         return;
177     } else if (n_channels == m_num_of_feats * m_num_of_scales) {
178         CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, out));
179         cudaStreamSynchronize(cudaStreamPerThread);
180         return;
181     }
182     CufftErrorCheck(cufftExecC2R(plan_i_features, in, out));
183     return;
184 }
185
186 cuFFT::~cuFFT()
187 {
188     CublasErrorCheck(cublasDestroy(cublas));
189
190     CufftErrorCheck(cufftDestroy(plan_f));
191     CufftErrorCheck(cufftDestroy(plan_fw));
192     CufftErrorCheck(cufftDestroy(plan_i_1ch));
193     CufftErrorCheck(cufftDestroy(plan_i_features));
194
195     if (BIG_BATCH_MODE) {
196         CufftErrorCheck(cufftDestroy(plan_f_all_scales));
197         CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
198         CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
199         CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));
200     }
201 }