diff --git a/src/cuda_wrapper/kernels.cu b/src/cuda_wrapper/kernels.cu index bfa044183..967af5106 100644 --- a/src/cuda_wrapper/kernels.cu +++ b/src/cuda_wrapper/kernels.cu @@ -72,25 +72,10 @@ /** * modified @ref vc_copylineRG48toR12L - * @todo fix the last block for widths not divisible by 8 */ -__global__ void -kernel_rg48_to_r12l(uint8_t *in, uint8_t *out, unsigned size_x) +__device__ static void +rt48_to_r12l_compute_blk(const uint8_t *src, uint8_t *dst) { - unsigned position_x = threadIdx.x + blockIdx.x * blockDim.x; - unsigned position_y = threadIdx.y + blockIdx.y * blockDim.y; - if (position_x >= (size_x + 7) / 8) { - return; - } - // drop last block if not complete (prevent overriding start of - // following line with junk and also possibly OOB access) - if (position_x > size_x / 8) { - return; - } - uint8_t *src = in + 2 * (position_y * 3 * size_x + position_x * 3 * 8); - uint8_t *dst = - out + (position_y * ((size_x + 7) / 8) + position_x) * 36; - // 0 dst[0] = src[0] >> 4; dst[0] |= src[1] << 4; @@ -208,6 +193,40 @@ kernel_rg48_to_r12l(uint8_t *in, uint8_t *out, unsigned size_x) src += 2; } +__device__ static void +rt48_to_r12l_compute_last_blk(uint8_t *src, uint8_t *dst, unsigned width) +{ + uint8_t tmp[48]; + for (unsigned i = 0; i < width * 6; ++i) { + tmp[i] = src[i]; + } + rt48_to_r12l_compute_blk(tmp, dst); +} + +/** + * @todo fix the last block for widths not divisible by 8 + */ +__global__ static void +kernel_rg48_to_r12l(uint8_t *in, uint8_t *out, unsigned size_x) +{ + unsigned position_x = threadIdx.x + blockIdx.x * blockDim.x; + unsigned position_y = threadIdx.y + blockIdx.y * blockDim.y; + if (position_x >= (size_x + 7) / 8) { + return; + } + uint8_t *src = in + 2 * (position_y * 3 * size_x + position_x * 3 * 8); + uint8_t *dst = + out + (position_y * ((size_x + 7) / 8) + position_x) * 36; + + // handle incomplete blocks + if (position_x == size_x / 8) { + rt48_to_r12l_compute_last_blk(src, dst, + size_x - position_x * 8); + return; + } + rt48_to_r12l_compute_blk(src, dst); +} + /** * @sa cmpto_j2k_dec_postprocessor_run_callback_cuda */