1 #include "complexmat.hpp"
4 __global__ void sqr_norm_kernel(const float *in, float *block_res, int total)
6 extern __shared__ float sdata[];
7 int in_idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
10 if (in_idx >= total * 2)
13 sdata[i] = in[in_idx] * in[in_idx] + in[in_idx + 1] * in[in_idx + 1];
15 for (unsigned s = (blockDim.x + 1) / 2; s > 0; s >>= 1) {
18 sdata[i] += sdata[i + s];
22 block_res[blockIdx.x] = sdata[0];
25 void ComplexMat_::sqr_norm(DynMem &result) const
28 assert(result.num_elem == n_scales);
30 const uint total = n_channels / n_scales * rows * cols;
31 const dim3 threads(1024);
32 const dim3 blocks((total + threads.x - 1) / threads.x);
34 DynMem block_res(blocks.x * n_scales);
36 for (uint s = 0; s < n_scales; ++s) {
37 sqr_norm_kernel<<<blocks, threads, threads.x * sizeof(float)>>>((const float*)(p_data.deviceMem() + s * total),
38 block_res.deviceMem() + s * blocks.x, total);
43 for (uint s = 0; s < n_scales; ++s) {
45 for (int i = 0; i < blocks.x; i++)
46 res += block_res[s * blocks.x + i];
47 result.hostMem()[s] = res / static_cast<T>(cols * rows);
51 __global__ void sqr_mag_kernel(const float *data, float *result, int total)
53 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
55 if (idx / 2 < total) {
56 result[idx] = data[idx] * data[idx] + data[idx + 1] * data[idx + 1];
61 ComplexMat_ ComplexMat_::sqr_mag() const
63 ComplexMat_ result = ComplexMat_::same_size(*this);
65 const uint total = n_channels * rows * cols;
66 const dim3 threads(256);
67 const dim3 blocks((total + threads.x - 1) / threads.x);
69 sqr_mag_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
70 (float*)result.p_data.deviceMem(),
77 __global__ void conj_kernel(const float *data, float *result, int total)
79 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
81 if (idx / 2 < total) {
82 result[idx] = data[idx];
83 result[idx + 1] = -data[idx + 1];
87 ComplexMat_ ComplexMat_::conj() const
89 ComplexMat_ result = ComplexMat_::same_size(*this);
91 const uint total = n_channels * rows * cols;
92 const dim3 threads(256);
93 const dim3 blocks((total + threads.x - 1) / threads.x);
95 conj_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem(), total);
101 __global__ static void sum_channels(float *dest, const float *src, uint channels, uint num_channel_elem)
103 int idx = blockIdx.x * blockDim.x + threadIdx.x;
105 if (idx >= num_channel_elem)
109 for (uint i = 0; i < channels; ++i)
110 acc += src[idx + i * num_channel_elem];
114 ComplexMat_ ComplexMat_::sum_over_channels() const
116 assert(p_data.num_elem == n_channels * rows * cols);
118 uint n_channels_per_scale = n_channels / n_scales;
120 ComplexMat_ result(this->rows, this->cols, 1, n_scales);
122 const uint total = rows * cols * 2;
123 const dim3 threads(256);
124 const dim3 blocks((total + threads.x - 1) / threads.x);
126 for (uint scale = 0; scale < n_scales; ++scale) {
127 sum_channels<<<blocks, threads>>>(reinterpret_cast<float*>(result.p_data.deviceMem() + scale * rows * cols),
128 reinterpret_cast<const float*>(p_data.deviceMem() + scale * n_channels_per_scale * rows * cols),
129 n_channels_per_scale, total);
134 __global__ void same_num_channels_mul_kernel(const float *data_l, const float *data_r, float *result, int total)
136 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
138 if (idx / 2 < total) {
139 result[idx] = data_l[idx] * data_r[idx] - data_l[idx + 1] * data_r[idx + 1];
140 result[idx + 1] = data_l[idx] * data_r[idx + 1] + data_l[idx + 1] * data_r[idx];
144 // element-wise per channel multiplication, division and addition
145 ComplexMat_ ComplexMat_::operator*(const ComplexMat_ &rhs) const
147 assert(n_channels == n_scales * rhs.n_channels && rhs.cols == cols && rhs.rows == rows);
149 ComplexMat_ result = ComplexMat_::same_size(*this);
151 const uint total = n_channels / n_scales * rows * cols;
152 const dim3 threads(256);
153 const dim3 blocks((total + threads.x - 1) / threads.x);
155 for (uint s = 0; s < n_scales; ++s) {
156 same_num_channels_mul_kernel<<<blocks, threads, 0>>>((float*)(this->p_data.deviceMem() + s * total),
157 (float*)rhs.p_data.deviceMem(),
158 (float*)(result.p_data.deviceMem() + s * total),
166 __global__ void same_num_channels_div_kernel(const float *data_l, const float *data_r, float *result, unsigned total)
168 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
170 if (idx / 2 < total) {
171 result[idx] = (data_l[idx] * data_r[idx] + data_l[idx + 1] * data_r[idx + 1]) /
172 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
173 result[idx + 1] = (data_l[idx + 1] * data_r[idx] - data_l[idx] * data_r[idx + 1]) /
174 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
178 ComplexMat_ ComplexMat_::operator/(const ComplexMat_ &rhs) const
180 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
182 ComplexMat_ result = ComplexMat_::same_size(*this);
184 const uint total = n_channels * rows * cols;
185 const dim3 threads(256);
186 const dim3 blocks((total + threads.x - 1) / threads.x);
188 same_num_channels_div_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
189 (float*)rhs.p_data.deviceMem(),
190 (float*)result.p_data.deviceMem(), total);
196 __global__ void same_num_channels_add_kernel(const float *data_l, const float *data_r, float *result, int total)
198 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
200 if (idx / 2 < total) {
201 result[idx] = data_l[idx] + data_r[idx];
202 result[idx + 1] = data_l[idx + 1] + data_r[idx + 1];
206 ComplexMat_ ComplexMat_::operator+(const ComplexMat_ &rhs) const
208 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
210 ComplexMat_ result = ComplexMat_::same_size(*this);
212 const uint total = n_channels * rows * cols;
213 const dim3 threads(256);
214 const dim3 blocks((total + threads.x - 1) / threads.x);
216 same_num_channels_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
217 (float*)rhs.p_data.deviceMem(),
218 (float*)result.p_data.deviceMem(),
225 __global__ void constant_mul_kernel(const float *data_l, float constant, float *result, int total)
227 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
229 if (idx / 2 < total) {
230 result[idx] = data_l[idx] * constant;
231 result[idx + 1] = data_l[idx + 1] * constant;
235 ComplexMat_ ComplexMat_::operator*(const float &rhs) const
237 ComplexMat_ result = ComplexMat_::same_size(*this);
239 const uint total = n_channels * rows * cols;
240 const dim3 threads(256);
241 const dim3 blocks((total + threads.x - 1) / threads.x);
243 constant_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
245 (float*)result.p_data.deviceMem(),
252 __global__ void constant_add_kernel(const float *data_l, float constant, float *result, int total)
254 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
256 if (idx / 2 < total) {
257 result[idx] = data_l[idx] + constant;
258 result[idx + 1] = data_l[idx + 1];
262 ComplexMat_ ComplexMat_::operator+(const float &rhs) const
264 ComplexMat_ result = ComplexMat_::same_size(*this);
266 const uint total = n_channels * rows * cols;
267 const dim3 threads(256);
268 const dim3 blocks((total + threads.x - 1) / threads.x);
270 constant_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
272 (float*)result.p_data.deviceMem(),
279 __global__ void one_channel_mul_kernel(const float *data_l, const float *data_r, float *result,
280 int channel_total, int total)
282 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
283 int one_ch_idx = idx % (2 * channel_total);
285 if (idx / 2 < total) {
286 result[idx] = data_l[idx] * data_r[one_ch_idx] - data_l[idx + 1] * data_r[one_ch_idx + 1];
287 result[idx + 1] = data_l[idx] * data_r[one_ch_idx + 1] + data_l[idx + 1] * data_r[one_ch_idx];
291 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
292 ComplexMat_ ComplexMat_::mul(const ComplexMat_ &rhs) const
294 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
296 ComplexMat_ result = ComplexMat_::same_size(*this);
298 const uint total = n_channels * rows * cols;
299 const dim3 threads(256);
300 const dim3 blocks((total + threads.x - 1) / threads.x);
302 one_channel_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
303 (float*)rhs.p_data.deviceMem(),
304 (float*)result.p_data.deviceMem(),
311 // __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
313 // int blockId = blockIdx.x + blockIdx.y * gridDim.x;
314 // int idx = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
315 // int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
317 // result[idx] = data_l[idx] * data_r[one_ch_index] - data_l[idx + 1] * data_r[one_ch_index + 1];
318 // result[idx + 1] = data_l[idx] * data_r[one_ch_index + 1] + data_l[idx + 1] * data_r[one_ch_index];
321 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
322 // ComplexMat_ ComplexMat_::mul2(const ComplexMat_ &rhs) const
324 // assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
326 // ComplexMat_ result(this->rows, this->cols, this->channels(), this->n_scales);
328 // dim3 threadsPerBlock(rows, cols);
329 // dim3 numBlocks(n_channels / n_scales, n_scales);
330 // scales_channel_mul_kernel<<<threads, blocks, 0>>>(this->p_data, rhs.p_data, result.p_data);
336 // void ComplexMat_::operator=(ComplexMat_ &&rhs)
340 // n_channels = rhs.n_channels;
341 // n_scales = rhs.n_scales;
343 // p_data = rhs.p_data;
345 // rhs.p_data = nullptr;
348 void ComplexMat_::cudaSync() const
350 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));