ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366
This commit is contained in:
Radoslav Gerganov
2025-06-19 11:04:23 +03:00
committed by Georgi Gerganov
parent 7b50d589a8
commit c1a581a10b
5 changed files with 98 additions and 2 deletions

View File

@ -470,6 +470,7 @@ extern "C" {
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_SET_ROWS,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
@ -1375,6 +1376,12 @@ extern "C" {
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // destination
struct ggml_tensor * b, // source
struct ggml_tensor * c); // row indices
GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);

View File

@ -1814,6 +1814,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_get_rows_back(params, tensor);
} break;
case GGML_OP_SET_ROWS:
{
ggml_compute_forward_set_rows(params, tensor);
} break;
case GGML_OP_DIAG:
{
ggml_compute_forward_diag(params, tensor);
@ -2167,6 +2171,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = n_threads;
} break;
case GGML_OP_GET_ROWS:
case GGML_OP_SET_ROWS:
{
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
// decreases performance with GPU offloading

View File

@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
//}
}
static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(src0) == nr);
const int ith = params->ith;
const int nth = params->nth;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i/(ne11*ne10);
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
GGML_ASSERT(i01 >= 0 && i01 < ne1);
ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
}
}
void ggml_compute_forward_set_rows(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_set_rows_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_get_rows_back
static void ggml_compute_forward_get_rows_back_f32_f16(

View File

@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TRANSPOSE",
"GET_ROWS",
"GET_ROWS_BACK",
"SET_ROWS",
"DIAG",
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"transpose(x)",
"get_rows(x)",
"get_rows_back(x)",
"set_rows(x)",
"diag(x)",
"diag_mask_inf(x)",
"diag_mask_zero(x)",
@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -3395,6 +3397,28 @@ struct ggml_tensor * ggml_get_rows_back(
return result;
}
// ggml_set_rows
struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I32);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SET_ROWS;
result->src[0] = b;
result->src[1] = c;
return result;
}
// ggml_diag
struct ggml_tensor * ggml_diag(