kv-cache : add SWA support (#13194)

* kv-cache : prepare for SWA

ggml-ci

* kv-cache : initial iSWA implementation

ggml-ci

* kv-cache : rework error recovery logic

ggml-ci

* models : fix Phi-3 SWA parameters

ggml-ci

* model : adjust Granite to rope factor changes

ggml-ci

* server : check if context can do shifts

ggml-ci

* iswa : for now, always enable shifts (experiment)

ggml-ci

* kv-cache : simplify SWA logic

ggml-ci

* kv-cache : apply defrag when we fail to find slots for the batch

ggml-ci

* llama : update docs about llama_decode

ggml-ci

* kv-cache : update warning logs when no space for the batch is available

ggml-ci

* llama : add llama_kv_self_seq_pos_min()

* kv-cache : keep track of partial SWA computes and print warnings

* server : disallow use cases involving partial SWA context

ggml-ci

* llama : add param to control SWA cache size

ggml-ci

* minor : clean-up

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-20 08:05:46 +03:00
committed by GitHub
parent f0adb80bf7
commit e298d2fbd0
15 changed files with 1426 additions and 650 deletions

View File

@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_keep = value; params.n_keep = value;
} }
)); ));
add_opt(common_arg(
{"--swa-full"},
string_format("use full-size SWA cache (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
[](common_params & params) {
params.swa_full = true;
}
));
add_opt(common_arg( add_opt(common_arg(
{"--no-context-shift"}, {"--no-context-shift"},
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

View File

@ -1136,6 +1136,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload; cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
if (params.reranking) { if (params.reranking) {
cparams.embeddings = true; cparams.embeddings = true;

View File

@ -323,6 +323,7 @@ struct common_params {
bool flash_attn = false; // flash attention bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads

View File

@ -361,10 +361,11 @@ extern "C" {
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits) bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool flash_attn; // use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings bool no_perf; // measure performance timings
bool op_offload; // whether to offload host tensor operations to device bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
}; };
// model quantization parameters // model quantization parameters
@ -730,10 +731,18 @@ extern "C" {
llama_pos p1, llama_pos p1,
int d); int d);
// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);
// Returns the largest position present in the KV cache for the specified sequence // Returns the largest position present in the KV cache for the specified sequence
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max( LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:
@ -943,9 +952,12 @@ extern "C" {
// Requires KV cache. // Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder. // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // Upon non-zero return values, the KV cache state is restored to the state before this call
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
// -1 - invalid input batch
// < -1 - error
LLAMA_API int32_t llama_decode( LLAMA_API int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);

View File

@ -93,6 +93,7 @@ llama_context::llama_context(
} }
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.op_offload = params.op_offload; cparams.op_offload = params.op_offload;
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@ -176,8 +177,9 @@ llama_context::llama_context(
// init the memory module // init the memory module
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
llama_memory_params params_mem = { llama_memory_params params_mem = {
/*.type_k =*/ params.type_k, /*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v, /*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
}; };
memory.reset(model.create_memory(params_mem, cparams)); memory.reset(model.create_memory(params_mem, cparams));
@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot // find KV slot
if (!kv_self->find_slot(ubatch)) { if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
return 1; return 1;
} }
@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
/*.flash_attn =*/ false, /*.flash_attn =*/ false,
/*.no_perf =*/ true, /*.no_perf =*/ true,
/*.op_offload =*/ true, /*.op_offload =*/ true,
/*.swa_full =*/ true,
}; };
return result; return result;
@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
kv->seq_div(seq_id, p0, p1, d); kv->seq_div(seq_id, p0, p1, d);
} }
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return -1;
}
return kv->seq_pos_min(seq_id);
}
// deprecated // deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id); return llama_kv_self_seq_pos_max(ctx, seq_id);
@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); const auto * kv = ctx->get_kv_self();
if (!kv) { if (!kv) {
return 0; return -1;
} }
return kv->seq_pos_max(seq_id); return kv->seq_pos_max(seq_id);
@ -2637,7 +2647,21 @@ int32_t llama_encode(
int32_t llama_decode( int32_t llama_decode(
llama_context * ctx, llama_context * ctx,
llama_batch batch) { llama_batch batch) {
const int ret = ctx->decode(batch); int ret = ctx->decode(batch);
// defrag and try again
// TODO: distinguish return code when we are sure that even after defrag there is no space available
if (ret == 1) {
llama_kv_self_defrag(ctx);
ret = ctx->decode(batch);
if (ret == 1) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
return ret;
}
}
if (ret != 0) { if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }

View File

@ -9,33 +9,6 @@
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
if (bidirectional) {
n_buckets >>= 1;
}
const int64_t max_exact = n_buckets >> 1;
int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) { if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens;
@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) { if (pos_bucket) {
const int64_t n_tokens = ubatch->n_tokens; kv_self->set_input_pos_bucket(pos_bucket, ubatch);
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
int32_t * data = (int32_t *) pos_bucket->data;
const int64_t n_kv = kv_self->n;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
}
}
}
} }
} }
@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) { if (self_kq_mask) {
const int64_t n_kv = kv_self->n; kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
const int64_t n_tokens = ubatch->n_tokens; }
const int64_t n_seq_tokens = ubatch->n_seq_tokens; }
const int64_t n_seqs = ubatch->n_seqs;
float * data = nullptr; void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
float * data_swa = nullptr; if (self_kq_mask) {
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_kq_mask) { if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
data = (float *) self_kq_mask->data;
}
if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
data_swa = (float *) self_kq_mask_swa->data;
}
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
// mask the token if:
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}
if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
if (data_swa) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
f = -INFINITY;
}
} else {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}
// mask padded tokens
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
// mask padded tokens
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
} }
} }
@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
n_layer (hparams.n_layer), n_layer (hparams.n_layer),
n_rot (hparams.n_rot), n_rot (hparams.n_rot),
n_ctx (cparams.n_ctx), n_ctx (cparams.n_ctx),
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
n_head (hparams.n_head()), n_head (hparams.n_head()),
n_head_kv (hparams.n_head_kv()), n_head_kv (hparams.n_head_kv()),
n_embd_head_k (hparams.n_embd_head_k), n_embd_head_k (hparams.n_embd_head_k),
@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self); auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
const auto n_kv = kv_self->n; const auto n_kv = kv_self->get_n();
auto & cur = inp->pos_bucket; auto & cur = inp->pos_bucket;
@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * kq_mask, ggml_tensor * kq_mask,
ggml_tensor * v_mla, ggml_tensor * v_mla,
bool v_trans,
float kq_scale) const { float kq_scale) const {
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const bool v_trans = v->nb[1] > v->nb[2];
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
//const int64_t n_head = hparams.n_head(il); q = ggml_permute(ctx0, q, 0, 2, 1, 3);
//const int64_t n_head_kv = hparams.n_head_kv(il); k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto n_tokens = q->ne[1]; const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2]; const auto n_head = q->ne[2];
@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); ggml_tensor * q = q_cur;
//cb(q, "q", il); ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
if (wo) { if (wo) {
@ -1369,22 +1235,17 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self); auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
const auto n_kv = kv_self->n; {
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); const auto n_kv = kv_self->get_n();
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
if (hparams.n_swa_pattern > 1) { inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
GGML_ASSERT(hparams.n_swa > 0);
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
} }
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@ -1409,81 +1270,100 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory); const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto & n_ctx = cparams.n_ctx;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const auto n_tokens = q_cur->ne[2];
const bool v_trans = !cparams.flash_attn;
// store to KV cache // store to KV cache
{ {
const auto kv_head = kv_self->head; ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
GGML_ASSERT(kv_self->size == n_ctx);
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
//cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
ggml_tensor * v_cache_view = nullptr;
if (!v_trans) {
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self->v_l[il]));
v_cur = ggml_transpose(ctx0, v_cur);
}
//cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
} }
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_self->get_k(ctx0, il);
ggml_tensor * v = kv_self->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
{
const auto n_kv = kv_self->get_kv_base()->get_n();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
if (hparams.n_swa_pattern > 1) {
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_self->get_kv_swa()->get_n();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const bool is_swa = hparams.is_swa(il); const bool is_swa = hparams.is_swa(il);
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
// store to KV cache
{
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
const auto n_kv = kv_self->n; ggml_tensor * q = q_cur;
ggml_tensor * k = kv->get_k(ctx0, il);
ggml_tensor * v = kv->get_v(ctx0, il);
const int64_t n_head_kv = hparams.n_head_kv(il); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
const auto & n_embd_head_k = hparams.n_embd_head_k;
const auto & n_embd_head_v = hparams.n_embd_head_v;
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);
ggml_tensor * k =
ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
0);
//cb(k, "k", il);
ggml_tensor * v = !v_trans ?
ggml_view_3d(ctx0, kv_self->v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
0) :
ggml_view_3d(ctx0, kv_self->v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self->v_l[il])*n_ctx,
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
if (wo) { if (wo) {
@ -1534,17 +1414,11 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask_cross(); const auto & kq_mask = inp->get_kq_mask_cross();
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); ggml_tensor * q = q_cur;
//cb(q, "q", il); ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
//cb(k, "k", il);
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
if (wo) { if (wo) {
@ -1712,3 +1586,30 @@ void llm_graph_context::build_pooling(
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
if (bidirectional) {
n_buckets >>= 1;
}
const int64_t max_exact = n_buckets >> 1;
int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
}

View File

@ -19,6 +19,7 @@ struct llama_cparams;
class llama_memory_i; class llama_memory_i;
class llama_kv_cache_unified; class llama_kv_cache_unified;
class llama_kv_cache_unified_iswa;
class llama_kv_cache_recurrent; class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs // certain models (typically multi-modal) can produce different types of graphs
@ -255,6 +256,31 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa * kv_self) :
hparams(hparams),
cparams(cparams),
kv_self(kv_self) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
@ -266,7 +292,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self; const llama_kv_cache_unified_iswa * kv_self;
}; };
class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_attn_cross : public llm_graph_input_i {
@ -378,7 +404,6 @@ struct llm_graph_context {
const int64_t n_layer; const int64_t n_layer;
const int64_t n_rot; const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_ctx_per_seq;
const int64_t n_head; const int64_t n_head;
const int64_t n_head_kv; const int64_t n_head_kv;
const int64_t n_embd_head_k; const int64_t n_embd_head_k;
@ -507,13 +532,12 @@ struct llm_graph_context {
ggml_tensor * build_attn_mha( ggml_tensor * build_attn_mha(
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b, ggml_tensor * kq_b,
ggml_tensor * kq_mask, ggml_tensor * kq_mask,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
bool v_trans,
float kq_scale) const; float kq_scale) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@ -546,6 +570,21 @@ struct llm_graph_context {
float kq_scale, float kq_scale,
int il) const; int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const; llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn( ggml_tensor * build_attn(
@ -596,3 +635,6 @@ struct llm_graph_context {
ggml_tensor * cls_out, ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const; ggml_tensor * cls_out_b) const;
}; };
// TODO: better name
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);

View File

@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
}; };
enum llama_swa_type {
LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
};
struct llama_hparams_posnet { struct llama_hparams_posnet {
uint32_t n_embd; uint32_t n_embd;
uint32_t n_layer; uint32_t n_layer;
@ -35,8 +41,6 @@ struct llama_hparams {
uint32_t n_embd_features = 0; uint32_t n_embd_features = 0;
uint32_t n_layer; uint32_t n_layer;
uint32_t n_rot; uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0; uint32_t n_expert = 0;
@ -96,6 +100,12 @@ struct llama_hparams {
std::array<int, 4> rope_sections; std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
// for State Space Models // for State Space Models
uint32_t ssm_d_conv = 0; uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0; uint32_t ssm_d_inner = 0;
@ -116,11 +126,10 @@ struct llama_hparams {
bool causal_attn = true; bool causal_attn = true;
bool use_alibi = false; bool use_alibi = false;
bool attn_soft_cap = false; bool attn_soft_cap = false;
bool use_kq_norm = true;
// llama4
uint32_t n_moe_layer_step = 0; uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4; uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192; uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1; float f_attn_temp_scale = 0.1;

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include <set> #include <set>
#include <unordered_map>
#include <vector> #include <vector>
struct llama_cparams; struct llama_cparams;
@ -40,6 +41,9 @@ struct llama_kv_cache : public llama_memory_i {
// batch processing // batch processing
// //
// =============================================================================================================
// TODO: refactor and simplify this
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies // different KV caches require different batch splitting strategies
@ -48,6 +52,8 @@ struct llama_kv_cache : public llama_memory_i {
// find an empty slot of size "n_tokens" in the cache // find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0; virtual bool find_slot(const llama_ubatch & batch) = 0;
// =============================================================================================================
// getters // getters
virtual int32_t get_n_tokens() const = 0; virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
@ -87,38 +93,24 @@ private:
// llama_kv_cache_unified // llama_kv_cache_unified
// //
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_unified : public llama_kv_cache {
public: public:
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
static uint32_t get_padding(const llama_cparams & cparams); static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_unified( llama_kv_cache_unified(
const llama_model & model, const llama_model & model,
ggml_type type_k, layer_filter_cb && filter,
ggml_type type_v, ggml_type type_k,
bool v_trans, ggml_type type_v,
bool offload, bool v_trans,
uint32_t kv_size, bool offload,
uint32_t padding); uint32_t kv_size,
uint32_t padding,
uint32_t n_swa,
llama_swa_type swa_type);
~llama_kv_cache_unified() = default; ~llama_kv_cache_unified() = default;
@ -130,10 +122,11 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override; void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// //
@ -150,7 +143,6 @@ public:
void set_full() override; void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
// updates the cache head // updates the cache head
@ -169,29 +161,72 @@ public:
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) //
uint32_t size = 0; // total number of cells, shared across all sequences // llama_kv_cache_unified specific API
uint32_t used = 0; // used cells (i.e. at least one seq_id) //
// computed before each graph build uint32_t get_n() const;
uint32_t n = 0; uint32_t get_size() const;
std::vector<kv_cell> cells; // get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
std::vector<ggml_tensor *> k_l; // per layer // store k_cur and v_cur in the cache based on the current head location
std::vector<ggml_tensor *> v_l; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private: private:
const llama_model & model; const llama_model & model;
const llama_hparams & hparams; const llama_hparams & hparams;
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
// TODO: replace with bitset uint64_t
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
struct kv_layer {
// layer index in the model
// note: can be different from the layer index in the KV cache
uint32_t il;
ggml_tensor * k;
ggml_tensor * v;
};
bool has_shift = false; bool has_shift = false;
bool do_defrag = false; bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
// computed before each graph build
uint32_t n = 0;
// required padding // required padding
uint32_t padding = 1; uint32_t padding = 1;
@ -199,9 +234,29 @@ private:
ggml_type type_k = GGML_TYPE_F16; ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16;
// SWA
uint32_t n_swa = 0;
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
std::vector<kv_layer> layers;
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// recovery information used to restore the KV cells to their original state in case of a failure
struct {
void clear() {
cells.clear();
}
std::unordered_map<uint32_t, kv_cell> cells;
} recovery;
// defrag // defrag
struct { struct {
std::vector<uint32_t> ids; std::vector<uint32_t> ids;
@ -210,17 +265,6 @@ private:
// return true if cells have been moved // return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes); bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// find how many cells are currently in use // find how many cells are currently in use
uint32_t cell_max() const; uint32_t cell_max() const;
@ -229,6 +273,8 @@ private:
size_t size_k_bytes() const; size_t size_k_bytes() const;
size_t size_v_bytes() const; size_t size_v_bytes() const;
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
ggml_tensor * build_rope_shift( ggml_tensor * build_rope_shift(
const llama_cparams & cparams, const llama_cparams & cparams,
ggml_context * ctx, ggml_context * ctx,
@ -255,6 +301,106 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
//
// llama_kv_cache_unified_iswa
//
// utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
class llama_kv_cache_unified_iswa : public llama_kv_cache {
public:
llama_kv_cache_unified_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
bool swa_full,
uint32_t n_seq_max,
uint32_t n_batch,
uint32_t padding);
~llama_kv_cache_unified_iswa() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
//
// llama_kv_cache_unified_iswa specific API
//
llama_kv_cache_unified * get_kv_base() const;
llama_kv_cache_unified * get_kv_swa () const;
private:
const llama_hparams & hparams;
bool do_prune = true;
struct {
struct entry {
llama_pos pmin;
llama_pos pmax;
};
void clear() {
pos.clear();
}
// used to perform SWA pruning of old tokens
std::unordered_map<llama_seq_id, entry> pos;
} pending;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
// //
// llama_kv_cache_recurrent // llama_kv_cache_recurrent
// //
@ -302,6 +448,7 @@ public:
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// //
@ -318,7 +465,6 @@ public:
void set_full() override; void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override; bool find_slot(const llama_ubatch & batch) override;

View File

@ -7,8 +7,8 @@ struct llama_memory_params {
ggml_type type_k; ggml_type type_k;
ggml_type type_v; ggml_type type_v;
// parameters for other types of memory // use full-size SWA cache
// ... bool swa_full;
}; };
// general concept of LLM memory // general concept of LLM memory
@ -25,6 +25,7 @@ public:
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual bool get_can_edit() const = 0; virtual bool get_can_edit() const = 0;

View File

@ -571,9 +571,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
switch (hparams.n_expert) { switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break; case 16: type = LLM_TYPE_17B_16E; break;
@ -855,20 +856,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 2047; hparams.n_swa = 2047;
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-mini-128k-instruct // default value for Phi-3-mini-128k-instruct
// note: this seems incorrect because the window is bigger than the train context? LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__);
hparams.n_swa = 262144;
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
// default value for Phi-3-medium-128k-instruct // default value for Phi-3-medium-128k-instruct
// note: this seems incorrect because the window is equal to the train context? LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__);
hparams.n_swa = 131072;
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
} }
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (!found_swa && hparams.n_swa == 0) { if (!found_swa && hparams.n_swa == 0) {
throw std::runtime_error("invalid value for sliding_window"); throw std::runtime_error("invalid value for sliding_window");
} }
if (hparams.n_swa > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train);
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = hparams.n_ctx_train;
hparams.n_swa_pattern = 1;
}
} break; } break;
case LLM_ARCH_PHIMOE: case LLM_ARCH_PHIMOE:
{ {
@ -937,6 +960,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break; } break;
case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA2:
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2 hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2; hparams.n_swa_pattern = 2;
hparams.attn_soft_cap = true; hparams.attn_soft_cap = true;
@ -955,6 +979,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6; hparams.n_swa_pattern = 6;
hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_base_train_swa = 10000.0f;
@ -1039,6 +1064,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break; } break;
case LLM_ARCH_COHERE2: case LLM_ARCH_COHERE2:
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4; hparams.n_swa_pattern = 4;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
@ -4489,7 +4515,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
return it->second; return it->second;
} }
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const { float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const {
return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
}
float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const {
return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
}
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
// choose long/short freq factors based on the context size // choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) { if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs; return layers[il].rope_freqs;
@ -4517,22 +4553,13 @@ struct llm_build_llama : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
if (arch == LLM_ARCH_LLAMA4) {
inp_attn_scale = build_inp_attn_scale();
}
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL; ggml_tensor * inpSA = inpL;
bool use_rope = arch == LLM_ARCH_LLAMA4
? (il + 1) % hparams.n_no_rope_layer_step != 0
: true;
// norm // norm
cur = build_norm(inpL, cur = build_norm(inpL,
model.layers[il].attn_norm, NULL, model.layers[il].attn_norm, NULL,
@ -4542,7 +4569,169 @@ struct llm_build_llama : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(cur, "ffn_moe_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_llama_iswa : public llm_graph_context {
llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// temperature tuning
ggml_tensor * inp_attn_scale = nullptr;
inp_attn_scale = build_inp_attn_scale();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -4590,7 +4779,7 @@ struct llm_build_llama : public llm_graph_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { if (use_rope && hparams.use_kq_norm) {
// Llama4TextL2Norm // Llama4TextL2Norm
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
@ -4614,23 +4803,7 @@ struct llm_build_llama : public llm_graph_context {
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il); cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE) {
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else if (arch == LLM_ARCH_LLAMA4) {
// llama4 MoE // llama4 MoE
ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL, model.layers[il].ffn_norm, NULL,
@ -4660,26 +4833,6 @@ struct llm_build_llama : public llm_graph_context {
cur = ggml_add(ctx0, moe_out, shexp_out); cur = ggml_add(ctx0, moe_out, shexp_out);
cb(cur, "ffn_moe_out_merged", il); cb(cur, "ffn_moe_out_merged", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(cur, "ffn_moe_out", il);
} }
cur = ggml_add(ctx0, cur, ffn_inp); cur = ggml_add(ctx0, cur, ffn_inp);
@ -4753,7 +4906,7 @@ struct llm_build_deci : public llm_graph_context {
} else if (n_head > 0) { } else if (n_head > 0) {
// self-attention // self-attention
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -7202,8 +7355,8 @@ struct llm_build_phi2 : public llm_graph_context {
} }
}; };
struct llm_build_phi3 : public llm_graph_context { struct llm_build_phi3_iswa : public llm_graph_context {
llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { llm_build_phi3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@ -7217,7 +7370,7 @@ struct llm_build_phi3 : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified_iswa();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
auto * residual = inpL; auto * residual = inpL;
@ -7225,7 +7378,7 @@ struct llm_build_phi3 : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for 128k context // rope freq factors for 128k context
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
ggml_tensor* attn_norm_output = build_norm(inpL, ggml_tensor* attn_norm_output = build_norm(inpL,
model.layers[il].attn_norm, model.layers[il].attn_norm,
@ -7977,7 +8130,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL; ggml_tensor * inpSA = inpL;
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// norm // norm
cur = build_norm(inpL, cur = build_norm(inpL,
@ -8277,8 +8430,8 @@ struct llm_build_gemma : public llm_graph_context {
} }
}; };
struct llm_build_gemma2 : public llm_graph_context { struct llm_build_gemma2_iswa : public llm_graph_context {
llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur; ggml_tensor * cur;
@ -8292,7 +8445,7 @@ struct llm_build_gemma2 : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified_iswa();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// norm // norm
@ -8414,8 +8567,8 @@ struct llm_build_gemma2 : public llm_graph_context {
} }
}; };
struct llm_build_gemma3 : public llm_graph_context { struct llm_build_gemma3_iswa : public llm_graph_context {
llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k; const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur; ggml_tensor * cur;
@ -8433,13 +8586,11 @@ struct llm_build_gemma3 : public llm_graph_context {
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes // TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified_iswa();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
const bool is_swa = hparams.is_swa(il); const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
// norm // norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@ -9016,8 +9167,8 @@ struct llm_build_command_r : public llm_graph_context {
} }
}; };
struct llm_build_cohere2 : public llm_graph_context { struct llm_build_cohere2_iswa : public llm_graph_context {
llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -9032,7 +9183,7 @@ struct llm_build_cohere2 : public llm_graph_context {
// inp_pos - contains the positions // inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified(); auto * inp_attn = build_attn_inp_kv_unified_iswa();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
const bool is_swa = hparams.is_swa(il); const bool is_swa = hparams.is_swa(il);
@ -9045,7 +9196,7 @@ struct llm_build_cohere2 : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for 128k context // rope freq factors for 128k context
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -9983,7 +10134,7 @@ struct llm_build_deepseek : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -11347,7 +11498,7 @@ struct llm_build_exaone : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -12263,7 +12414,7 @@ struct llm_build_granite : public llm_graph_context {
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) { if (use_rope) {
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -12916,7 +13067,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
// self-attention // self-attention
{ {
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them // compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -13068,14 +13219,31 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
res = new llama_kv_cache_unified( if (hparams.n_swa > 0) {
*this, res = new llama_kv_cache_unified_iswa(
params.type_k, *this,
params.type_v, params.type_k,
!cparams.flash_attn, params.type_v,
cparams.offload_kqv, !cparams.flash_attn,
cparams.n_ctx, cparams.offload_kqv,
padding); cparams.n_ctx,
params.swa_full,
cparams.n_seq_max,
cparams.n_batch,
padding);
} else {
res = new llama_kv_cache_unified(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
padding,
hparams.n_swa,
hparams.swa_type);
}
} }
} }
@ -13090,11 +13258,14 @@ llm_graph_result_ptr llama_model::build_graph(
switch (arch) { switch (arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_MINICPM: case LLM_ARCH_MINICPM:
{ {
llm = std::make_unique<llm_build_llama>(*this, params, gf); llm = std::make_unique<llm_build_llama>(*this, params, gf);
} break; } break;
case LLM_ARCH_LLAMA4:
{
llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
} break;
case LLM_ARCH_DECI: case LLM_ARCH_DECI:
{ {
llm = std::make_unique<llm_build_deci>(*this, params, gf); llm = std::make_unique<llm_build_deci>(*this, params, gf);
@ -13169,7 +13340,7 @@ llm_graph_result_ptr llama_model::build_graph(
case LLM_ARCH_PHI3: case LLM_ARCH_PHI3:
case LLM_ARCH_PHIMOE: case LLM_ARCH_PHIMOE:
{ {
llm = std::make_unique<llm_build_phi3>(*this, params, gf); llm = std::make_unique<llm_build_phi3_iswa>(*this, params, gf);
} break; } break;
case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO:
{ {
@ -13201,11 +13372,11 @@ llm_graph_result_ptr llama_model::build_graph(
} break; } break;
case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA2:
{ {
llm = std::make_unique<llm_build_gemma2>(*this, params, gf); llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
llm = std::make_unique<llm_build_gemma3>(*this, params, gf); llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
} break; } break;
case LLM_ARCH_STARCODER2: case LLM_ARCH_STARCODER2:
{ {
@ -13225,7 +13396,7 @@ llm_graph_result_ptr llama_model::build_graph(
} break; } break;
case LLM_ARCH_COHERE2: case LLM_ARCH_COHERE2:
{ {
llm = std::make_unique<llm_build_cohere2>(*this, params, gf); llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
} break; } break;
case LLM_ARCH_DBRX: case LLM_ARCH_DBRX:
{ {

View File

@ -398,7 +398,10 @@ struct llama_model {
const struct ggml_tensor * get_tensor(const char * name) const; const struct ggml_tensor * get_tensor(const char * name) const;
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const; float get_rope_freq_base (const llama_cparams & cparams, int il) const;
float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
// note: can mutate `cparams` // note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface // TODO: move this to new llm_arch_model_i interface

View File

@ -991,6 +991,7 @@ struct cmd_params_instance {
cparams.flash_attn = flash_attn; cparams.flash_attn = flash_attn;
cparams.embeddings = embeddings; cparams.embeddings = embeddings;
cparams.op_offload = !no_op_offload; cparams.op_offload = !no_op_offload;
cparams.swa_full = false;
return cparams; return cparams;
} }

View File

@ -2004,6 +2004,23 @@ struct server_context {
} }
} }
if (!llama_kv_self_can_shift(ctx)) {
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
}
if (params_base.n_cache_reuse) {
params_base.n_cache_reuse = 0;
SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
}
if (!params_base.speculative.model.path.empty()) {
SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
return false;
}
}
return true; return true;
} }
@ -3181,7 +3198,15 @@ struct server_context {
// if we don't cache the prompt, we have to remove the entire KV cache // if we don't cache the prompt, we have to remove the entire KV cache
llama_kv_self_seq_rm(ctx, slot.id, 0, -1); llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
slot.n_past = 0; slot.n_past = 0;
slot.cache_tokens.clear(); slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()"
}
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;
}
} }
} }