]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_fftw.cpp
Merge branch 'master' of github.com:Shanigen/kcf
[hercules2020/kcf.git] / src / fft_fftw.cpp
1 #include "fft_fftw.h"
2
3 #include "fft.h"
4
5 #ifdef OPENMP
6   #include <omp.h>
7 #endif
8
9 Fftw::Fftw()
10 {
11 }
12
13 void Fftw::init(unsigned width, unsigned height)
14 {
15     m_width = width;
16     m_height = height;
17     plan_f = NULL;
18     plan_fw = NULL;
19     plan_if = NULL;
20     plan_ir = NULL;
21
22 #if defined(ASYNC) || defined(OPENMP)
23     fftw_init_threads();
24 #endif //OPENMP
25
26 #ifndef CUFFTW
27     std::cout << "FFT: FFTW" << std::endl;
28 #else
29     std::cout << "FFT: cuFFTW" << std::endl;
30 #endif
31 }
32
33 void Fftw::set_window(const cv::Mat &window)
34 {
35     m_window = window;
36 }
37
38 ComplexMat Fftw::forward(const cv::Mat &input)
39 {
40     cv::Mat complex_result(m_height, m_width / 2 + 1, CV_32FC2);
41
42     if(!plan_f){
43 #ifdef ASYNC
44         std::unique_lock<std::mutex> lock(fftw_mut);
45         fftw_plan_with_nthreads(2);
46 #elif OPENMP
47 #pragma omp critical
48         fftw_plan_with_nthreads(omp_get_max_threads());
49 #endif
50 #pragma omp critical
51         plan_f = fftwf_plan_dft_r2c_2d(m_height, m_width,
52                                                   reinterpret_cast<float*>(input.data),
53                                                   reinterpret_cast<fftwf_complex*>(complex_result.data),
54                                                   FFTW_ESTIMATE);
55         fftwf_execute(plan_f);
56     }else{fftwf_execute_dft_r2c(plan_f,reinterpret_cast<float*>(input.data),reinterpret_cast<fftwf_complex*>(complex_result.data));}
57
58     return ComplexMat(complex_result);
59 }
60
61 ComplexMat Fftw::forward_window(const std::vector<cv::Mat> &input)
62 {
63     int n_channels = input.size();
64     cv::Mat in_all(m_height * n_channels, m_width, CV_32F);
65     for (int i = 0; i < n_channels; ++i) {
66         cv::Mat in_roi(in_all, cv::Rect(0, i*m_height, m_width, m_height));
67         in_roi = input[i].mul(m_window);
68     }
69     cv::Mat complex_result(n_channels*m_height, m_width/2+1, CV_32FC2);
70
71     float *in = reinterpret_cast<float*>(in_all.data);
72     fftwf_complex *out = reinterpret_cast<fftwf_complex*>(complex_result.data);
73     if(!plan_fw){
74         int rank = 2;
75         int n[] = {(int)m_height, (int)m_width};
76         int howmany = n_channels;
77         int idist = m_height*m_width, odist = m_height*(m_width/2+1);
78         int istride = 1, ostride = 1;
79         int *inembed = NULL, *onembed = NULL;
80 #pragma omp critical
81 #ifdef ASYNC
82         std::unique_lock<std::mutex> lock(fftw_mut);
83         fftw_plan_with_nthreads(2);
84 #elif OPENMP
85 #pragma omp critical
86         fftw_plan_with_nthreads(omp_get_max_threads());
87 #endif
88         plan_fw = fftwf_plan_many_dft_r2c(rank, n, howmany,
89                                                      in,  inembed, istride, idist,
90                                                      out, onembed, ostride, odist,
91                                                      FFTW_ESTIMATE);
92         fftwf_execute(plan_fw);
93     }else{fftwf_execute_dft_r2c(plan_fw,in,out);}
94
95     ComplexMat result(m_height, m_width/2 + 1, n_channels);
96     for (int i = 0; i < n_channels; ++i)
97         result.set_channel(i, complex_result(cv::Rect(0, i*m_height, m_width/2+1, m_height)));
98
99     return result;
100 }
101
102 cv::Mat Fftw::inverse(const ComplexMat &inputf)
103 {
104     int n_channels = inputf.n_channels;
105     cv::Mat real_result(m_height, m_width, CV_32FC(n_channels));
106     cv::Mat complex_vconcat = inputf.to_vconcat_mat();
107
108     fftwf_complex *in = reinterpret_cast<fftwf_complex*>(complex_vconcat.data);
109     float *out = reinterpret_cast<float*>(real_result.data);
110
111     if(n_channels != 1){
112         if(!plan_if){
113             int rank = 2;
114             int n[] = {(int)m_height, (int)m_width};
115             int howmany = n_channels;
116             int idist = m_height*(m_width/2+1), odist = 1;
117             int istride = 1, ostride = n_channels;
118             int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
119
120 #ifdef ASYNC
121             std::unique_lock<std::mutex> lock(fftw_mut);
122             fftw_plan_with_nthreads(2);
123 #elif OPENMP
124 #pragma omp critical
125             fftw_plan_with_nthreads(omp_get_max_threads());
126 #endif
127 #pragma omp critical
128             plan_if = fftwf_plan_many_dft_c2r(rank, n, howmany,
129                                                          in,  inembed, istride, idist,
130                                                          out, onembed, ostride, odist,
131                                                          FFTW_ESTIMATE);
132             fftwf_execute(plan_if);
133         }else{fftwf_execute_dft_c2r(plan_if,in,out);}
134     }else{
135         if(!plan_ir){
136             int rank = 2;
137             int n[] = {(int)m_height, (int)m_width};
138             int howmany = n_channels;
139             int idist = m_height*(m_width/2+1), odist = 1;
140             int istride = 1, ostride = n_channels;
141 #ifndef CUFFTW
142             int *inembed = NULL, *onembed = NULL;
143 #else
144             int inembed[2];
145             int onembed[2];
146             inembed[1] = m_width/2+1, onembed[1] = m_width;
147 #endif
148
149 #ifdef ASYNC
150             std::unique_lock<std::mutex> lock(fftw_mut);
151             fftw_plan_with_nthreads(2);
152 #elif OPENMP
153 #pragma omp critical
154             fftw_plan_with_nthreads(omp_get_max_threads());
155 #endif
156 #pragma omp critical
157             plan_ir = fftwf_plan_many_dft_c2r(rank, n, howmany,
158                                                          in,  inembed, istride, idist,
159                                                          out, onembed, ostride, odist,
160                                                          FFTW_ESTIMATE);
161             fftwf_execute(plan_ir);
162     }else{fftwf_execute_dft_c2r(plan_ir,in,out);}
163   }
164
165     return real_result/(m_width*m_height);
166 }
167
168 Fftw::~Fftw()
169 {
170   fftwf_destroy_plan(plan_f);
171   fftwf_destroy_plan(plan_fw);
172   fftwf_destroy_plan(plan_if);
173   fftwf_destroy_plan(plan_ir);
174 }