1 #include "complexmat.cuh"
3 __global__ void sqr_norm_kernel(int n, float *out, float *data, float rows, float cols)
5 extern __shared__ float sdata[];
6 int i = blockDim.x * threadIdx.y + threadIdx.x;
7 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
8 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
11 sdata[i] = data[threadId] * data[threadId] + data[threadId + 1] * data[threadId + 1];
14 for (unsigned int s = (blockDim.x * blockDim.y + 1) / 2, old_s = blockDim.x * blockDim.y; s > 0; s >>= 1) {
16 if (old_s & 1) s += 1;
18 if (i < s && i + s < old_s) {
19 sdata[i] += sdata[i + s];
26 atomicAdd(&out[blockId / n], sdata[0] / (rows * cols));
30 void ComplexMat::sqr_norm(DynMem &result) const
32 CudaSafeCall(cudaMemsetAsync(result.deviceMem(), 0, n_scales * sizeof(float)));
34 dim3 threadsPerBlock(rows, cols);
35 dim3 numBlocks(n_channels / n_scales, n_scales);
37 sqr_norm_kernel<<<numBlocks, threadsPerBlock, rows * cols * sizeof(float)>>>(
38 n_channels / n_scales, result.deviceMem(), this->p_data, rows, cols);
44 __global__ void sqr_mag_kernel(float *data, float *result)
46 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
47 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
49 result[threadId] = data[threadId] * data[threadId] + data[threadId + 1] * data[threadId + 1];
50 result[threadId + 1] = 0;
53 ComplexMat ComplexMat::sqr_mag() const
55 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
57 dim3 threadsPerBlock(rows, cols);
58 dim3 numBlocks(n_channels / n_scales, n_scales);
59 sqr_mag_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, result.p_data);
65 __global__ void conj_kernel(float *data, float *result)
67 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
68 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
70 result[threadId] = data[threadId];
71 result[threadId + 1] = -data[threadId + 1];
74 ComplexMat ComplexMat::conj() const
76 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
78 dim3 threadsPerBlock(rows, cols);
79 dim3 numBlocks(n_channels / n_scales, n_scales);
80 conj_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, result.p_data);
86 ComplexMat ComplexMat::sum_over_channels() const
88 // assert(p_data.size() > 1);
89 ComplexMat result(this->rows, this->cols, 1);
93 cufftComplex *ComplexMat::get_p_data() const
95 return (cufftComplex *)p_data;
98 __global__ void same_num_channels_mul_kernel(float *data_l, float *data_r, float *result)
100 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
101 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
103 result[threadId] = data_l[threadId] * data_r[threadId] - data_l[threadId + 1] * data_r[threadId + 1];
104 result[threadId + 1] = data_l[threadId] * data_r[threadId + 1] + data_l[threadId + 1] * data_r[threadId];
107 // element-wise per channel multiplication, division and addition
108 ComplexMat ComplexMat::operator*(const ComplexMat &rhs) const
110 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
112 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
114 dim3 threadsPerBlock(rows, cols);
115 dim3 numBlocks(n_channels / n_scales, n_scales);
116 same_num_channels_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data,
123 __global__ void same_num_channels_div_kernel(float *data_l, float *data_r, float *result)
125 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
126 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
128 result[threadId] = (data_l[threadId] * data_r[threadId] + data_l[threadId + 1] * data_r[threadId + 1]) /
129 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
130 result[threadId + 1] = (data_l[threadId + 1] * data_r[threadId] - data_l[threadId] * data_r[threadId + 1]) /
131 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
134 ComplexMat ComplexMat::operator/(const ComplexMat &rhs) const
136 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
138 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
140 dim3 threadsPerBlock(rows, cols);
141 dim3 numBlocks(n_channels / n_scales, n_scales);
142 same_num_channels_div_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data,
149 __global__ void same_num_channels_add_kernel(float *data_l, float *data_r, float *result)
151 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
152 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
154 result[threadId] = data_l[threadId] + data_r[threadId];
155 result[threadId + 1] = data_l[threadId + 1] + data_r[threadId + 1];
158 ComplexMat ComplexMat::operator+(const ComplexMat &rhs) const
160 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
162 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
164 dim3 threadsPerBlock(rows, cols);
165 dim3 numBlocks(n_channels / n_scales, n_scales);
166 same_num_channels_add_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data,
173 __global__ void constant_mul_kernel(float *data_l, float constant, float *result)
175 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
176 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
178 result[threadId] = data_l[threadId] * constant;
179 result[threadId + 1] = data_l[threadId + 1] * constant;
182 ComplexMat ComplexMat::operator*(const float &rhs) const
184 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
186 dim3 threadsPerBlock(rows, cols);
187 dim3 numBlocks(n_channels / n_scales, n_scales);
188 constant_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs, result.p_data);
194 __global__ void constant_add_kernel(float *data_l, float constant, float *result)
196 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
197 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
199 result[threadId] = data_l[threadId] + constant;
200 result[threadId + 1] = data_l[threadId + 1];
203 ComplexMat ComplexMat::operator+(const float &rhs) const
205 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
207 dim3 threadsPerBlock(rows, cols);
208 dim3 numBlocks(n_channels / n_scales, n_scales);
209 constant_add_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs, result.p_data);
215 __global__ void one_channel_mul_kernel(float *data_l, float *data_r, float *result)
217 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
218 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
219 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x);
221 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
222 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
225 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
226 ComplexMat ComplexMat::mul(const ComplexMat &rhs) const
228 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
230 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
232 dim3 threadsPerBlock(rows, cols);
233 dim3 numBlocks(n_channels / n_scales, n_scales);
234 one_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data, result.p_data);
240 __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
242 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
243 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
244 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
246 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
247 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
250 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
251 ComplexMat ComplexMat::mul2(const ComplexMat &rhs) const
253 assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
255 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
257 dim3 threadsPerBlock(rows, cols);
258 dim3 numBlocks(n_channels / n_scales, n_scales);
259 scales_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data, result.p_data);
265 void ComplexMat::operator=(ComplexMat &rhs)
269 n_channels = rhs.n_channels;
270 n_scales = rhs.n_scales;
275 void ComplexMat::operator=(ComplexMat &&rhs)
279 n_channels = rhs.n_channels;
280 n_scales = rhs.n_scales;
284 rhs.p_data = nullptr;