ggml : add ggml_set_rows (#14274)

* ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366

* use I64 for indices

* ggml : add repeat impl for i64

* ggml : add ggml_is_contiguous_rows

* ggml : ggml_set_rows support broadcast

* ggml : ggml_set_rows support quantized dst

ggml-ci

* ggml : support GGML_TYPE_F32 ".from_float" trait

* ggml : ggml_set_rows update comment + better index name

* tests : add ggml_set_rows

* metal : add ggml_set_rows implementation

ggml-ci

* ggml : simplify forward_dup_f32

* ggml : fix supports_op

* tests : add comment to set_rows

* ggml : leave the repeat_i64 for a separate PR

ggml-ci

* ggml : set_rows use std::min instead of MIN

* ggml : better error message for set_rows unsupported type

* metal : perform op->type check only once

* tests : more consistent implementation + more tests

ggml-ci

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Radoslav Gerganov
2025-06-27 16:41:40 +03:00
committed by GitHub
parent f667f1e624
commit 8d94219a4a
12 changed files with 653 additions and 204 deletions

View File

@ -1213,6 +1213,76 @@ struct test_get_rows_back : public test_case {
}
};
// GGML_OP_SET_ROWS
struct test_set_rows : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int, 2> nr23; // broadcast only dims 2 and 3
const int r; // rows to set
const bool v; // view (non-contiguous src1)
std::string vars() override {
return VARS_TO_STR5(type, ne, nr23, r, v);
}
test_set_rows(ggml_type type,
std::array<int64_t, 4> ne,
std::array<int, 2> nr23,
int r, bool v = false)
: type(type), ne(ne), nr23(nr23), r(r), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(dst, "dst");
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r, ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(src, "src");
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
ggml_set_name(row_idxs, "row_idxs");
if (v) {
src = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
ggml_set_name(row_idxs, "view_of_rows");
}
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I64) {
if (ggml_is_view_op(t->op)) {
continue;
}
for (int i2 = 0; i2 < t->ne[2]; i2++) {
for (int i1 = 0; i1 < t->ne[1]; i1++) {
// generate a shuffled subset of row indices
std::vector<int64_t> data(ne[1]);
for (int i = 0; i < ne[1]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
data.resize(t->ne[0]);
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
}
}
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@ -3984,6 +4054,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
}
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_set_rows(type, { 256, 5, b, 3 }, { 1, 1, }, 1, v));
test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
if (ggml_blck_size(type) == 1) {
test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
}
}
}
}
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {