1 #include "complexmat.hpp"
4 __global__ void sqr_norm_kernel(const float *in, float *block_res, int total)
6 extern __shared__ float sdata[];
7 int in_idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
10 if (in_idx >= total * 2)
13 sdata[i] = in[in_idx] * in[in_idx] + in[in_idx + 1] * in[in_idx + 1];
15 for (unsigned s = (blockDim.x + 1) / 2; s > 0; s >>= 1) {
18 sdata[i] += sdata[i + s];
22 block_res[blockIdx.x] = sdata[0];
25 void ComplexMat_::sqr_norm(DynMem &result) const
27 assert(n_scales == 1);
29 const uint total = n_channels * rows * cols;
30 const dim3 threads(1024);
31 const dim3 blocks((total + threads.x - 1) / threads.x);
33 DynMem block_res(blocks.x);
35 sqr_norm_kernel<<<blocks, threads, threads.x * sizeof(float)>>>((const float*)p_data.deviceMem(),
36 block_res.deviceMem(), total);
38 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
41 for (int i = 0; i < blocks.x; i++)
43 result.hostMem()[0] = res / static_cast<T>(cols * rows);
46 __global__ void sqr_mag_kernel(const float *data, float *result)
48 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
49 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
51 result[threadId] = data[threadId] * data[threadId] + data[threadId + 1] * data[threadId + 1];
52 result[threadId + 1] = 0;
55 ComplexMat ComplexMat::sqr_mag() const
57 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
59 dim3 threadsPerBlock(rows, cols);
60 dim3 numBlocks(n_channels / n_scales, n_scales);
61 sqr_mag_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem());
67 __global__ void conj_kernel(const float *data, float *result)
69 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
70 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
72 result[threadId] = data[threadId];
73 result[threadId + 1] = -data[threadId + 1];
76 ComplexMat ComplexMat::conj() const
78 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
80 dim3 threadsPerBlock(rows, cols);
81 dim3 numBlocks(n_channels / n_scales, n_scales);
82 conj_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem());
88 __global__ static void sum_channels(float *dest, const float *src, uint channels, uint num_channel_elem)
90 int idx = blockIdx.x * blockDim.x + threadIdx.x;
92 if (idx >= num_channel_elem)
96 for (uint i = 0; i < channels; ++i)
97 acc += src[idx + i * num_channel_elem];
101 ComplexMat ComplexMat::sum_over_channels() const
103 assert(p_data.num_elem == n_channels * rows * cols);
105 uint n_channels_per_scale = n_channels / n_scales;
106 uint scale_offset = n_channels_per_scale * rows * cols;
108 ComplexMat_ result(this->rows, this->cols, 1, n_scales);
110 const uint total = rows * cols * 2;
111 const dim3 threads(256);
112 const dim3 blocks((total + threads.x - 1) / threads.x);
114 for (uint scale = 0; scale < n_scales; ++scale) {
115 sum_channels<<<blocks, threads>>>(reinterpret_cast<float*>(result.p_data.deviceMem() + scale * scale_offset),
116 reinterpret_cast<const float*>(p_data.deviceMem() + scale * scale_offset),
117 n_channels_per_scale, total);
119 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
123 __global__ void same_num_channels_mul_kernel(const float *data_l, const float *data_r, float *result)
125 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
126 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
128 result[threadId] = data_l[threadId] * data_r[threadId] - data_l[threadId + 1] * data_r[threadId + 1];
129 result[threadId + 1] = data_l[threadId] * data_r[threadId + 1] + data_l[threadId + 1] * data_r[threadId];
132 // element-wise per channel multiplication, division and addition
133 ComplexMat ComplexMat::operator*(const ComplexMat &rhs) const
135 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
137 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
139 dim3 threadsPerBlock(rows, cols);
140 dim3 numBlocks(n_channels / n_scales, n_scales);
141 same_num_channels_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
142 (float*)rhs.p_data.deviceMem(),
143 (float*)result.p_data.deviceMem());
149 __global__ void same_num_channels_div_kernel(const float *data_l, const float *data_r, float *result)
151 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
152 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
154 result[threadId] = (data_l[threadId] * data_r[threadId] + data_l[threadId + 1] * data_r[threadId + 1]) /
155 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
156 result[threadId + 1] = (data_l[threadId + 1] * data_r[threadId] - data_l[threadId] * data_r[threadId + 1]) /
157 (data_r[threadId] * data_r[threadId] + data_r[threadId + 1] * data_r[threadId + 1]);
160 ComplexMat ComplexMat::operator/(const ComplexMat &rhs) const
162 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
164 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
166 dim3 threadsPerBlock(rows, cols);
167 dim3 numBlocks(n_channels / n_scales, n_scales);
168 same_num_channels_div_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
169 (float*)rhs.p_data.deviceMem(),
170 (float*)result.p_data.deviceMem());
176 __global__ void same_num_channels_add_kernel(const float *data_l, const float *data_r, float *result)
178 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
179 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
181 result[threadId] = data_l[threadId] + data_r[threadId];
182 result[threadId + 1] = data_l[threadId + 1] + data_r[threadId + 1];
185 ComplexMat ComplexMat::operator+(const ComplexMat &rhs) const
187 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
189 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
191 dim3 threadsPerBlock(rows, cols);
192 dim3 numBlocks(n_channels / n_scales, n_scales);
193 same_num_channels_add_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
194 (float*)rhs.p_data.deviceMem(),
195 (float*)result.p_data.deviceMem());
201 __global__ void constant_mul_kernel(const float *data_l, float constant, float *result)
203 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
204 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
206 result[threadId] = data_l[threadId] * constant;
207 result[threadId + 1] = data_l[threadId + 1] * constant;
210 ComplexMat ComplexMat::operator*(const float &rhs) const
212 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
214 dim3 threadsPerBlock(rows, cols);
215 dim3 numBlocks(n_channels / n_scales, n_scales);
216 constant_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
218 (float*)result.p_data.deviceMem());
224 __global__ void constant_add_kernel(const float *data_l, float constant, float *result)
226 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
227 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
229 result[threadId] = data_l[threadId] + constant;
230 result[threadId + 1] = data_l[threadId + 1];
233 ComplexMat ComplexMat::operator+(const float &rhs) const
235 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
237 dim3 threadsPerBlock(rows, cols);
238 dim3 numBlocks(n_channels / n_scales, n_scales);
239 constant_add_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
241 (float*)result.p_data.deviceMem());
247 __global__ void one_channel_mul_kernel(const float *data_l, const float *data_r, float *result)
249 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
250 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
251 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x);
253 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
254 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
257 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
258 ComplexMat ComplexMat::mul(const ComplexMat &rhs) const
260 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
262 ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
264 dim3 threadsPerBlock(rows, cols);
265 dim3 numBlocks(n_channels / n_scales, n_scales);
266 one_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>((float*)this->p_data.deviceMem(),
267 (float*)rhs.p_data.deviceMem(),
268 (float*)result.p_data.deviceMem());
274 __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
276 int blockId = blockIdx.x + blockIdx.y * gridDim.x;
277 int threadId = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
278 int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
280 result[threadId] = data_l[threadId] * data_r[one_ch_index] - data_l[threadId + 1] * data_r[one_ch_index + 1];
281 result[threadId + 1] = data_l[threadId] * data_r[one_ch_index + 1] + data_l[threadId + 1] * data_r[one_ch_index];
284 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
285 // ComplexMat ComplexMat::mul2(const ComplexMat &rhs) const
287 // assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
289 // ComplexMat result(this->rows, this->cols, this->channels(), this->n_scales);
291 // dim3 threadsPerBlock(rows, cols);
292 // dim3 numBlocks(n_channels / n_scales, n_scales);
293 // scales_channel_mul_kernel<<<numBlocks, threadsPerBlock, 0>>>(this->p_data, rhs.p_data, result.p_data);
299 // void ComplexMat::operator=(ComplexMat &&rhs)
303 // n_channels = rhs.n_channels;
304 // n_scales = rhs.n_scales;
306 // p_data = rhs.p_data;
308 // rhs.p_data = nullptr;