CUDA: add set rows for f32 and f16 (#14551)

* CUDA: add set rows for f32 and f16

* Review: change kernel params, use strides from host

* Use 1-d kernel

* Review: use int64_t for blockDim.x, rename nb->s for clarity
This commit is contained in:
Aman Gupta
2025-07-12 21:31:38 +08:00
committed by GitHub
parent 8eff95544e
commit 7de5c7cab6
3 changed files with 147 additions and 0 deletions

View File

@ -43,6 +43,7 @@
#include "ggml-cuda/upscale.cuh" #include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh" #include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh" #include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml.h" #include "ggml.h"
#include <algorithm> #include <algorithm>
@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GET_ROWS_BACK: case GGML_OP_GET_ROWS_BACK:
ggml_cuda_op_get_rows_back(ctx, dst); ggml_cuda_op_get_rows_back(ctx, dst);
break; break;
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_DUP: case GGML_OP_DUP:
ggml_cuda_dup(ctx, dst); ggml_cuda_dup(ctx, dst);
break; break;
@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
{ {
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
} break; } break;
case GGML_OP_SET_ROWS:
{
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_I64;
} break;
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;

View File

@ -0,0 +1,130 @@
#include "set-rows.cuh"
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
template<typename src_t, typename dst_t>
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
template<>
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
*dst_h = __float2half(*src_f);
}
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
*dst_f = *src_f;
}
template<typename src_t, typename dst_t>
static __global__ void k_set_rows(
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
if (i >= ne_total) {
return;
}
const int64_t i03 = i / (ne00 * ne01 * ne02);
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
const src_t* src_elem = src0_row + i00;
dst_t* dst_elem = dst_row_ptr + i00;
set_rows_1(src_elem, dst_elem);
}
template<typename src_t, typename dst_t>
static void set_rows_cuda(
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(src_t);
const int64_t s02 = nb02/sizeof(src_t);
const int64_t s03 = nb03/sizeof(src_t);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s1 = nb1/sizeof(dst_t);
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
if (ne_total > 0) {
k_set_rows<<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
}
}
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_I64);
GGML_TENSOR_BINARY_OP_LOCALS
const float * src0_d = (const float *)src0->data;
const int64_t * src1_d = (const int64_t *)src1->data;
cudaStream_t stream = ctx.stream();
if (dst->type == GGML_TYPE_F32) {
set_rows_cuda(
src0_d, src1_d, (float*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_F16) {
set_rows_cuda(
src0_d, src1_d, (half*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else {
GGML_ABORT("unsupported type");
}
}

View File

@ -0,0 +1,7 @@
#pragma once
#include "common.cuh"
#define CUDA_SET_ROWS_BLOCK_SIZE 256
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);