diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 5bb85b480..73b981334 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,65 +1,75 @@ #include "im2col.cuh" +#define MIN(a, b) (a) < (b) ? (a) : (b) + +#define MAX_GRIDDIM_Z 65535 + template static __global__ void im2col_kernel( - const float * x, T * dst, int64_t batch_offset, - int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, int s0, int s1, int p0, int p1, int d0, int d1) { const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= pelements) { + if (i >= IC_KH_KW) { return; } - const int64_t ksize = OW * KH; - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; - const int64_t oh = blockIdx.y; - const int64_t batch = blockIdx.z / IC; - const int64_t ic = blockIdx.z % IC; + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] template static void im2col_cuda(const float * x, T* dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, batch * IC); - im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, + IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); } static void im2col_cuda_f16(const float * x, half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } static void im2col_cuda_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -91,13 +101,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t OH = is_2D ? dst->ne[2] : 1; const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[is_2D ? 3 : 2]; - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) { - im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } else { - im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } }