1 #include "complexmat.cuh"
3 __global__ void sqr_norm_kernel(int n, float* out, float* data, float rows, float cols)
5 extern __shared__ float sdata[];
6 int i = blockDim.x * threadIdx.y + threadIdx.x;
7 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
8 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
11 sdata[i] = data[threadId]*data[threadId] + data[threadId+1]*data[threadId+1];
14 for (unsigned int s=(blockDim.x*blockDim.y+1)/2, old_s = blockDim.x*blockDim.y;s>0; s>>=1) {
18 if (i < s && i+s < old_s) {
19 sdata[i] += sdata[i + s];
26 atomicAdd(&out[blockId/n], sdata[0]/(rows*cols));
30 void ComplexMat::sqr_norm(float *result) const
32 cudaMemset(result, 0, n_scales*sizeof(float));
34 dim3 threadsPerBlock(rows, cols);
35 dim3 numBlocks(n_channels/n_scales, n_scales);
37 sqr_norm_kernel<<<numBlocks, threadsPerBlock, rows*cols*sizeof(float)>>>(n_channels/n_scales, result, p_data, rows, cols);
42 __global__ void sqr_mag_kernel(float* data, float* result)
44 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
45 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
47 result[threadId] = data[threadId]*data[threadId] + data[threadId+1]*data[threadId+1];
48 result[threadId+1] = 0;
51 ComplexMat ComplexMat::sqr_mag() const
53 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
55 dim3 threadsPerBlock(rows, cols);
56 dim3 numBlocks(n_channels/n_scales, n_scales);
57 sqr_mag_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, result.p_data);
62 __global__ void conj_kernel(float* data, float* result)
64 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
65 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
67 result[threadId] = data[threadId];
68 result[threadId+1] = -data[threadId+1];
71 ComplexMat ComplexMat::conj() const
73 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
75 dim3 threadsPerBlock(rows, cols);
76 dim3 numBlocks(n_channels/n_scales, n_scales);
77 conj_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, result.p_data);
81 ComplexMat ComplexMat::sum_over_channels() const
83 // assert(p_data.size() > 1);
84 ComplexMat result(this->rows, this->cols, 1);
88 cufftComplex* ComplexMat::get_p_data() const
90 return (cufftComplex*) p_data;
93 __global__ void same_num_channels_mul_kernel(float* data_l, float* data_r, float* result)
95 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
96 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
98 result[threadId] = data_l[threadId]*data_r[threadId] - data_l[threadId+1]*data_r[threadId+1];
99 result[threadId+1] = data_l[threadId]*data_r[threadId+1] + data_l[threadId+1]*data_r[threadId];
102 //element-wise per channel multiplication, division and addition
103 ComplexMat ComplexMat::operator*(const ComplexMat & rhs) const
105 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
107 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
109 dim3 threadsPerBlock(rows, cols);
110 dim3 numBlocks(n_channels/n_scales, n_scales);
111 same_num_channels_mul_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs.p_data, result.p_data);
116 __global__ void same_num_channels_div_kernel(float* data_l, float* data_r, float* result)
118 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
119 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
121 result[threadId] = (data_l[threadId]*data_r[threadId] + data_l[threadId+1]*data_r[threadId+1])/
122 (data_r[threadId]*data_r[threadId] + data_r[threadId+1]*data_r[threadId+1]);
123 result[threadId+1] = (data_l[threadId+1]*data_r[threadId] - data_l[threadId]*data_r[threadId+1])/
124 (data_r[threadId]*data_r[threadId] + data_r[threadId+1]*data_r[threadId+1]);
127 ComplexMat ComplexMat::operator/(const ComplexMat & rhs) const
129 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
131 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
133 dim3 threadsPerBlock(rows, cols);
134 dim3 numBlocks(n_channels/n_scales, n_scales);
135 same_num_channels_div_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs.p_data, result.p_data);
140 __global__ void same_num_channels_add_kernel(float* data_l, float* data_r, float* result)
142 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
143 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
145 result[threadId] = data_l[threadId]+data_r[threadId];
146 result[threadId+1] = data_l[threadId+1]+data_r[threadId+1];
149 ComplexMat ComplexMat::operator+(const ComplexMat & rhs) const
151 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
153 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
155 dim3 threadsPerBlock(rows, cols);
156 dim3 numBlocks(n_channels/n_scales, n_scales);
157 same_num_channels_add_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs.p_data, result.p_data);
162 __global__ void constant_mul_kernel(float* data_l, float constant, float* result)
164 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
165 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
167 result[threadId] = data_l[threadId]*constant;
168 result[threadId+1] = data_l[threadId+1]*constant;
171 ComplexMat ComplexMat::operator*(const float & rhs) const
173 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
175 dim3 threadsPerBlock(rows, cols);
176 dim3 numBlocks(n_channels/n_scales, n_scales);
177 constant_mul_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs, result.p_data);
182 __global__ void constant_add_kernel(float* data_l, float constant, float* result)
184 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
185 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
187 result[threadId] = data_l[threadId]+constant;
188 result[threadId+1] = data_l[threadId+1];
191 ComplexMat ComplexMat::operator+(const float & rhs) const
193 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
195 dim3 threadsPerBlock(rows, cols);
196 dim3 numBlocks(n_channels/n_scales, n_scales);
197 constant_add_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs, result.p_data);
202 __global__ void one_channel_mul_kernel(float* data_l, float* data_r, float* result)
204 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
205 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
206 int one_ch_index = 2*((threadIdx.y * blockDim.x) + threadIdx.x);
208 result[threadId] = data_l[threadId]*data_r[one_ch_index] - data_l[threadId+1]*data_r[one_ch_index+1];
209 result[threadId+1] = data_l[threadId]*data_r[one_ch_index+1] + data_l[threadId+1]*data_r[one_ch_index];
212 //multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
213 ComplexMat ComplexMat::mul(const ComplexMat & rhs) const
215 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
217 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
219 dim3 threadsPerBlock(rows, cols);
220 dim3 numBlocks(n_channels/n_scales, n_scales);
221 one_channel_mul_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs.p_data, result.p_data);
226 __global__ void scales_channel_mul_kernel(float* data_l, float* data_r, float* result)
228 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
229 int threadId = 2*(blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
230 int one_ch_index = 2*((threadIdx.y * blockDim.x) + threadIdx.x+blockIdx.x*blockDim.x*blockDim.y);
232 result[threadId] = data_l[threadId]*data_r[one_ch_index] - data_l[threadId+1]*data_r[one_ch_index+1];
233 result[threadId+1] = data_l[threadId]*data_r[one_ch_index+1] + data_l[threadId+1]*data_r[one_ch_index];
236 //multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
237 ComplexMat ComplexMat::mul2(const ComplexMat & rhs) const
239 assert(rhs.n_channels == n_channels/n_scales && rhs.cols == cols && rhs.rows == rows);
241 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
243 dim3 threadsPerBlock(rows, cols);
244 dim3 numBlocks(n_channels/n_scales, n_scales);
245 scales_channel_mul_kernel<<<numBlocks, threadsPerBlock>>>(this->p_data, rhs.p_data, result.p_data);
250 void ComplexMat::operator=(ComplexMat & rhs)
254 n_channels = rhs.n_channels;
255 n_scales = rhs.n_scales;
261 void ComplexMat::operator=(ComplexMat && rhs)
265 n_channels = rhs.n_channels;
266 n_scales = rhs.n_scales;
270 rhs.p_data = nullptr;