mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
kv-cache : separate recurrent vs non-recurrent impl (#12799)
* kv-cache : serparate recurrent vs non-recurrent impl (wip) ggml-ci * kv-cache : init -> contructor + add llama_memory_params ggml-ci * kv-cache : fix callback reference ggml-ci * context : llama_kv_cache -> llama_memory_i ggml-ci * context : move memory creation logic to model ggml-ci * llama : remove reference of memory during encode ggml-ci * kv-cache : hide padding details in the implementation ggml-ci * kv-cache : add ubatch_next() ggml-ci * context : simplify sbatch logic ggml-ci * kv-cache : hide defrag logic in the implementation ggml-ci * context : hide kv cache details in implementation ggml-ci * build : fix ggml-ci * cont : another fix ggml-ci * kv-cache : simplify interface (wip) ggml-ci * kv-cache : use separate KV cell structs for unified/recurrent ggml-ci * kv-cache : clean-up ggml-ci * model : better llama_model::create_model() signature ggml-ci * kv-cache : fix recurrent seq_rm() ggml-ci * kv-cache : replace `struct callbacks` with `llama_model &` ggml-ci * kv-cache : replace `struct graph_params` with `llama_context &` ggml-ci * kv-cache : fix offload check ggml-ci * context : avoid passing unique_ptr ggml-ci * kv-cache : avoid using the backends from the llama_context ref #13113 ggml-ci * kv-cache : more consistent debug logs [no ci] * kv-cache : do not pass the full llama_context for kv graphs ggml-ci * kv-cache : remove comment * kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext ggml-ci * kv-cache : fix recurrent multi-user case ggml-ci * memory : remove comments [no ci]
This commit is contained in:
@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
data[i] = kv_self->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
data[i] = kv_self->s_mask(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1105,7 +1077,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
||||
|
||||
@ -1122,7 +1094,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
||||
|
||||
@ -1436,8 +1408,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
GGML_ASSERT(!kv_self->recurrent);
|
||||
|
||||
const auto kv_head = kv_self->head;
|
||||
|
||||
GGML_ASSERT(kv_self->size == n_ctx);
|
||||
@ -1587,7 +1557,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
const auto kv_head = kv_self->head;
|
||||
@ -1619,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
@ -1640,7 +1610,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
||||
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
Reference in New Issue
Block a user