]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/complexmat.hpp
complexmat: Move CPU-only methods from .hpp to .cpp
[hercules2020/kcf.git] / src / complexmat.hpp
1 #ifndef COMPLEX_MAT_HPP_213123048309482094
2 #define COMPLEX_MAT_HPP_213123048309482094
3
4 #include <opencv2/opencv.hpp>
5 #include <vector>
6 #include <algorithm>
7 #include <functional>
8 #include "dynmem.hpp"
9 #include "pragmas.h"
10
11 #ifdef CUFFT
12 #include <cufft.h>
13 #endif
14
15 class ComplexMat_ {
16   public:
17     typedef float T;
18
19     uint cols;
20     uint rows;
21     uint n_channels;
22     uint n_scales;
23
24     ComplexMat_(uint _rows, uint _cols, uint _n_channels, uint _n_scales = 1)
25         : cols(_cols), rows(_rows), n_channels(_n_channels * _n_scales), n_scales(_n_scales),
26           p_data(n_channels * cols * rows) {}
27     ComplexMat_(cv::Size size, uint _n_channels, uint _n_scales = 1)
28         : cols(size.width), rows(size.height), n_channels(_n_channels * _n_scales), n_scales(_n_scales)
29         , p_data(n_channels * cols * rows) {}
30
31     // assuming that mat has 2 channels (real, img)
32     ComplexMat_(const cv::Mat &mat) : cols(uint(mat.cols)), rows(uint(mat.rows)), n_channels(1), n_scales(1)
33                                     , p_data(n_channels * cols * rows)
34     {
35         memcpy(p_data.hostMem(), mat.ptr<std::complex<T>>(), mat.total() * mat.elemSize());
36     }
37
38     static ComplexMat_ same_size(const ComplexMat_ &o)
39     {
40         return ComplexMat_(o.rows, o.cols, o.n_channels / o.n_scales, o.n_scales);
41     }
42
43     // cv::Mat API compatibility
44     cv::Size size() const { return cv::Size(cols, rows); }
45     uint channels() const { return n_channels; }
46
47     // assuming that mat has 2 channels (real, imag)
48     void set_channel(uint idx, const cv::Mat &mat)
49     {
50         assert(idx < n_channels);
51         for (uint i = 0; i < rows; ++i) {
52             const std::complex<T> *row = mat.ptr<std::complex<T>>(i);
53             for (uint j = 0; j < cols; ++j)
54                 p_data.hostMem()[idx * rows * cols + i * cols + j] = row[j];
55         }
56     }
57
58     T sqr_norm() const;
59
60     void sqr_norm(DynMem_<T> &result) const;
61
62     ComplexMat_ sqr_mag() const;
63
64     ComplexMat_ conj() const;
65
66     ComplexMat_ sum_over_channels() const;
67
68     // return 2 channels (real, imag) for first complex channel
69     cv::Mat to_cv_mat() const
70     {
71         assert(p_data.num_elem >= 1);
72         return channel_to_cv_mat(0);
73     }
74     // return a vector of 2 channels (real, imag) per one complex channel
75     std::vector<cv::Mat> to_cv_mat_vector() const
76     {
77         std::vector<cv::Mat> result;
78         result.reserve(n_channels);
79
80         for (uint i = 0; i < n_channels; ++i)
81             result.push_back(channel_to_cv_mat(i));
82
83         return result;
84     }
85
86     std::complex<T> *get_p_data() { return p_data.hostMem(); }
87     const std::complex<T> *get_p_data() const { return p_data.hostMem(); }
88
89 #ifdef CUFFT
90     cufftComplex *get_dev_data() { return (cufftComplex*)p_data.deviceMem(); }
91     const cufftComplex *get_dev_data() const { return (cufftComplex*)p_data.deviceMem(); }
92 #endif
93
94     // element-wise per channel multiplication, division and addition
95     ComplexMat_ operator*(const ComplexMat_ &rhs) const;
96     ComplexMat_ operator/(const ComplexMat_ &rhs) const;
97     ComplexMat_ operator+(const ComplexMat_ &rhs) const;
98
99     // multiplying or adding constant
100     ComplexMat_ operator*(const T &rhs) const;
101     ComplexMat_ operator+(const T &rhs) const;
102
103     // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
104     ComplexMat_ mul(const ComplexMat_ &rhs) const;
105
106     // multiplying element-wise multichannel mats - same as operator*(ComplexMat), but without allocating memory for the result
107     ComplexMat_ muln(const ComplexMat_ &rhs) const
108     {
109         return mat_mat_operator([](std::complex<T> &c_lhs, const std::complex<T> &c_rhs) { c_lhs *= c_rhs; }, rhs);
110     }
111
112     // text output
113     friend std::ostream &operator<<(std::ostream &os, const ComplexMat_ &mat)
114     {
115         // for (int i = 0; i < mat.n_channels; ++i){
116         for (int i = 0; i < 1; ++i) {
117             os << "Channel " << i << std::endl;
118             for (uint j = 0; j < mat.rows; ++j) {
119                 for (uint k = 0; k < mat.cols - 1; ++k)
120                     os << mat.p_data[j * mat.cols + k] << ", ";
121                 os << mat.p_data[j * mat.cols + mat.cols - 1] << std::endl;
122             }
123         }
124         return os;
125     }
126
127   private:
128     DynMem_<std::complex<T>> p_data;
129
130     // convert 2 channel mat (real, imag) to vector row-by-row
131     std::vector<std::complex<T>> convert(const cv::Mat &mat)
132     {
133         std::vector<std::complex<T>> result;
134         result.reserve(mat.cols * mat.rows);
135         for (int y = 0; y < mat.rows; ++y) {
136             const T *row_ptr = mat.ptr<T>(y);
137             for (int x = 0; x < 2 * mat.cols; x += 2) {
138                 result.push_back(std::complex<T>(row_ptr[x], row_ptr[x + 1]));
139             }
140         }
141         return result;
142     }
143
144     ComplexMat_ mat_mat_operator(void (*op)(std::complex<T> &c_lhs, const std::complex<T> &c_rhs),
145                                  const ComplexMat_ &mat_rhs) const;
146     ComplexMat_ matn_mat1_operator(void (*op)(std::complex<T> &c_lhs, const std::complex<T> &c_rhs),
147                                    const ComplexMat_ &mat_rhs) const;
148     ComplexMat_ matn_mat2_operator(void (*op)(std::complex<T> &c_lhs, const std::complex<T> &c_rhs),
149                                    const ComplexMat_ &mat_rhs) const;
150     ComplexMat_ mat_const_operator(const std::function<void(std::complex<T> &c_rhs)> &op) const;
151
152     cv::Mat channel_to_cv_mat(int channel_id) const
153     {
154         cv::Mat result(rows, cols, CV_32FC2);
155         for (uint y = 0; y < rows; ++y) {
156             std::complex<T> *row_ptr = result.ptr<std::complex<T>>(y);
157             for (uint x = 0; x < cols; ++x) {
158                 row_ptr[x] = p_data[channel_id * rows * cols + y * cols + x];
159             }
160         }
161         return result;
162     }
163 };
164
165 typedef ComplexMat_ ComplexMat;
166
167 #endif // COMPLEX_MAT_HPP_213123048309482094