mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
This commit is contained in:
@ -83,7 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||
if (pos_bucket) {
|
||||
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,7 +234,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
@ -242,7 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_copy(i);
|
||||
data[i] = kv_state->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -250,7 +250,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
||||
if (s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
||||
@ -258,7 +258,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_self->s_mask(i);
|
||||
data[i] = kv_state->s_mask(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
@ -448,7 +448,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
backend_cpu (params.backend_cpu),
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
memory (params.memory),
|
||||
mstate (params.mstate),
|
||||
cross (params.cross),
|
||||
cb_func (params.cb),
|
||||
res (std::make_unique<llm_graph_result>()) {
|
||||
@ -954,11 +954,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
@ -971,11 +971,11 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_mask;
|
||||
|
||||
@ -1025,11 +1025,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
||||
|
||||
const auto n_kv = kv_self->get_n();
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->pos_bucket;
|
||||
|
||||
@ -1231,14 +1231,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = kv_self->get_n();
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
@ -1268,19 +1268,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_self->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_self->get_v(ctx0, il);
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -1301,12 +1301,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_self->get_kv_base()->get_n();
|
||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
@ -1318,7 +1318,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
@ -1348,23 +1348,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
||||
|
||||
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
||||
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv->get_v(ctx0, il);
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -1446,12 +1446,12 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
|
||||
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
@ -1478,13 +1478,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
ggml_tensor * token_shift_all = kv_self->k_l[il];
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_copy_mask_state(
|
||||
gf, token_shift_all, state_copy, state_mask,
|
||||
@ -1499,19 +1499,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
return ggml_cpy(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||
ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
|
||||
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
||||
);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user