1 #include "complexmat.hpp"
4 __global__ void sqr_norm_kernel(const float *in, float *block_res, int total)
6 extern __shared__ float sdata[];
7 int in_idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
10 if (in_idx >= total * 2)
13 sdata[i] = in[in_idx] * in[in_idx] + in[in_idx + 1] * in[in_idx + 1];
15 for (unsigned s = (blockDim.x + 1) / 2; s > 0; s >>= 1) {
18 sdata[i] += sdata[i + s];
22 block_res[blockIdx.x] = sdata[0];
25 void ComplexMat_::sqr_norm(DynMem &result) const
27 assert(n_scales == 1);
29 const uint total = n_channels * rows * cols;
30 const dim3 threads(1024);
31 const dim3 blocks((total + threads.x - 1) / threads.x);
33 DynMem block_res(blocks.x);
35 sqr_norm_kernel<<<blocks, threads, threads.x * sizeof(float)>>>((const float*)p_data.deviceMem(),
36 block_res.deviceMem(), total);
41 for (int i = 0; i < blocks.x; i++)
43 result.hostMem()[0] = res / static_cast<T>(cols * rows);
46 __global__ void sqr_mag_kernel(const float *data, float *result, int total)
48 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
50 if (idx / 2 < total) {
51 result[idx] = data[idx] * data[idx] + data[idx + 1] * data[idx + 1];
56 ComplexMat_ ComplexMat_::sqr_mag() const
58 ComplexMat_ result = ComplexMat_::same_size(*this);
60 const uint total = n_channels * rows * cols;
61 const dim3 threads(256);
62 const dim3 blocks((total + threads.x - 1) / threads.x);
64 sqr_mag_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
65 (float*)result.p_data.deviceMem(),
72 __global__ void conj_kernel(const float *data, float *result, int total)
74 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
76 if (idx / 2 < total) {
77 result[idx] = data[idx];
78 result[idx + 1] = -data[idx + 1];
82 ComplexMat_ ComplexMat_::conj() const
84 ComplexMat_ result = ComplexMat_::same_size(*this);
86 const uint total = n_channels * rows * cols;
87 const dim3 threads(256);
88 const dim3 blocks((total + threads.x - 1) / threads.x);
90 conj_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(), (float*)result.p_data.deviceMem(), total);
96 __global__ static void sum_channels(float *dest, const float *src, uint channels, uint num_channel_elem)
98 int idx = blockIdx.x * blockDim.x + threadIdx.x;
100 if (idx >= num_channel_elem)
104 for (uint i = 0; i < channels; ++i)
105 acc += src[idx + i * num_channel_elem];
109 ComplexMat_ ComplexMat_::sum_over_channels() const
111 assert(p_data.num_elem == n_channels * rows * cols);
113 uint n_channels_per_scale = n_channels / n_scales;
114 uint scale_offset = n_channels_per_scale * rows * cols;
116 ComplexMat_ result(this->rows, this->cols, 1, n_scales);
118 const uint total = rows * cols * 2;
119 const dim3 threads(256);
120 const dim3 blocks((total + threads.x - 1) / threads.x);
122 for (uint scale = 0; scale < n_scales; ++scale) {
123 sum_channels<<<blocks, threads>>>(reinterpret_cast<float*>(result.p_data.deviceMem() + scale * scale_offset),
124 reinterpret_cast<const float*>(p_data.deviceMem() + scale * scale_offset),
125 n_channels_per_scale, total);
127 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));
131 __global__ void same_num_channels_mul_kernel(const float *data_l, const float *data_r, float *result, int total)
133 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
135 if (idx / 2 < total) {
136 result[idx] = data_l[idx] * data_r[idx] - data_l[idx + 1] * data_r[idx + 1];
137 result[idx + 1] = data_l[idx] * data_r[idx + 1] + data_l[idx + 1] * data_r[idx];
141 // element-wise per channel multiplication, division and addition
142 ComplexMat_ ComplexMat_::operator*(const ComplexMat_ &rhs) const
144 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
146 ComplexMat_ result = ComplexMat_::same_size(*this);
148 const uint total = n_channels * rows * cols;
149 const dim3 threads(256);
150 const dim3 blocks((total + threads.x - 1) / threads.x);
152 same_num_channels_mul_kernel<<<blocks, threads, 0>>>((float*)this->p_data.deviceMem(),
153 (float*)rhs.p_data.deviceMem(),
154 (float*)result.p_data.deviceMem(),
161 __global__ void same_num_channels_div_kernel(const float *data_l, const float *data_r, float *result, unsigned total)
163 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
165 if (idx / 2 < total) {
166 result[idx] = (data_l[idx] * data_r[idx] + data_l[idx + 1] * data_r[idx + 1]) /
167 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
168 result[idx + 1] = (data_l[idx + 1] * data_r[idx] - data_l[idx] * data_r[idx + 1]) /
169 (data_r[idx] * data_r[idx] + data_r[idx + 1] * data_r[idx + 1]);
173 ComplexMat_ ComplexMat_::operator/(const ComplexMat_ &rhs) const
175 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
177 ComplexMat_ result = ComplexMat_::same_size(*this);
179 const uint total = n_channels * rows * cols;
180 const dim3 threads(256);
181 const dim3 blocks((total + threads.x - 1) / threads.x);
183 same_num_channels_div_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
184 (float*)rhs.p_data.deviceMem(),
185 (float*)result.p_data.deviceMem(), total);
191 __global__ void same_num_channels_add_kernel(const float *data_l, const float *data_r, float *result, int total)
193 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
195 if (idx / 2 < total) {
196 result[idx] = data_l[idx] + data_r[idx];
197 result[idx + 1] = data_l[idx + 1] + data_r[idx + 1];
201 ComplexMat_ ComplexMat_::operator+(const ComplexMat_ &rhs) const
203 assert(rhs.n_channels == n_channels && rhs.cols == cols && rhs.rows == rows);
205 ComplexMat_ result = ComplexMat_::same_size(*this);
207 const uint total = n_channels * rows * cols;
208 const dim3 threads(256);
209 const dim3 blocks((total + threads.x - 1) / threads.x);
211 same_num_channels_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
212 (float*)rhs.p_data.deviceMem(),
213 (float*)result.p_data.deviceMem(),
220 __global__ void constant_mul_kernel(const float *data_l, float constant, float *result, int total)
222 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
224 if (idx / 2 < total) {
225 result[idx] = data_l[idx] * constant;
226 result[idx + 1] = data_l[idx + 1] * constant;
230 ComplexMat_ ComplexMat_::operator*(const float &rhs) const
232 ComplexMat_ result = ComplexMat_::same_size(*this);
234 const uint total = n_channels * rows * cols;
235 const dim3 threads(256);
236 const dim3 blocks((total + threads.x - 1) / threads.x);
238 constant_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
240 (float*)result.p_data.deviceMem(),
247 __global__ void constant_add_kernel(const float *data_l, float constant, float *result, int total)
249 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
251 if (idx / 2 < total) {
252 result[idx] = data_l[idx] + constant;
253 result[idx + 1] = data_l[idx + 1];
257 ComplexMat_ ComplexMat_::operator+(const float &rhs) const
259 ComplexMat_ result = ComplexMat_::same_size(*this);
261 const uint total = n_channels * rows * cols;
262 const dim3 threads(256);
263 const dim3 blocks((total + threads.x - 1) / threads.x);
265 constant_add_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
267 (float*)result.p_data.deviceMem(),
274 __global__ void one_channel_mul_kernel(const float *data_l, const float *data_r, float *result,
275 int channel_total, int total)
277 int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
278 int one_ch_idx = idx % (2 * channel_total);
280 if (idx / 2 < total) {
281 result[idx] = data_l[idx] * data_r[one_ch_idx] - data_l[idx + 1] * data_r[one_ch_idx + 1];
282 result[idx + 1] = data_l[idx] * data_r[one_ch_idx + 1] + data_l[idx + 1] * data_r[one_ch_idx];
286 // multiplying element-wise multichannel by one channel mats (rhs mat is with one channel)
287 ComplexMat_ ComplexMat_::mul(const ComplexMat_ &rhs) const
289 assert(rhs.n_channels == 1 && rhs.cols == cols && rhs.rows == rows);
291 ComplexMat_ result = ComplexMat_::same_size(*this);
293 const uint total = n_channels * rows * cols;
294 const dim3 threads(256);
295 const dim3 blocks((total + threads.x - 1) / threads.x);
297 one_channel_mul_kernel<<<threads, blocks, 0>>>((float*)this->p_data.deviceMem(),
298 (float*)rhs.p_data.deviceMem(),
299 (float*)result.p_data.deviceMem(),
306 // __global__ void scales_channel_mul_kernel(float *data_l, float *data_r, float *result)
308 // int blockId = blockIdx.x + blockIdx.y * gridDim.x;
309 // int idx = 2 * (blockId * (blockDim.x * blockDim.y) + (threadIdx.y * blockDim.x) + threadIdx.x);
310 // int one_ch_index = 2 * ((threadIdx.y * blockDim.x) + threadIdx.x + blockIdx.x * blockDim.x * blockDim.y);
312 // result[idx] = data_l[idx] * data_r[one_ch_index] - data_l[idx + 1] * data_r[one_ch_index + 1];
313 // result[idx + 1] = data_l[idx] * data_r[one_ch_index + 1] + data_l[idx + 1] * data_r[one_ch_index];
316 // multiplying element-wise multichannel by one channel mats (rhs mat is with multiple channel)
317 // ComplexMat_ ComplexMat_::mul2(const ComplexMat_ &rhs) const
319 // assert(rhs.n_channels == n_channels / n_scales && rhs.cols == cols && rhs.rows == rows);
321 // ComplexMat_ result(this->rows, this->cols, this->channels(), this->n_scales);
323 // dim3 threadsPerBlock(rows, cols);
324 // dim3 numBlocks(n_channels / n_scales, n_scales);
325 // scales_channel_mul_kernel<<<threads, blocks, 0>>>(this->p_data, rhs.p_data, result.p_data);
331 // void ComplexMat_::operator=(ComplexMat_ &&rhs)
335 // n_channels = rhs.n_channels;
336 // n_scales = rhs.n_scales;
338 // p_data = rhs.p_data;
340 // rhs.p_data = nullptr;
343 void ComplexMat_::cudaSync() const
345 CudaSafeCall(cudaStreamSynchronize(cudaStreamPerThread));