]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_fftw.cpp
Simplify cufftw fix
[hercules2020/kcf.git] / src / fft_fftw.cpp
1 #include "fft_fftw.h"
2
3 #include "fft.h"
4 #ifndef CUFFTW
5   #include <fftw3.h>
6 #else
7   #include <cufftw.h>
8 #endif //CUFFTW
9
10 #if defined(OPENMP)
11   #include <omp.h>
12 #endif
13
14 Fftw::Fftw()
15 {
16 }
17
18 void Fftw::init(unsigned width, unsigned height)
19 {
20     m_width = width;
21     m_height = height;
22
23 #if defined(ASYNC) || defined(OPENMP)
24     fftw_init_threads();
25 #endif //OPENMP
26
27 #ifndef CUFFTW
28     std::cout << "FFT: FFTW" << std::endl;
29 #else
30     std::cout << "FFT: cuFFTW" << std::endl;
31 #endif
32 }
33
34 void Fftw::set_window(const cv::Mat &window)
35 {
36     m_window = window;
37 }
38
39 ComplexMat Fftw::forward(const cv::Mat &input)
40 {
41     cv::Mat complex_result(m_height, m_width / 2 + 1, CV_32FC2);
42 #ifdef ASYNC
43     std::unique_lock<std::mutex> lock(fftw_mut);
44     fftw_plan_with_nthreads(2);
45 #endif
46     fftwf_plan plan = fftwf_plan_dft_r2c_2d(m_height, m_width,
47                                             reinterpret_cast<float*>(input.data),
48                                             reinterpret_cast<fftwf_complex*>(complex_result.data),
49                                             FFTW_ESTIMATE);
50 #ifdef ASYNC
51     lock.unlock();
52 #endif
53     fftwf_execute(plan);
54 #ifdef ASYNC
55     lock.lock();
56 #endif
57     fftwf_destroy_plan(plan);
58 #ifdef ASYNC
59     lock.unlock();
60 #endif
61     return ComplexMat(complex_result);
62 }
63
64 ComplexMat Fftw::forward_window(const std::vector<cv::Mat> &input)
65 {
66     int n_channels = input.size();
67     cv::Mat in_all(m_height * n_channels, m_width, CV_32F);
68     for (int i = 0; i < n_channels; ++i) {
69         cv::Mat in_roi(in_all, cv::Rect(0, i*m_height, m_width, m_height));
70         in_roi = input[i].mul(m_window);
71     }
72     cv::Mat complex_result(n_channels*m_height, m_width/2+1, CV_32FC2);
73
74     int rank = 2;
75     int n[] = {(int)m_height, (int)m_width};
76     int howmany = n_channels;
77     int idist = m_height*m_width, odist = m_height*(m_width/2+1);
78     int istride = 1, ostride = 1;
79     int *inembed = NULL, *onembed = NULL;
80     float *in = reinterpret_cast<float*>(in_all.data);
81     fftwf_complex *out = reinterpret_cast<fftwf_complex*>(complex_result.data);
82 #ifdef ASYNC
83     std::unique_lock<std::mutex> lock(fftw_mut);
84     fftw_plan_with_nthreads(2);
85 #endif
86     fftwf_plan plan = fftwf_plan_many_dft_r2c(rank, n, howmany,
87                                               in,  inembed, istride, idist,
88                                               out, onembed, ostride, odist,
89                                               FFTW_ESTIMATE);
90 #ifdef ASYNC
91     lock.unlock();
92 #endif
93     fftwf_execute(plan);
94 #ifdef ASYNC
95     lock.lock();
96 #endif
97     fftwf_destroy_plan(plan);
98 #ifdef ASYNC
99     lock.unlock();
100 #endif
101
102     ComplexMat result(m_height, m_width/2 + 1, n_channels);
103     for (int i = 0; i < n_channels; ++i)
104         result.set_channel(i, complex_result(cv::Rect(0, i*m_height, m_width/2+1, m_height)));
105
106     return result;
107 }
108
109 cv::Mat Fftw::inverse(const ComplexMat &inputf)
110 {
111     int n_channels = inputf.n_channels;
112     cv::Mat real_result(m_height, m_width, CV_32FC(n_channels));
113     cv::Mat complex_vconcat = inputf.to_vconcat_mat();
114
115     int rank = 2;
116     int n[] = {(int)m_height, (int)m_width};
117     int howmany = n_channels;
118     int idist = m_height*(m_width/2+1), odist = 1;
119     int istride = 1, ostride = n_channels;
120     int inembed[] = {(int)m_height, (int)m_width/2+1}, *onembed = n;
121     fftwf_complex *in = reinterpret_cast<fftwf_complex*>(complex_vconcat.data);
122     float *out = reinterpret_cast<float*>(real_result.data);
123 #if defined(ASYNC) || defined(OPENMP)
124     std::unique_lock<std::mutex> lock(fftw_mut);
125     fftw_plan_with_nthreads(2);
126 #endif
127     fftwf_plan plan = fftwf_plan_many_dft_c2r(rank, n, howmany,
128                                               in,  inembed, istride, idist,
129                                               out, onembed, ostride, odist,
130                                               FFTW_ESTIMATE);
131 #ifdef ASYNC
132     lock.unlock();
133 #endif
134     fftwf_execute(plan);
135 #ifdef ASYNC
136     lock.lock();
137 #endif
138     fftwf_destroy_plan(plan);
139 #ifdef ASYNC
140     lock.unlock();
141 #endif
142     return real_result/(m_width*m_height);
143 }
144
145 Fftw::~Fftw()
146 {
147 }