]> rtime.felk.cvut.cz Git - hercules2020/kcf.git/blob - src/complexmat.cuh
Merge remote-tracking branch 'upstream/master' into rotation
[hercules2020/kcf.git] / src / complexmat.cuh
1 #ifndef COMPLEXMAT_H
2 #define COMPLEXMAT_H
3
4 #include <opencv2/opencv.hpp>
5
6 #include "dynmem.hpp"
7 #include "cuda_runtime.h"
8 #include "cufft.h"
9
10 #include "cuda/cuda_error_check.cuh"
11
12 class ComplexMat {
13   public:
14     uint cols;
15     uint rows;
16     uint n_channels;
17     uint n_scales = 1;
18     bool foreign_data = false;
19     cudaStream_t stream = nullptr;
20
21     ComplexMat() : cols(0), rows(0), n_channels(0) {}
22     ComplexMat(uint _rows, uint _cols, uint _n_channels, cudaStream_t _stream)
23         : cols(_cols), rows(_rows), n_channels(_n_channels), stream(_stream)
24     {
25         CudaSafeCall(cudaMalloc(&p_data, n_channels * cols * rows * sizeof(cufftComplex)));
26     }
27
28     ComplexMat(uint _rows, uint _cols, uint _n_channels, uint _n_scales, cudaStream_t _stream)
29         : cols(_cols), rows(_rows), n_channels(_n_channels), n_scales(_n_scales), stream(_stream)
30     {
31         CudaSafeCall(cudaMalloc(&p_data, n_channels * cols * rows * sizeof(cufftComplex)));
32     }
33
34     ComplexMat(ComplexMat &&other)
35     {
36         cols = other.cols;
37         rows = other.rows;
38         n_channels = other.n_channels;
39         n_scales = other.n_scales;
40         p_data = other.p_data;
41         stream = other.stream;
42
43         other.p_data = nullptr;
44     }
45
46     ~ComplexMat()
47     {
48         if (p_data != nullptr && !foreign_data) {
49             CudaSafeCall(cudaFree(p_data));
50             p_data = nullptr;
51         }
52     }
53
54     void create(uint _rows, uint _cols, uint _n_channels, cudaStream_t _stream = nullptr)
55     {
56         rows = _rows;
57         cols = _cols;
58         n_channels = _n_channels;
59         stream = _stream;
60         CudaSafeCall(cudaMalloc(&p_data, n_channels * cols * rows * sizeof(cufftComplex)));
61     }
62
63     void create(uint _rows, uint _cols, uint _n_channels, uint _n_scales, cudaStream_t _stream = nullptr)
64     {
65         rows = _rows;
66         cols = _cols;
67         n_channels = _n_channels;
68         n_scales = _n_scales;
69         stream = _stream;
70         CudaSafeCall(cudaMalloc(&p_data, n_channels * cols * rows * sizeof(cufftComplex)));
71     }
72     // cv::Mat API compatibility
73     cv::Size size() { return cv::Size(cols, rows); }
74     int channels() { return n_channels; }
75     int channels() const { return n_channels; }
76
77     void set_stream(cudaStream_t _stream)
78     {
79         stream = _stream;
80         return;
81     }
82
83     void sqr_norm(DynMem &result) const;
84
85     ComplexMat sqr_mag() const;
86
87     ComplexMat conj() const;
88
89     ComplexMat sum_over_channels() const;
90
91     cufftComplex *get_p_data() const;
92
93     // element-wise per channel multiplication, division and addition
94     ComplexMat operator*(const ComplexMat &rhs) const;
95     ComplexMat operator/(const ComplexMat &rhs) const;
96     ComplexMat operator+(const ComplexMat &rhs) const;
97
98     // multiplying or adding constant
99     ComplexMat operator*(const float &rhs) const;
100     ComplexMat operator+(const float &rhs) const;
101
102     // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
103     ComplexMat mul(const ComplexMat &rhs) const;
104
105     // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
106     ComplexMat mul2(const ComplexMat &rhs) const;
107     // text output
108     friend std::ostream &operator<<(std::ostream &os, const ComplexMat &mat)
109     {
110         float *data_cpu = reinterpret_cast<float*>(malloc(mat.rows * mat.cols * mat.n_channels * sizeof(cufftComplex)));
111         CudaSafeCall(cudaMemcpy(data_cpu, mat.p_data, mat.rows * mat.cols * mat.n_channels * sizeof(cufftComplex),
112                                 cudaMemcpyDeviceToHost));
113         // for (int i = 0; i < mat.n_channels; ++i){
114         for (int i = 0; i < 1; ++i) {
115             os << "Channel " << i << std::endl;
116             for (uint j = 0; j < mat.rows; ++j) {
117                 for (uint k = 0; k < 2 * mat.cols - 2; k += 2)
118                     os << "(" << data_cpu[j * 2 * mat.cols + k] << "," << data_cpu[j * 2 * mat.cols + (k + 1)] << ")"
119                        << ", ";
120                 os << "(" << data_cpu[j * 2 * mat.cols + 2 * mat.cols - 2] << ","
121                    << data_cpu[j * 2 * mat.cols + 2 * mat.cols - 1] << ")" << std::endl;
122             }
123         }
124         free(data_cpu);
125         return os;
126     }
127
128     void operator=(ComplexMat &rhs);
129     void operator=(ComplexMat &&rhs);
130
131   private:
132     mutable float *p_data = nullptr;
133 };
134
135 #endif // COMPLEXMAT_H