1 #include "complexmat.hpp"
4 __global__ void sqr_norm_kernel(const float *in, float *block_res, int total)
6 extern __shared__ float sdata[];
7 int in_idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
10 if (in_idx >= total * 2)
13 sdata[i] = in[in_idx] * in[in_idx] + in[in_idx + 1] * in[in_idx + 1];
15 for (unsigned s = (blockDim.x + 1) / 2; s > 0; s >>= 1) {
18 sdata[i] += sdata[i + s];
22 block_res[blockIdx.x] = sdata[0];
25 void ComplexMat_::sqr_norm(DynMem &result) const
27 assert(result.num_elem == n_scales);
29 const uint total = n_channels / n_scales * rows * cols;
30 const dim3 threads(1024);
31 const dim3 blocks((total + threads.x - 1) / threads.x);
33 DynMem block_res(blocks.x * n_scales);
35 for (uint s = 0; s < n_scales; ++s) {
36 sqr_norm_kernel<<<blocks, threads, threads.x * sizeof(float)>>>((const float*)(p_data.deviceMem() + s * total),
37 block_res.deviceMem() + s * blocks.x, total);
42 for (uint s = 0; s < n_scales; ++s) {
44 for (int i = 0; i < blocks.x; i++)
45 res += block_res[s * blocks.x + i];
46 result.hostMem()[s] = res / static_cast<T>(cols * rows);
50 __global__ void sqr_mag_kernel(const float *data, float *result, int total)
52 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
54 if (idx / 2 < total) {
55 result[idx] = data[idx] * data[idx] + data[idx + 1] * data[idx + 1];
60 ComplexMat_ ComplexMat_::sqr_mag() const
62 ComplexMat_ result = ComplexMat_::same_size(*this);
64 const uint total = n_channels * rows * cols;
65 const dim3 threads(256);
66 const dim3 blocks((total + threads.x - 1) / threads.x);
68 sqr_mag_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
69 (float*)result.p_data.deviceMem(),
76 __global__ void conj_kernel(const float *data, float *result, int total)
78 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
80 if (idx / 2 < total) {
81 result[idx] = data[idx];
82 result[idx + 1] = -data[idx + 1];
86 ComplexMat_ ComplexMat_::conj() const
88 ComplexMat_ result = ComplexMat_::same_size(*this);
90 const uint total = n_channels * rows * cols;
91 const dim3 threads(256);
92 const dim3 blocks((total + threads.x - 1) / threads.x);
94 conj_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem(), total);
100 __global__ static void sum_channels(float *dest, const float *src, uint channels, uint num_channel_elem)
102 int idx = blockIdx.x * blockDim.x + threadIdx.x;
104 if (idx >= num_channel_elem)
108 for (uint i = 0; i < channels; ++i)
109 acc += src[idx + i * num_channel_elem];
113 ComplexMat_ ComplexMat_::sum_over_channels() const
115 assert(p_data.num_elem == n_channels * rows * cols);
117 uint n_channels_per_scale = n_channels / n_scales;
118 uint scale_offset = n_channels_per_scale * rows * cols;
120 ComplexMat_ result(this->rows, this->cols, 1, n_scales);
122 const uint total = rows * cols * 2;
123 const dim3 threads(256);
124 const dim3 blocks((total + threads.x - 1) / threads.x);
126 for (uint scale = 0; scale < n_scales; ++scale) {
127 sum_channels<<<blocks, threads>>>(reinterpret_cast<float*>(result.p_data.deviceMem() + scale * scale_offset),
128 reinterpret_cast<const float*>(p_data.deviceMem() + scale * scale_offset),
129 n_channels_per_scale, total);
131 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
135 __global__ void same_num_channels_mul_kernel(const float *data_l, const float *data_r, float *result, int total)
137 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
139 if (idx / 2 < total) {
140 result[idx] = data_l[idx] * data_r[idx] - data_l[idx + 1] * data_r[idx + 1];
141 result[idx + 1] = data_l[idx] * data_r[idx + 1] + data_l[idx + 1] * data_r[idx];
145 // element-wise per channel multiplication, division and addition
146 ComplexMat_ ComplexMat_::operator*(const ComplexMat_ &rhs) const
148 assert(n_channels == n_scales * rhs.n_channels && rhs.cols == cols && rhs.rows == rows);
150 ComplexMat_ result = ComplexMat_::same_size(*this);
152 const uint total = n_channels / n_scales * rows * cols;
153 const dim3 threads(256);
154 const dim3 blocks((total + threads.x - 1) / threads.x);
156 for (uint s = 0; s < n_scales; ++s) {
157 same_num_channels_mul_kernel<<<blocks, threads, 0>>>((float*)(this->p_data.deviceMem() + s * total),
158 (float*)rhs.p_data.deviceMem(),
159 (float*)(result.p_data.deviceMem() + s * total),
167 __global__ void same_num_channels_div_kernel(const float *data_l, const float *data_r, float *result, unsigned total)
169 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
171 if (idx / 2 < total) {
172 result[idx] = (data_l[idx] * data_r[idx] + data_l[idx + 1] * data_r[idx + 1]) /
173 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
174 result[idx + 1] = (data_l[idx + 1] * data_r[idx] - data_l[idx] * data_r[idx + 1]) /
175 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
179 ComplexMat_ ComplexMat_::operator/(const ComplexMat_ &rhs) const
181 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
183 ComplexMat_ result = ComplexMat_::same_size(*this);
185 const uint total = n_channels * rows * cols;
186 const dim3 threads(256);
187 const dim3 blocks((total + threads.x - 1) / threads.x);
189 same_num_channels_div_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
190 (float*)rhs.p_data.deviceMem(),
191 (float*)result.p_data.deviceMem(), total);
197 __global__ void same_num_channels_add_kernel(const float *data_l, const float *data_r, float *result, int total)
199 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
201 if (idx / 2 < total) {
202 result[idx] = data_l[idx] + data_r[idx];
203 result[idx + 1] = data_l[idx + 1] + data_r[idx + 1];
207 ComplexMat_ ComplexMat_::operator+(const ComplexMat_ &rhs) const
209 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
211 ComplexMat_ result = ComplexMat_::same_size(*this);
213 const uint total = n_channels * rows * cols;
214 const dim3 threads(256);
215 const dim3 blocks((total + threads.x - 1) / threads.x);
217 same_num_channels_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
218 (float*)rhs.p_data.deviceMem(),
219 (float*)result.p_data.deviceMem(),
226 __global__ void constant_mul_kernel(const float *data_l, float constant, float *result, int total)
228 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
230 if (idx / 2 < total) {
231 result[idx] = data_l[idx] * constant;
232 result[idx + 1] = data_l[idx + 1] * constant;
236 ComplexMat_ ComplexMat_::operator*(const float &rhs) const
238 ComplexMat_ result = ComplexMat_::same_size(*this);
240 const uint total = n_channels * rows * cols;
241 const dim3 threads(256);
242 const dim3 blocks((total + threads.x - 1) / threads.x);
244 constant_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
246 (float*)result.p_data.deviceMem(),
253 __global__ void constant_add_kernel(const float *data_l, float constant, float *result, int total)
255 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
257 if (idx / 2 < total) {
258 result[idx] = data_l[idx] + constant;
259 result[idx + 1] = data_l[idx + 1];
263 ComplexMat_ ComplexMat_::operator+(const float &rhs) const
265 ComplexMat_ result = ComplexMat_::same_size(*this);
267 const uint total = n_channels * rows * cols;
268 const dim3 threads(256);
269 const dim3 blocks((total + threads.x - 1) / threads.x);
271 constant_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
273 (float*)result.p_data.deviceMem(),
280 __global__ void one_channel_mul_kernel(const float *data_l, const float *data_r, float *result,
281 int channel_total, int total)
283 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
284 int one_ch_idx = idx % (2 * channel_total);
286 if (idx / 2 < total) {
287 result[idx] = data_l[idx] * data_r[one_ch_idx] - data_l[idx + 1] * data_r[one_ch_idx + 1];
288 result[idx + 1] = data_l[idx] * data_r[one_ch_idx + 1] + data_l[idx + 1] * data_r[one_ch_idx];
292 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
293 ComplexMat_ ComplexMat_::mul(const ComplexMat_ &rhs) const
295 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
297 ComplexMat_ result = ComplexMat_::same_size(*this);
299 const uint total = n_channels * rows * cols;
300 const dim3 threads(256);
301 const dim3 blocks((total + threads.x - 1) / threads.x);
303 one_channel_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
304 (float*)rhs.p_data.deviceMem(),
305 (float*)result.p_data.deviceMem(),
312 // __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
314 // int blockId = blockIdx.x + blockIdx.y * gridDim.x;
315 // int idx = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
316 // int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
318 // result[idx] = data_l[idx] * data_r[one_ch_index] - data_l[idx + 1] * data_r[one_ch_index + 1];
319 // result[idx + 1] = data_l[idx] * data_r[one_ch_index + 1] + data_l[idx + 1] * data_r[one_ch_index];
322 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
323 // ComplexMat_ ComplexMat_::mul2(const ComplexMat_ &rhs) const
325 // assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
327 // ComplexMat_ result(this->rows, this->cols, this->channels(), this->n_scales);
329 // dim3 threadsPerBlock(rows, cols);
330 // dim3 numBlocks(n_channels / n_scales, n_scales);
331 // scales_channel_mul_kernel<<<threads, blocks, 0>>>(this->p_data, rhs.p_data, result.p_data);
337 // void ComplexMat_::operator=(ComplexMat_ &&rhs)
341 // n_channels = rhs.n_channels;
342 // n_scales = rhs.n_scales;
344 // p_data = rhs.p_data;
346 // rhs.p_data = nullptr;
349 void ComplexMat_::cudaSync() const
351 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));