mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-28 03:55:06 -04:00
tests : add non-cont K,V FA tests
ggml-ci
This commit is contained in:
@ -4366,26 +4366,32 @@ struct test_flash_attn_ext : public test_case {
|
||||
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
|
||||
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
|
||||
|
||||
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
|
||||
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
|
||||
int64_t ne[4] = {ne0, ne1, ne2, ne3};
|
||||
int64_t ne_perm[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ne_perm[permute[i]] = ne[i];
|
||||
}
|
||||
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
|
||||
ggml_tensor * t;
|
||||
if (is_view) {
|
||||
ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);
|
||||
t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);
|
||||
} else {
|
||||
t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
|
||||
}
|
||||
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
|
||||
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
|
||||
}
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
|
||||
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);
|
||||
ggml_set_name(q, "q");
|
||||
|
||||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
|
||||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
|
Reference in New Issue
Block a user