]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
da43fbf4bd06f92a606380d4dd91097212994e82
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2
3 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales, bool big_batch_mode)
4 {
5     m_width = width;
6     m_height = height;
7     m_num_of_feats = num_of_feats;
8     m_num_of_scales = num_of_scales;
9     m_big_batch_mode = big_batch_mode;
10
11     std::cout << "FFT: cuFFT" << std::endl;
12
13     //FFT forward one scale
14     {
15        CufftErrorCheck(cufftPlan2d(&plan_f, int(m_height), int(m_width), CUFFT_R2C));
16     }
17 #ifdef BIG_BATCH
18     //FFT forward all scales
19     if(m_num_of_scales > 1 && m_big_batch_mode)
20     {
21         int rank = 2;
22         int n[] = {(int)m_height, (int)m_width};
23         int howmany = m_num_of_scales;
24         int idist = m_height*m_width, odist = m_height*(m_width/2+1);
25         int istride = 1, ostride = 1;
26         int *inembed = n, onembed[] = {(int)m_height, (int)m_width/2+1};
27
28         CufftErrorCheck(cufftPlanMany(&plan_f_all_scales, rank, n,
29                   inembed, istride, idist,
30                   onembed, ostride, odist,
31                   CUFFT_R2C, howmany));
32     }
33 #endif
34     //FFT forward window one scale
35     {
36         int rank = 2;
37         int n[] = {int(m_height), int(m_width)};
38         int howmany = int(m_num_of_feats);
39         int idist = int(m_height*m_width), odist = int(m_height*(m_width/2+1));
40         int istride = 1, ostride = 1;
41         int *inembed = n, onembed[] = {int(m_height), int(m_width/2+1)};
42
43         CufftErrorCheck(cufftPlanMany(&plan_fw, rank, n,
44                   inembed, istride, idist,
45                   onembed, ostride, odist,
46                   CUFFT_R2C, howmany));
47     }
48 #ifdef BIG_BATCH
49     //FFT forward window all scales all feats
50     if(m_num_of_scales > 1 && m_big_batch_mode)
51     {
52         int rank = 2;
53         int n[] = {(int)m_height, (int)m_width};
54         int howmany = m_num_of_scales*m_num_of_feats;
55         int idist = m_height*m_width, odist = m_height*(m_width/2+1);
56         int istride = 1, ostride = 1;
57         int *inembed = n, onembed[] = {(int)m_height, (int)m_width/2+1};
58
59         CufftErrorCheck(cufftPlanMany(&plan_fw_all_scales, rank, n,
60                   inembed, istride, idist,
61                   onembed, ostride, odist,
62                   CUFFT_R2C, howmany));
63     }
64 #endif
65     //FFT inverse one scale
66     {
67         int rank = 2;
68         int n[] = {int(m_height), int(m_width)};
69         int howmany = int(m_num_of_feats);
70         int idist = int(m_height*(m_width/2+1)), odist = 1;
71         int istride = 1, ostride = int(m_num_of_feats);
72         int inembed[] = {int(m_height), int(m_width/2+1)}, *onembed = n;
73
74         CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n,
75                   inembed, istride, idist,
76                   onembed, ostride, odist,
77                   CUFFT_C2R, howmany));
78     }
79     //FFT inverse all scales
80 #ifdef BIG_BATCH
81     if(m_num_of_scales > 1 && m_big_batch_mode)
82     {
83         int rank = 2;
84         int n[] = {(int)m_height, (int)m_width};
85         int howmany = m_num_of_feats*m_num_of_scales;
86         int idist = m_height*(m_width/2+1), odist = 1;
87         int istride = 1, ostride = m_num_of_feats*m_num_of_scales;
88         int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
89
90         CufftErrorCheck(cufftPlanMany(&plan_i_features_all_scales, rank, n,
91                   inembed, istride, idist,
92                   onembed, ostride, odist,
93                   CUFFT_C2R, howmany));
94     }
95 #endif
96     //FFT inverse one channel one scale
97     {
98         int rank = 2;
99         int n[] = {int(m_height), int(m_width)};
100         int howmany = 1;
101         int idist = int(m_height*(m_width/2+1)), odist = 1;
102         int istride = 1, ostride = 1;
103         int inembed[] = {int(m_height), int(m_width/2+1)}, *onembed = n;
104
105         CufftErrorCheck(cufftPlanMany(&plan_i_1ch, rank, n,
106                   inembed, istride, idist,
107                   onembed, ostride, odist,
108                   CUFFT_C2R, howmany));
109     }
110 #ifdef BIG_BATCH
111     //FFT inverse one channel all scales
112     if(m_num_of_scales > 1 && m_big_batch_mode)
113     {
114         int rank = 2;
115         int n[] = {(int)m_height, (int)m_width};
116         int howmany = m_num_of_scales;
117         int idist = m_height*(m_width/2+1), odist = 1;
118         int istride = 1, ostride = m_num_of_scales;
119         int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
120
121         CufftErrorCheck(cufftPlanMany(&plan_i_1ch_all_scales, rank, n,
122                   inembed, istride, idist,
123                   onembed, ostride, odist,
124                   CUFFT_C2R, howmany));
125     }
126 #endif
127 }
128
129 void cuFFT::set_window(const cv::Mat & window)
130 {
131      m_window = window;
132 }
133
134 void cuFFT::forward(const cv::Mat & real_input, ComplexMat & complex_result, float *real_input_arr, cudaStream_t stream)
135 {
136     (void) real_input;
137
138     if(m_big_batch_mode && real_input.rows == int(m_height*m_num_of_scales)){
139         CufftErrorCheck(cufftExecR2C(plan_f_all_scales, reinterpret_cast<cufftReal*>(real_input_arr),
140                                 complex_result.get_p_data()));
141     } else {
142         CufftErrorCheck(cufftSetStream(plan_f, stream));
143         CufftErrorCheck(cufftExecR2C(plan_f, reinterpret_cast<cufftReal*>(real_input_arr),
144                                 complex_result.get_p_data()));
145     }
146     return;
147 }
148
149 void cuFFT::forward_window(std::vector<cv::Mat> patch_feats, ComplexMat & complex_result, cv::Mat & fw_all, float *real_input_arr, cudaStream_t stream)
150 {
151     int n_channels = int(patch_feats.size());
152
153     if(n_channels > int(m_num_of_feats)){
154         for (uint i = 0; i < uint(n_channels); ++i) {
155             cv::Mat in_roi(fw_all, cv::Rect(0, int(i*m_height), int(m_width), int(m_height)));
156             in_roi = patch_feats[i].mul(m_window);
157         }
158         CufftErrorCheck(cufftExecR2C(plan_fw_all_scales, reinterpret_cast<cufftReal*>(real_input_arr), complex_result.get_p_data()));
159     } else {
160         for (uint i = 0; i < uint(n_channels); ++i) {
161             cv::Mat in_roi(fw_all, cv::Rect(0, int(i*m_height), int(m_width), int(m_height)));
162             in_roi = patch_feats[i].mul(m_window);
163         }
164         CufftErrorCheck(cufftSetStream(plan_fw, stream));
165         CufftErrorCheck(cufftExecR2C(plan_fw, reinterpret_cast<cufftReal*>(real_input_arr), complex_result.get_p_data()));
166     }
167     return;
168 }
169
170 void cuFFT::inverse(ComplexMat &  complex_input, cv::Mat & real_result, float *real_result_arr, cudaStream_t stream)
171 {
172     int n_channels = complex_input.n_channels;
173     cufftComplex *in = reinterpret_cast<cufftComplex*>(complex_input.get_p_data());
174
175     if(n_channels == 1){
176         CufftErrorCheck(cufftSetStream(plan_i_1ch, stream));
177         CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, reinterpret_cast<cufftReal*>(real_result_arr)));
178         cudaStreamSynchronize(stream);
179         real_result = real_result/(m_width*m_height);
180         return;
181     } else if(n_channels == int(m_num_of_scales)){
182         CufftErrorCheck(cufftExecC2R(plan_i_1ch_all_scales, in, reinterpret_cast<cufftReal*>(real_result_arr)));
183         cudaStreamSynchronize(stream);
184
185         real_result = real_result/(m_width*m_height);
186         return;
187     } else if(n_channels == int(m_num_of_feats) * int(m_num_of_scales)){
188         CufftErrorCheck(cufftExecC2R(plan_i_features_all_scales, in, reinterpret_cast<cufftReal*>(real_result_arr)));
189         return;
190     }
191     CufftErrorCheck(cufftSetStream(plan_i_features, stream));
192     CufftErrorCheck(cufftExecC2R(plan_i_features, in, reinterpret_cast<cufftReal*>(real_result_arr)));
193     return;
194 }
195
196 cuFFT::~cuFFT()
197 {
198   CufftErrorCheck(cufftDestroy(plan_f));
199   CufftErrorCheck(cufftDestroy(plan_fw));
200   CufftErrorCheck(cufftDestroy(plan_i_1ch));
201   CufftErrorCheck(cufftDestroy(plan_i_features));
202   
203   if (m_big_batch_mode) {
204       CufftErrorCheck(cufftDestroy(plan_f_all_scales));
205       CufftErrorCheck(cufftDestroy(plan_fw_all_scales));
206       CufftErrorCheck(cufftDestroy(plan_i_1ch_all_scales));
207       CufftErrorCheck(cufftDestroy(plan_i_features_all_scales));
208   }
209 }