1 #include "complexmat.hpp"
3 __global__ void sqr_norm_kernel(int n, float *out, const float *data, float rows, float cols)
5 extern __shared__ float sdata[];
6 int i = blockDim.x * threadIdx.y + threadIdx.x;
7 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
8 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
11 sdata[i] = data[threadId] * data[threadId] + data[threadId + 1] * data[threadId + 1];
14 for (unsigned int s = (blockDim.x * blockDim.y + 1) / 2, old_s = blockDim.x * blockDim.y; s > 0; s >>= 1) {
16 if (old_s & 1) s += 1;
18 if (i < s && i + s < old_s) {
19 sdata[i] += sdata[i + s];
26 atomicAdd(&out[blockId / n], sdata[0] / (rows * cols));
30 void ComplexMat::sqr_norm(DynMem &result) const
32 CudaSafeCall(cudaMemsetAsync(result.deviceMem(), 0, n_scales * sizeof(float)));
34 dim3 threadsPerBlock(rows, cols);
35 dim3 numBlocks(n_channels / n_scales, n_scales);
37 sqr_norm_kernel<<<numBlocks, threadsPerBlock, rows * cols * sizeof(float)>>>(
38 n_channels / n_scales, result.deviceMem(), (float*)this->p_data.deviceMem(), rows, cols);
44 __global__ void sqr_mag_kernel(const float *data, float *result)
46 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
47 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
49 result[threadId] = data[threadId] * data[threadId] + data[threadId + 1] * data[threadId + 1];
50 result[threadId + 1] = 0;
53 ComplexMat ComplexMat::sqr_mag() const
55 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
57 dim3 threadsPerBlock(rows, cols);
58 dim3 numBlocks(n_channels / n_scales, n_scales);
59 sqr_mag_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem());
65 __global__ void conj_kernel(const float *data, float *result)
67 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
68 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
70 result[threadId] = data[threadId];
71 result[threadId + 1] = -data[threadId + 1];
74 ComplexMat ComplexMat::conj() const
76 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
78 dim3 threadsPerBlock(rows, cols);
79 dim3 numBlocks(n_channels / n_scales, n_scales);
80 conj_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem());
86 __global__ static void sum_channels(float *dest, const float *src, uint channels, uint num_channel_elem)
88 int idx = blockIdx.x * blockDim.x + threadIdx.x;
90 if (idx >= num_channel_elem)
94 for (uint i = 0; i < channels; ++i)
95 acc += src[idx + i * num_channel_elem];
99 ComplexMat ComplexMat::sum_over_channels() const
101 assert(p_data.num_elem == n_channels * rows * cols);
103 uint n_channels_per_scale = n_channels / n_scales;
104 uint scale_offset = n_channels_per_scale * rows * cols;
106 ComplexMat_ result(this->rows, this->cols, 1, n_scales);
108 const uint total = rows * cols * 2;
109 const dim3 threads(256);
110 const dim3 blocks((total + threads.x - 1) / threads.x);
112 for (uint scale = 0; scale < n_scales; ++scale) {
113 sum_channels<<<blocks, threads>>>(reinterpret_cast<float*>(result.p_data.deviceMem() + scale * scale_offset),
114 reinterpret_cast<const float*>(p_data.deviceMem() + scale * scale_offset),
115 n_channels_per_scale, total);
117 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
121 __global__ void same_num_channels_mul_kernel(const float *data_l, const float *data_r, float *result)
123 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
124 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
126 result[threadId] = data_l[threadId] * data_r[threadId] - data_l[threadId + 1] * data_r[threadId + 1];
127 result[threadId + 1] = data_l[threadId] * data_r[threadId + 1] + data_l[threadId + 1] * data_r[threadId];
130 // element-wise per channel multiplication, division and addition
131 ComplexMat ComplexMat::operator*(const ComplexMat &rhs) const
133 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
135 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
137 dim3 threadsPerBlock(rows, cols);
138 dim3 numBlocks(n_channels / n_scales, n_scales);
139 same_num_channels_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
140 (float*)rhs.p_data.deviceMem(),
141 (float*)result.p_data.deviceMem());
147 __global__ void same_num_channels_div_kernel(const float *data_l, const float *data_r, float *result)
149 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
150 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
152 result[threadId] = (data_l[threadId] * data_r[threadId] + data_l[threadId + 1] * data_r[threadId + 1]) /
153 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
154 result[threadId + 1] = (data_l[threadId + 1] * data_r[threadId] - data_l[threadId] * data_r[threadId + 1]) /
155 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
158 ComplexMat ComplexMat::operator/(const ComplexMat &rhs) const
160 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
162 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
164 dim3 threadsPerBlock(rows, cols);
165 dim3 numBlocks(n_channels / n_scales, n_scales);
166 same_num_channels_div_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
167 (float*)rhs.p_data.deviceMem(),
168 (float*)result.p_data.deviceMem());
174 __global__ void same_num_channels_add_kernel(const float *data_l, const float *data_r, float *result)
176 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
177 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
179 result[threadId] = data_l[threadId] + data_r[threadId];
180 result[threadId + 1] = data_l[threadId + 1] + data_r[threadId + 1];
183 ComplexMat ComplexMat::operator+(const ComplexMat &rhs) const
185 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
187 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
189 dim3 threadsPerBlock(rows, cols);
190 dim3 numBlocks(n_channels / n_scales, n_scales);
191 same_num_channels_add_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
192 (float*)rhs.p_data.deviceMem(),
193 (float*)result.p_data.deviceMem());
199 __global__ void constant_mul_kernel(const float *data_l, float constant, float *result)
201 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
202 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
204 result[threadId] = data_l[threadId] * constant;
205 result[threadId + 1] = data_l[threadId + 1] * constant;
208 ComplexMat ComplexMat::operator*(const float &rhs) const
210 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
212 dim3 threadsPerBlock(rows, cols);
213 dim3 numBlocks(n_channels / n_scales, n_scales);
214 constant_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
216 (float*)result.p_data.deviceMem());
222 __global__ void constant_add_kernel(const float *data_l, float constant, float *result)
224 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
225 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
227 result[threadId] = data_l[threadId] + constant;
228 result[threadId + 1] = data_l[threadId + 1];
231 ComplexMat ComplexMat::operator+(const float &rhs) const
233 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
235 dim3 threadsPerBlock(rows, cols);
236 dim3 numBlocks(n_channels / n_scales, n_scales);
237 constant_add_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
239 (float*)result.p_data.deviceMem());
245 __global__ void one_channel_mul_kernel(const float *data_l, const float *data_r, float *result)
247 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
248 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
249 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x);
251 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
252 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
255 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
256 ComplexMat ComplexMat::mul(const ComplexMat &rhs) const
258 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
260 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
262 dim3 threadsPerBlock(rows, cols);
263 dim3 numBlocks(n_channels / n_scales, n_scales);
264 one_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
265 (float*)rhs.p_data.deviceMem(),
266 (float*)result.p_data.deviceMem());
272 __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
274 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
275 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
276 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
278 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
279 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
282 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
283 // ComplexMat ComplexMat::mul2(const ComplexMat &rhs) const
285 // assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
287 // ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
289 // dim3 threadsPerBlock(rows, cols);
290 // dim3 numBlocks(n_channels / n_scales, n_scales);
291 // scales_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data, result.p_data);
297 // void ComplexMat::operator=(ComplexMat &&rhs)
301 // n_channels = rhs.n_channels;
302 // n_scales = rhs.n_scales;
304 // p_data = rhs.p_data;
306 // rhs.p_data = nullptr;