From 6c04a2a6535ae54b1d8e815066c72b5b56f8f414 Mon Sep 17 00:00:00 2001 From: Michal Sojka Date: Sun, 16 Sep 2018 23:31:07 +0200 Subject: [PATCH] Add a method for calculation of FFT result size Also remove CUFFT ifdef, because both branches are the same since removal of explicit streams a few commits ago. --- src/fft.h | 9 +++++++++ src/threadctx.hpp | 18 +++++------------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/fft.h b/src/fft.h index c7a8b8f..cc1393d 100644 --- a/src/fft.h +++ b/src/fft.h @@ -26,6 +26,15 @@ public: virtual void forward_window(std::vector patch_feats, ComplexMat & complex_result, cv::Mat & fw_all, float *real_input_arr) = 0; virtual void inverse(ComplexMat & complex_input, cv::Mat & real_result, float *real_result_arr) = 0; virtual ~Fft() = 0; + + static cv::Size freq_size(cv::Size space_size) + { + cv::Size ret(space_size); +#if defined(CUFFT) || defined(FFTW) + ret.width = space_size.width / 2 + 1; +#endif + return ret; + } }; #endif // FFT_H diff --git a/src/threadctx.hpp b/src/threadctx.hpp index 6c46037..fa95704 100644 --- a/src/threadctx.hpp +++ b/src/threadctx.hpp @@ -20,18 +20,15 @@ struct ThreadCtx { , gc(num_of_scales) { uint cells_size = roi.width * roi.height * sizeof(float); + cv::Size freq_size = Fft::freq_size(roi); #if defined(CUFFT) || defined(FFTW) this->gauss_corr_res = DynMem(cells_size * num_of_scales); this->data_features = DynMem(cells_size * num_channels); - uint width_freq = roi.width / 2 + 1; - this->in_all = cv::Mat(roi.height * num_of_scales, roi.width, CV_32F, this->gauss_corr_res.hostMem()); this->fw_all = cv::Mat(roi.height * num_channels, roi.width, CV_32F, this->data_features.hostMem()); #else - uint width_freq = roi.width; - this->in_all = cv::Mat(roi, CV_32F); #endif @@ -41,15 +38,9 @@ struct ThreadCtx { this->ifft2_res = cv::Mat(roi, CV_32FC(num_channels), this->data_i_features.hostMem()); this->response = cv::Mat(roi, CV_32FC(num_of_scales), this->data_i_1ch.hostMem()); -#ifdef CUFFT - this->zf.create(roi.height, width_freq, num_channels, num_of_scales); - this->kzf.create(roi.height, width_freq, num_of_scales); - this->kf.create(roi.height, width_freq, num_of_scales); -#else - this->zf.create(roi.height, width_freq, num_channels, num_of_scales); - this->kzf.create(roi.height, width_freq, num_of_scales); - this->kf.create(roi.height, width_freq, num_of_scales); -#endif + this->zf.create(freq_size.height, freq_size.width, num_channels, num_of_scales); + this->kzf.create(freq_size.height, freq_size.width, num_of_scales); + this->kf.create(freq_size.height, freq_size.width, num_of_scales); #ifdef BIG_BATCH if (num_of_scales > 1) { @@ -59,6 +50,7 @@ struct ThreadCtx { } #endif } + ThreadCtx(ThreadCtx &&) = default; const double scale; -- 2.39.2