]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_fftw.cpp
FFTW OpenMP support restored.
[hercules2020/kcf.git] / src / fft_fftw.cpp
1 #include "fft_fftw.h"
2
3 #include "fft.h"
4 #ifndef CUFFTW
5   #include <fftw3.h>
6 #else
7   #include <cufftw.h>
8 #endif //CUFFTW
9
10 #ifdef OPENMP
11   #include <omp.h>
12 #endif
13
14 Fftw::Fftw()
15 {
16 }
17
18 void Fftw::init(unsigned width, unsigned height)
19 {
20     m_width = width;
21     m_height = height;
22
23 #if defined(ASYNC) || defined(OPENMP)
24     fftw_init_threads();
25 #endif //OPENMP
26
27 #ifndef CUFFTW
28     std::cout << "FFT: FFTW" << std::endl;
29 #else
30     std::cout << "FFT: cuFFTW" << std::endl;
31 #endif
32 }
33
34 void Fftw::set_window(const cv::Mat &window)
35 {
36     m_window = window;
37 }
38
39 ComplexMat Fftw::forward(const cv::Mat &input)
40 {
41     cv::Mat complex_result(m_height, m_width / 2 + 1, CV_32FC2);
42 #ifdef ASYNC
43     std::unique_lock<std::mutex> lock(fftw_mut);
44     fftw_plan_with_nthreads(2);
45 #elif OPENMP
46 #pragma omp critical
47     fftw_plan_with_nthreads(omp_get_max_threads());
48 #endif
49 fftwf_plan plan;
50 #pragma omp critical
51     plan = fftwf_plan_dft_r2c_2d(m_height, m_width,
52                                             reinterpret_cast<float*>(input.data),
53                                             reinterpret_cast<fftwf_complex*>(complex_result.data),
54                                             FFTW_ESTIMATE);
55 #ifdef ASYNC
56     lock.unlock();
57 #endif
58     fftwf_execute(plan);
59 #ifdef ASYNC
60     lock.lock();
61 #endif
62 #pragma omp critical
63     fftwf_destroy_plan(plan);
64 #ifdef ASYNC
65     lock.unlock();
66 #endif
67     return ComplexMat(complex_result);
68 }
69
70 ComplexMat Fftw::forward_window(const std::vector<cv::Mat> &input)
71 {
72     int n_channels = input.size();
73     cv::Mat in_all(m_height * n_channels, m_width, CV_32F);
74     for (int i = 0; i < n_channels; ++i) {
75         cv::Mat in_roi(in_all, cv::Rect(0, i*m_height, m_width, m_height));
76         in_roi = input[i].mul(m_window);
77     }
78     cv::Mat complex_result(n_channels*m_height, m_width/2+1, CV_32FC2);
79
80     int rank = 2;
81     int n[] = {(int)m_height, (int)m_width};
82     int howmany = n_channels;
83     int idist = m_height*m_width, odist = m_height*(m_width/2+1);
84     int istride = 1, ostride = 1;
85     int *inembed = NULL, *onembed = NULL;
86     float *in = reinterpret_cast<float*>(in_all.data);
87     fftwf_complex *out = reinterpret_cast<fftwf_complex*>(complex_result.data);
88 #ifdef ASYNC
89     std::unique_lock<std::mutex> lock(fftw_mut);
90     fftw_plan_with_nthreads(2);
91 #elif OPENMP
92 #pragma omp critical
93     fftw_plan_with_nthreads(omp_get_max_threads());
94 #endif
95 fftwf_plan plan;
96 #pragma omp critical
97     plan = fftwf_plan_many_dft_r2c(rank, n, howmany,
98                                               in,  inembed, istride, idist,
99                                               out, onembed, ostride, odist,
100                                               FFTW_ESTIMATE);
101 #ifdef ASYNC
102     lock.unlock();
103 #endif
104     fftwf_execute(plan);
105 #ifdef ASYNC
106     lock.lock();
107 #endif
108 #pragma omp critical
109     fftwf_destroy_plan(plan);
110 #ifdef ASYNC
111     lock.unlock();
112 #endif
113
114     ComplexMat result(m_height, m_width/2 + 1, n_channels);
115     for (int i = 0; i < n_channels; ++i)
116         result.set_channel(i, complex_result(cv::Rect(0, i*m_height, m_width/2+1, m_height)));
117
118     return result;
119 }
120
121 cv::Mat Fftw::inverse(const ComplexMat &inputf)
122 {
123     int n_channels = inputf.n_channels;
124     cv::Mat real_result(m_height, m_width, CV_32FC(n_channels));
125     cv::Mat complex_vconcat = inputf.to_vconcat_mat();
126
127     int rank = 2;
128     int n[] = {(int)m_height, (int)m_width};
129     int howmany = n_channels;
130     int idist = m_height*(m_width/2+1), odist = 1;
131     int istride = 1, ostride = n_channels;
132 #ifndef CUFFTW
133     int *inembed = NULL, *onembed = NULL;
134 #else
135     int inembed[2];
136     int onembed[2];
137     inembed[1] = m_width/2+1, onembed[1] = m_width;
138 #endif
139     fftwf_complex *in = reinterpret_cast<fftwf_complex*>(complex_vconcat.data);
140     float *out = reinterpret_cast<float*>(real_result.data);
141 #ifdef ASYNC
142     std::unique_lock<std::mutex> lock(fftw_mut);
143     fftw_plan_with_nthreads(2);
144 #elif OPENMP
145 #pragma omp critical
146     fftw_plan_with_nthreads(omp_get_max_threads());
147 #endif
148 fftwf_plan plan;
149 #pragma omp critical
150     plan = fftwf_plan_many_dft_c2r(rank, n, howmany,
151                                               in,  inembed, istride, idist,
152                                               out, onembed, ostride, odist,
153                                               FFTW_ESTIMATE);
154 #ifdef ASYNC
155     lock.unlock();
156 #endif
157     fftwf_execute(plan);
158 #ifdef ASYNC
159     lock.lock();
160 #endif
161 #pragma omp critical
162     fftwf_destroy_plan(plan);
163 #ifdef ASYNC
164     lock.unlock();
165 #endif
166     return real_result/(m_width*m_height);
167 }
168
169 Fftw::~Fftw()
170 {
171 }