mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-12 14:14:22 +00:00
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
ggml-ci
This commit is contained in:
@ -1510,8 +1510,14 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// a [ne0, ne01, ne02, ne03]
|
||||
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
|
||||
//
|
||||
// broadcast:
|
||||
// ne02 % ne12 == 0
|
||||
// ne03 % ne13 == 0
|
||||
//
|
||||
// fused soft_max(a*scale + mask*(ALiBi slope))
|
||||
// mask is optional
|
||||
// max_bias = 0.0f for no ALiBi
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||
struct ggml_context * ctx,
|
||||
@ -1974,11 +1980,16 @@ extern "C" {
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, 1]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, 1]
|
||||
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
|
||||
// q: [n_embd_k, n_batch, n_head, ne3]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, ne3]
|
||||
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
|
||||
//
|
||||
// broadcast:
|
||||
// n_head % n_head_kv == 0
|
||||
// ne3 % ne32 == 0
|
||||
//
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
Reference in New Issue
Block a user