3 #include <opencv2/opencv.hpp>
5 #include "cuda_runtime.h"
16 ComplexMat() : cols(0), rows(0), n_channels(0) {}
17 ComplexMat(int _rows, int _cols, int _n_channels) : cols(_cols), rows(_rows), n_channels(_n_channels)
19 cudaMalloc(&p_data, n_channels*cols*rows*sizeof(cufftComplex));
22 ComplexMat(int _rows, int _cols, int _n_channels, int _n_scales) : cols(_cols), rows(_rows), n_channels(_n_channels), n_scales(_n_scales)
24 cudaMalloc(&p_data, n_channels*cols*rows*sizeof(cufftComplex));
27 ComplexMat(ComplexMat &&other)
31 n_channels = other.n_channels;
32 n_scales = other.n_scales;
33 p_data = other.p_data;
35 other.p_data = nullptr;
40 if(p_data != nullptr) cudaFree(p_data);
43 void create(int _rows, int _cols, int _n_channels)
47 n_channels = _n_channels;
48 cudaMalloc(&p_data, n_channels*cols*rows*sizeof(cufftComplex));
51 void create(int _rows, int _cols, int _n_channels, int _n_scales)
55 n_channels = _n_channels;
57 cudaMalloc(&p_data, n_channels*cols*rows*sizeof(cufftComplex));
59 // cv::Mat API compatibility
60 cv::Size size() { return cv::Size(cols, rows); }
61 int channels() { return n_channels; }
62 int channels() const { return n_channels; }
64 void sqr_norm(float *result) const;
66 ComplexMat sqr_mag() const;
68 ComplexMat conj() const;
70 ComplexMat sum_over_channels() const;
72 cufftComplex* get_p_data() const;
74 //element-wise per channel multiplication, division and addition
75 ComplexMat operator*(const ComplexMat & rhs) const;
76 ComplexMat operator/(const ComplexMat & rhs) const;
77 ComplexMat operator+(const ComplexMat & rhs) const;
79 //multiplying or adding constant
80 ComplexMat operator*(const float & rhs) const;
81 ComplexMat operator+(const float & rhs) const;
83 //multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
84 ComplexMat mul(const ComplexMat & rhs) const;
86 //multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
87 ComplexMat mul2(const ComplexMat & rhs) const;
89 friend std::ostream & operator<<(std::ostream & os, const ComplexMat & mat)
91 //for (int i = 0; i < mat.n_channels; ++i){
92 for (int i = 0; i < 1; ++i){
93 os << "Channel " << i << std::endl;
94 for (int j = 0; j < mat.rows; ++j) {
95 for (int k = 0; k < 2*mat.cols-2; k+=2)
96 os << "(" << mat.p_data[j*2*mat.cols + k] << "," << mat.p_data[j*2*mat.cols + (k+1)] << ")" << ", ";
97 os << "(" << mat.p_data[j*2*mat.cols + 2*mat.cols-2] << "," << mat.p_data[j*2*mat.cols + 2*mat.cols-1] << ")" << std::endl;
103 void operator=(ComplexMat & rhs);
104 void operator=(ComplexMat && rhs);
108 mutable float *p_data = nullptr;