]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_fftw.cpp
Repaired the big batch mode for FFTW.
[hercules2020/kcf.git] / src / fft_fftw.cpp
1 #include "fft_fftw.h"
2
3 #include "fft.h"
4
5 #ifdef OPENMP
6 #include <omp.h>
7 #endif
8
9 #if !defined(ASYNC) && !defined(OPENMP) && !defined(CUFFTW)
10 #define FFTW_PLAN_WITH_THREADS() fftw_plan_with_nthreads(4);
11 #else
12 #define FFTW_PLAN_WITH_THREADS()
13 #endif
14
15 Fftw::Fftw(){}
16
17 void Fftw::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
18 {
19     Fft::init(width, height, num_of_feats, num_of_scales);
20
21 #if (!defined(ASYNC) && !defined(CUFFTW)) && defined(OPENMP)
22     fftw_init_threads();
23 #endif // OPENMP
24
25 #ifndef CUFFTW
26     std::cout << "FFT: FFTW" << std::endl;
27 #else
28     std::cout << "FFT: cuFFTW" << std::endl;
29 #endif
30     fftwf_cleanup();
31     // FFT forward one scale
32     {
33         cv::Mat in_f = cv::Mat::zeros(int(m_height), int(m_width), CV_32FC1);
34         ComplexMat out_f(int(m_height), m_width / 2 + 1, 1);
35         plan_f = fftwf_plan_dft_r2c_2d(int(m_height), int(m_width), reinterpret_cast<float *>(in_f.data),
36                                        reinterpret_cast<fftwf_complex *>(out_f.get_p_data()), FFTW_PATIENT);
37     }
38 #ifdef BIG_BATCH
39     // FFT forward all scales
40     if (m_num_of_scales > 1) {
41         cv::Mat in_f_all = cv::Mat::zeros(m_height * m_num_of_scales, m_width, CV_32F);
42         ComplexMat out_f_all(m_height, m_width / 2 + 1, m_num_of_scales);
43         float *in = reinterpret_cast<float *>(in_f_all.data);
44         fftwf_complex *out = reinterpret_cast<fftwf_complex *>(out_f_all.get_p_data());
45         int rank = 2;
46         int n[] = {(int)m_height, (int)m_width};
47         int howmany = m_num_of_scales;
48         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
49         int istride = 1, ostride = 1;
50         int *inembed = NULL, *onembed = NULL;
51
52         FFTW_PLAN_WITH_THREADS();
53         plan_f_all_scales = fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, out, onembed,
54                                                     ostride, odist, FFTW_PATIENT);
55     }
56 #endif
57     // FFT forward window one scale
58     {
59         cv::Mat in_fw = cv::Mat::zeros(int(m_height * m_num_of_feats), int(m_width), CV_32F);
60         ComplexMat out_fw(int(m_height), m_width / 2 + 1, int(m_num_of_feats));
61         float *in = reinterpret_cast<float *>(in_fw.data);
62         fftwf_complex *out = reinterpret_cast<fftwf_complex *>(out_fw.get_p_data());
63         int rank = 2;
64         int n[] = {int(m_height), int(m_width)};
65         int howmany = int(m_num_of_feats);
66         int idist = int(m_height * m_width), odist = int(m_height * (m_width / 2 + 1));
67         int istride = 1, ostride = 1;
68         int *inembed = nullptr, *onembed = nullptr;
69
70         FFTW_PLAN_WITH_THREADS();
71         plan_fw = fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride, odist,
72                                           FFTW_PATIENT);
73     }
74 #ifdef BIG_BATCH
75     // FFT forward window all scales all feats
76     if (m_num_of_scales > 1) {
77         cv::Mat in_all = cv::Mat::zeros(m_height * (m_num_of_scales * m_num_of_feats), m_width, CV_32F);
78         ComplexMat out_all(m_height, m_width / 2 + 1, m_num_of_scales * m_num_of_feats);
79         float *in = reinterpret_cast<float *>(in_all.data);
80         fftwf_complex *out = reinterpret_cast<fftwf_complex *>(out_all.get_p_data());
81         int rank = 2;
82         int n[] = {(int)m_height, (int)m_width};
83         int howmany = m_num_of_scales * m_num_of_feats;
84         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
85         int istride = 1, ostride = 1;
86         int *inembed = NULL, *onembed = NULL;
87
88         FFTW_PLAN_WITH_THREADS();
89         plan_fw_all_scales = fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, out, onembed,
90                                                      ostride, odist, FFTW_PATIENT);
91     }
92 #endif
93 #ifdef BIG_BATCH
94     // FFT inverse all scales
95     {
96         ComplexMat in_i_all(m_height, m_width / 2 + 1, m_num_of_scales);
97         cv::Mat out_i_all = cv::Mat::zeros(m_height, m_width, CV_32FC(m_num_of_scales));
98         fftwf_complex *in = reinterpret_cast<fftwf_complex *>(in_i_all.get_p_data());
99         float *out = reinterpret_cast<float *>(out_i_all.data);
100         int rank = 2;
101         int n[] = {(int)m_height, (int)m_width};
102         int howmany = m_num_of_scales;
103         int idist = m_height * (m_width / 2 + 1), odist = m_height * m_width;
104         int istride = 1, ostride = 1;
105         int inembed[] = {(int)m_height, (int)m_width / 2 + 1}, *onembed = n;
106
107         FFTW_PLAN_WITH_THREADS();
108         plan_i_all_scales = fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, out,
109                                                     onembed, ostride, odist, FFTW_PATIENT);
110     }
111 #endif
112     // FFT inverse one channel
113     {
114         ComplexMat in_i1(int(m_height), int(m_width), 1);
115         cv::Mat out_i1 = cv::Mat::zeros(int(m_height), int(m_width), CV_32FC1);
116         fftwf_complex *in = reinterpret_cast<fftwf_complex *>(in_i1.get_p_data());
117         float *out = reinterpret_cast<float *>(out_i1.data);
118         int rank = 2;
119         int n[] = {int(m_height), int(m_width)};
120         int howmany = 1;
121         int idist = m_height * (m_width / 2 + 1), odist = 1;
122         int istride = 1, ostride = 1;
123         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
124
125         FFTW_PLAN_WITH_THREADS();
126         plan_i_1ch = fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride,
127                                              odist, FFTW_PATIENT);
128     }
129 }
130
131 void Fftw::set_window(const MatDynMem &window)
132 {
133     Fft::set_window(window);
134     m_window = window;
135 }
136
137 void Fftw::forward(const MatScales &real_input, ComplexMat &complex_result)
138 {
139     Fft::forward(real_input, complex_result);
140
141     if (real_input.size[0] == 1)
142         fftwf_execute_dft_r2c(plan_f, reinterpret_cast<float *>(real_input.data),
143                               reinterpret_cast<fftwf_complex *>(complex_result.get_p_data()));
144 #ifdef BIG_BATCH
145     else
146         fftwf_execute_dft_r2c(plan_f_all_scales, reinterpret_cast<float *>(real_input.data),
147                               reinterpret_cast<fftwf_complex *>(complex_result.get_p_data()));
148 #endif
149 }
150
151 void Fftw::forward_window(MatScaleFeats  &feat, ComplexMat & complex_result, MatScaleFeats &temp)
152 {
153     Fft::forward_window(feat, complex_result, temp);
154
155     uint n_scales = feat.size[0];
156     for (uint s = 0; s < n_scales; ++s) {
157         for (uint ch = 0; ch < uint(feat.size[1]); ++ch) {
158             cv::Mat feat_plane = feat.plane(s, ch);
159             cv::Mat temp_plane = temp.plane(s, ch);
160             temp_plane = feat_plane.mul(m_window);
161         }
162     }
163
164     float *in = temp.ptr<float>();
165     fftwf_complex *out = reinterpret_cast<fftwf_complex *>(complex_result.get_p_data());
166
167     if (n_scales == 1)
168         fftwf_execute_dft_r2c(plan_fw, in, out);
169     else
170         fftwf_execute_dft_r2c(plan_fw_all_scales, in, out);
171     return;
172 }
173
174 void Fftw::inverse(ComplexMat &complex_input, MatScales &real_result)
175 {
176     Fft::inverse(complex_input, real_result);
177
178     int n_channels = complex_input.n_channels;
179     fftwf_complex *in = reinterpret_cast<fftwf_complex *>(complex_input.get_p_data());
180     float *out = real_result.ptr<float>();
181
182     if (n_channels == 1)
183         fftwf_execute_dft_c2r(plan_i_1ch, in, out);
184 #ifdef BIG_BATCH
185     else
186         fftwf_execute_dft_c2r(plan_i_all_scales, in, out);
187 #endif
188     real_result *= 1.0 / (m_width * m_height);
189 }
190
191 Fftw::~Fftw()
192 {
193     fftwf_destroy_plan(plan_f);
194     fftwf_destroy_plan(plan_fw);
195     fftwf_destroy_plan(plan_i_1ch);
196
197     if (BIG_BATCH_MODE) {
198         fftwf_destroy_plan(plan_f_all_scales);
199         fftwf_destroy_plan(plan_i_all_scales);
200         fftwf_destroy_plan(plan_fw_all_scales);
201     }
202 }