]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/complexmat.cuh
Merge branch 'master' of https://github.com/Shanigen/kcf
[hercules2020/kcf.git] / src / complexmat.cuh
1 #pragma once
2
3 #include <opencv2/opencv.hpp>
4
5 #include "cuda_runtime.h"
6 #include "cufft.h"
7
8 class ComplexMat
9 {
10 public:
11     int cols;
12     int rows;
13     int n_channels;
14     int n_scales = 1;
15     
16     ComplexMat() : cols(0), rows(0), n_channels(0) {}
17     ComplexMat(int _rows, int _cols, int _n_channels) : cols(_cols), rows(_rows), n_channels(_n_channels)
18     {
19         cudaMalloc(&p_data,  n_channels*cols*rows*sizeof(cufftComplex));
20     }
21     
22     ComplexMat(int _rows, int _cols, int _n_channels, int _n_scales) : cols(_cols), rows(_rows), n_channels(_n_channels), n_scales(_n_scales)
23     {
24         cudaMalloc(&p_data,  n_channels*cols*rows*sizeof(cufftComplex));
25     }
26     
27     ComplexMat(ComplexMat &&other)
28     {
29         cols = other.cols;
30         rows = other.rows;
31         n_channels = other.n_channels;
32         n_scales = other.n_scales;
33         p_data = other.p_data;
34         
35         other.p_data = nullptr;
36     }
37     
38     ~ComplexMat()
39     {
40         if(p_data != nullptr) cudaFree(p_data);
41     }
42
43     void create(int _rows, int _cols, int _n_channels)
44     {
45         rows = _rows;
46         cols = _cols;
47         n_channels = _n_channels;
48         cudaMalloc(&p_data,  n_channels*cols*rows*sizeof(cufftComplex));
49     }
50
51     void create(int _rows, int _cols, int _n_channels, int _n_scales)
52     {
53         rows = _rows;
54         cols = _cols;
55         n_channels = _n_channels;
56         n_scales = _n_scales;
57         cudaMalloc(&p_data,  n_channels*cols*rows*sizeof(cufftComplex));
58     }
59     // cv::Mat API compatibility
60     cv::Size size() { return cv::Size(cols, rows); }
61     int channels() { return n_channels; }
62     int channels() const { return n_channels; }
63
64     void sqr_norm(float *result) const;
65     
66     ComplexMat sqr_mag() const;
67
68     ComplexMat conj() const;
69
70     ComplexMat sum_over_channels() const;
71
72     cufftComplex* get_p_data() const;
73
74     //element-wise per channel multiplication, division and addition
75     ComplexMat operator*(const ComplexMat & rhs) const;
76     ComplexMat operator/(const ComplexMat & rhs) const;
77     ComplexMat operator+(const ComplexMat & rhs) const;
78     
79     //multiplying or adding constant
80     ComplexMat operator*(const float & rhs) const;
81     ComplexMat operator+(const float & rhs) const;
82
83     //multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
84     ComplexMat mul(const ComplexMat & rhs) const;
85
86     //multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
87     ComplexMat mul2(const ComplexMat & rhs) const;
88     //text output
89     friend std::ostream & operator<<(std::ostream & os, const ComplexMat & mat)
90     {
91         //for (int i = 0; i < mat.n_channels; ++i){
92         for (int i = 0; i < 1; ++i){
93             os << "Channel " << i << std::endl;
94             for (int j = 0; j < mat.rows; ++j) {
95                 for (int k = 0; k < 2*mat.cols-2; k+=2)
96                     os << "(" << mat.p_data[j*2*mat.cols + k] << "," << mat.p_data[j*2*mat.cols + (k+1)] << ")" << ", ";
97                 os << "(" << mat.p_data[j*2*mat.cols + 2*mat.cols-2] << "," << mat.p_data[j*2*mat.cols + 2*mat.cols-1] << ")" <<  std::endl;
98             }
99         }
100         return os;
101     }
102     
103     void operator=(ComplexMat & rhs);
104     void operator=(ComplexMat && rhs);
105
106
107 private:
108     mutable float *p_data = nullptr;
109 };