ggml : ggml_set_rows support broadcast

This commit is contained in:
Georgi Gerganov
2025-06-22 10:28:07 +03:00
parent 313a444b22
commit df71c803b4
3 changed files with 39 additions and 15 deletions

View File

@ -1379,6 +1379,15 @@ extern "C" {
struct ggml_tensor * b, // row indices struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
// a TD [n_embd, ne1, ne2, ne3]
// b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
// c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
//
// broadcast:
// ne2 % ne11 == 0
// ne3 % ne12 == 0
//
// return view(a)
GGML_API struct ggml_tensor * ggml_set_rows( GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, // destination struct ggml_tensor * a, // destination

View File

@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00; const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1); const int64_t nr = ne01;
assert(ne0 == nc); assert(ne0 == nc);
assert(ne02 == ne11); assert(ne2 == ne02);
assert(nb00 == sizeof(float)); assert(ne3 == ne03);
assert(ggml_nrows(src0) == nr); assert(src0->type == GGML_TYPE_F32);
assert(ne02 % ne11 == 0);
assert(ne03 % ne12 == 0);
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
const int ir0 = dr*ith; const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) { for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i/(ne11*ne10); const int64_t i12 = i03%ne12;
const int64_t i11 = (i - i12*ne11*ne10)/ne10; const int64_t i11 = i02%ne11;
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i10 = i;
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
GGML_ASSERT(i01 >= 0 && i01 < ne1); GGML_ASSERT(i01 >= 0 && i01 < ne1);
ggml_cpu_fp32_to_fp16( ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03), (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc); (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc);
}
}
} }
} }

View File

@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c) { struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]); GGML_ASSERT(a->ne[0] == b->ne[0]);
GGML_ASSERT(a->ne[2] == b->ne[2]);
GGML_ASSERT(a->ne[3] == b->ne[3]);
GGML_ASSERT(b->ne[1] == c->ne[0]);
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
GGML_ASSERT(c->ne[3] == 1); GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16); GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I64); GGML_ASSERT(c->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_is_contiguous_rows(a));
GGML_ASSERT(ggml_is_contiguous_rows(b));
struct ggml_tensor * result = ggml_view_tensor(ctx, a); struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SET_ROWS; result->op = GGML_OP_SET_ROWS;