]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_fftw.cpp
Added New-array Execute function support for FFTW.
[hercules2020/kcf.git] / src / fft_fftw.cpp
1 #include "fft_fftw.h"
2
3 #include "fft.h"
4
5 #ifdef OPENMP
6   #include <omp.h>
7 #endif
8
9 Fftw::Fftw()
10 {
11 }
12
13 void Fftw::init(unsigned width, unsigned height)
14 {
15     m_width = width;
16     m_height = height;
17     plan_f = NULL;
18     plan_fw = NULL;
19     plan_if = NULL;
20     plan_ir = NULL;
21
22 -#if defined(ASYNC) || defined(OPENMP)
23     fftw_init_threads();
24 #endif //OPENMP
25
26 #ifndef CUFFTW
27     std::cout << "FFT: FFTW" << std::endl;
28 #else
29     std::cout << "FFT: cuFFTW" << std::endl;
30 #endif
31 }
32
33 void Fftw::set_window(const cv::Mat &window)
34 {
35     m_window = window;
36 }
37
38 ComplexMat Fftw::forward(const cv::Mat &input)
39 {
40     cv::Mat complex_result(m_height, m_width / 2 + 1, CV_32FC2);
41
42     if(!plan_f){
43 #ifdef ASYNC
44         std::unique_lock<std::mutex> lock(fftw_mut);
45         fftw_plan_with_nthreads(2);
46 #elif OPENMP
47 #pragma omp critical
48         fftw_plan_with_nthreads(omp_get_max_threads());
49 #endif
50 #pragma omp critical
51         plan_f = fftwf_plan_dft_r2c_2d(m_height, m_width,
52                                                   reinterpret_cast<float*>(input.data),
53                                                   reinterpret_cast<fftwf_complex*>(complex_result.data),
54                                                   FFTW_ESTIMATE);
55         fftwf_execute(plan_f);
56     }else{fftwf_execute_dft_r2c(plan_f,reinterpret_cast<float*>(input.data),reinterpret_cast<fftwf_complex*>(complex_result.data));}
57
58     return ComplexMat(complex_result);
59 }
60
61 ComplexMat Fftw::forward_window(const std::vector<cv::Mat> &input)
62 {
63     int n_channels = input.size();
64     cv::Mat in_all(m_height * n_channels, m_width, CV_32F);
65     for (int i = 0; i < n_channels; ++i) {
66         cv::Mat in_roi(in_all, cv::Rect(0, i*m_height, m_width, m_height));
67         in_roi = input[i].mul(m_window);
68     }
69     cv::Mat complex_result(n_channels*m_height, m_width/2+1, CV_32FC2);
70
71     float *in = reinterpret_cast<float*>(in_all.data);
72     fftwf_complex *out = reinterpret_cast<fftwf_complex*>(complex_result.data);
73     if(!plan_fw){
74         int rank = 2;
75         int n[] = {(int)m_height, (int)m_width};
76         int howmany = n_channels;
77         int idist = m_height*m_width, odist = m_height*(m_width/2+1);
78         int istride = 1, ostride = 1;
79         int *inembed = NULL, *onembed = NULL;
80 #pragma omp critical
81 #ifdef ASYNC
82         std::unique_lock<std::mutex> lock(fftw_mut);
83         fftw_plan_with_nthreads(2);
84 #elif OPENMP
85 #pragma omp critical
86         fftw_plan_with_nthreads(omp_get_max_threads());
87 #endif
88         plan_fw = fftwf_plan_many_dft_r2c(rank, n, howmany,
89                                                      in,  inembed, istride, idist,
90                                                      out, onembed, ostride, odist,
91                                                      FFTW_ESTIMATE);
92         fftwf_execute(plan_fw);
93     }else{fftwf_execute_dft_r2c(plan_fw,in,out);}
94
95     ComplexMat result(m_height, m_width/2 + 1, n_channels);
96     for (int i = 0; i < n_channels; ++i)
97         result.set_channel(i, complex_result(cv::Rect(0, i*m_height, m_width/2+1, m_height)));
98
99     return result;
100 }
101
102 cv::Mat Fftw::inverse(const ComplexMat &inputf)
103 {
104     int n_channels = inputf.n_channels;
105     cv::Mat real_result(m_height, m_width, CV_32FC(n_channels));
106     cv::Mat complex_vconcat = inputf.to_vconcat_mat();
107
108     fftwf_complex *in = reinterpret_cast<fftwf_complex*>(complex_vconcat.data);
109     float *out = reinterpret_cast<float*>(real_result.data);
110
111     if(n_channels != 1){
112         if(!plan_if){
113             int rank = 2;
114             int n[] = {(int)m_height, (int)m_width};
115             int howmany = n_channels;
116             int idist = m_height*(m_width/2+1), odist = 1;
117             int istride = 1, ostride = n_channels;
118 #ifndef CUFFTW
119             int *inembed = NULL, *onembed = NULL;
120 #else
121             int inembed[2];
122             int onembed[2];
123             inembed[1] = m_width/2+1, onembed[1] = m_width;
124 #endif
125
126 #ifdef ASYNC
127             std::unique_lock<std::mutex> lock(fftw_mut);
128             fftw_plan_with_nthreads(2);
129 #elif OPENMP
130 #pragma omp critical
131             fftw_plan_with_nthreads(omp_get_max_threads());
132 #endif
133 #pragma omp critical
134             plan_if = fftwf_plan_many_dft_c2r(rank, n, howmany,
135                                                          in,  inembed, istride, idist,
136                                                          out, onembed, ostride, odist,
137                                                          FFTW_ESTIMATE);
138             fftwf_execute(plan_if);
139         }else{fftwf_execute_dft_c2r(plan_if,in,out);}
140     }else{
141         if(!plan_ir){
142             int rank = 2;
143             int n[] = {(int)m_height, (int)m_width};
144             int howmany = n_channels;
145             int idist = m_height*(m_width/2+1), odist = 1;
146             int istride = 1, ostride = n_channels;
147 #ifndef CUFFTW
148             int *inembed = NULL, *onembed = NULL;
149 #else
150             int inembed[2];
151             int onembed[2];
152             inembed[1] = m_width/2+1, onembed[1] = m_width;
153 #endif
154
155 #ifdef ASYNC
156             std::unique_lock<std::mutex> lock(fftw_mut);
157             fftw_plan_with_nthreads(2);
158 #elif OPENMP
159 #pragma omp critical
160             fftw_plan_with_nthreads(omp_get_max_threads());
161 #endif
162 #pragma omp critical
163             plan_ir = fftwf_plan_many_dft_c2r(rank, n, howmany,
164                                                          in,  inembed, istride, idist,
165                                                          out, onembed, ostride, odist,
166                                                          FFTW_ESTIMATE);
167             fftwf_execute(plan_ir);
168     }else{fftwf_execute_dft_c2r(plan_ir,in,out);}
169   }
170
171     return real_result/(m_width*m_height);
172 }
173
174 Fftw::~Fftw()
175 {
176   fftwf_destroy_plan(plan_f);
177   fftwf_destroy_plan(plan_fw);
178   fftwf_destroy_plan(plan_if);
179   fftwf_destroy_plan(plan_ir);
180 }