]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/fft_cufft.cpp
cufft: Move scaling from gaussian_correlation to inverse fft
[hercules2020/kcf.git] / src / fft_cufft.cpp
1 #include "fft_cufft.h"
2 #include <cublas_v2.h>
3
4 cuFFT::cuFFT()
5 {
6     CublasErrorCheck(cublasCreate(&cublas));
7 }
8
9 void cuFFT::init(unsigned width, unsigned height, unsigned num_of_feats, unsigned num_of_scales)
10 {
11     Fft::init(width, height, num_of_feats, num_of_scales);
12
13     std::cout << "FFT: cuFFT" << std::endl;
14
15     // FFT forward
16     {
17         int rank = 2;
18         int n[] = {int(m_height), int(m_width)};
19         int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
20         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
21         int istride = 1, ostride = 1;
22         int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
23
24         CufftErrorCheck(cufftPlanMany(&plan_f, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
25         CufftErrorCheck(cufftSetStream(plan_f, cudaStreamPerThread));
26     }
27
28     // FFT forward window
29     {
30         int rank = 2;
31         int n[] = {int(m_height), int(m_width)};
32         int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
33         int idist = m_height * m_width, odist = m_height * (m_width / 2 + 1);
34         int istride = 1, ostride = 1;
35         int *inembed = n, onembed[] = {int(m_height), int(m_width) / 2 + 1};
36
37         CufftErrorCheck(cufftPlanMany(&plan_fw, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_R2C, howmany));
38         CufftErrorCheck(cufftSetStream(plan_fw, cudaStreamPerThread));
39     }
40     // FFT inverse all channels
41     {
42         int rank = 2;
43         int n[] = {int(m_height), int(m_width)};
44         int howmany = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
45         int idist = int(m_height * (m_width / 2 + 1)), odist = 1;
46         int istride = 1, ostride = m_num_of_feats * IF_BIG_BATCH(m_num_of_scales, 1);
47         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
48
49         CufftErrorCheck(cufftPlanMany(&plan_i_features, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
50         CufftErrorCheck(cufftSetStream(plan_i_features, cudaStreamPerThread));
51     }
52     // FFT inverse one channel
53     {
54         int rank = 2;
55         int n[] = {int(m_height), int(m_width)};
56         int howmany = IF_BIG_BATCH(m_num_of_scales, 1);
57         int idist = m_height * (m_width / 2 + 1), odist = 1;
58         int istride = 1, ostride = IF_BIG_BATCH(m_num_of_scales, 1);
59         int inembed[] = {int(m_height), int(m_width / 2 + 1)}, *onembed = n;
60
61         CufftErrorCheck(cufftPlanMany(&plan_i_1ch, rank, n, inembed, istride, idist, onembed, ostride, odist, CUFFT_C2R, howmany));
62         CufftErrorCheck(cufftSetStream(plan_i_1ch, cudaStreamPerThread));
63     }
64 }
65
66 void cuFFT::set_window(const MatDynMem &window)
67 {
68     Fft::set_window(window);
69     m_window = window;
70 }
71
72 void cuFFT::forward(const MatDynMem &real_input, ComplexMat &complex_result)
73 {
74     Fft::forward(real_input, complex_result);
75     auto in = static_cast<cufftReal *>(const_cast<MatDynMem&>(real_input).deviceMem());
76
77     CufftErrorCheck(cufftExecR2C(plan_f, in, complex_result.get_p_data()));
78 }
79
80 void cuFFT::forward_window(MatDynMem &feat, ComplexMat &complex_result, MatDynMem &temp)
81 {
82     Fft::forward_window(feat, complex_result, temp);
83
84     uint n_channels = feat.size[0];
85     cufftReal *temp_data = temp.deviceMem();
86
87     assert(feat.dims == 3);
88     assert(n_channels == m_num_of_feats || n_channels == m_num_of_feats * m_num_of_scales);
89
90     for (uint i = 0; i < n_channels; ++i) {
91         cv::Mat feat_plane(feat.dims - 1, feat.size + 1, feat.cv::Mat::type(), feat.ptr<void>(i));
92         cv::Mat temp_plane(temp.dims - 1, temp.size + 1, temp.cv::Mat::type(), temp.ptr(i));
93         temp_plane = feat_plane.mul(m_window);
94     }
95     CufftErrorCheck(cufftExecR2C(plan_fw, temp_data, complex_result.get_p_data()));
96 }
97
98 void cuFFT::inverse(ComplexMat &complex_input, MatDynMem &real_result)
99 {
100     Fft::inverse(complex_input, real_result);
101
102     uint n_channels = complex_input.n_channels;
103     cufftComplex *in = reinterpret_cast<cufftComplex *>(complex_input.get_p_data());
104     cufftReal *out = real_result.deviceMem();
105     float alpha = 1.0 / (m_width * m_height);
106
107     if (n_channels == 1) {
108         CufftErrorCheck(cufftExecC2R(plan_i_1ch, in, out));
109     } else {
110         CufftErrorCheck(cufftExecC2R(plan_i_features, in, out));
111     }
112     // TODO: Investigate whether this scalling is needed or not
113     CublasErrorCheck(cublasSscal(cublas, real_result.total(), &alpha, out, 1));
114 }
115
116 cuFFT::~cuFFT()
117 {
118     CublasErrorCheck(cublasDestroy(cublas));
119
120     CufftErrorCheck(cufftDestroy(plan_f));
121     CufftErrorCheck(cufftDestroy(plan_fw));
122     CufftErrorCheck(cufftDestroy(plan_i_1ch));
123     CufftErrorCheck(cufftDestroy(plan_i_features));
124 }