context : add llama_context_recurrent

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-19 14:56:01 +02:00
parent 5f11a5502a
commit e17e4b72d1
5 changed files with 266 additions and 83 deletions

View File

@ -20,6 +20,8 @@ llama_context::llama_context(
model (model),
t_start_us(model.t_start_us),
t_load_us (model.t_load_us) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
const llama_context_params & params) :
llama_context(model, params),
kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
const auto & hparams = model.hparams;
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
inp_KQ_mask_swa_cnv = nullptr;
inp_KQ_mask_cross = nullptr;
inp_k_shift = nullptr;
inp_s_copy = nullptr;
inp_s_mask = nullptr;
inp_embd_enc = nullptr;
inp_pos_bucket = nullptr;
@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
}
}
if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;
if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
data[i] = (float) (kv_cell.src >= 0);
// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}
if (inp_pos_bucket) {
const int64_t n_tokens = ubatch.n_tokens;
@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
void llama_context_kv_self::build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
//cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
}
//cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
ggml_tensor * llama_context_kv_self::build_attn_qkv(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
}
}
ggml_build_forward_expand(graph, cur);
ggml_build_forward_expand(gf, cur);
if (wo) {
cur = build_lora_mm(ctx0, wo, cur);
@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
return inp_KQ_mask_cross;
}
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
//
// llama_context_recurrent
//
llama_context_recurrent::llama_context_recurrent(
const llama_model & model,
const llama_context_params & params) :
llama_context_kv_self(model, params) {
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
}
llama_context_recurrent::~llama_context_recurrent() = default;
ggml_cgraph * llama_context_recurrent::graph_init() {
inp_s_copy = nullptr;
inp_s_mask = nullptr;
return llama_context_kv_self::graph_init();
}
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_kv_self::input_set(ubatch);
GGML_ASSERT(kv_self.recurrent);
const int64_t n_kv = kv_self.n;
if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
data[i] = (float) (kv_cell.src >= 0);
// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
return inp_s_copy;
}
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
return inp_s_mask;
}
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
ggml_tensor * llama_context_recurrent::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
}
// TODO: split
ggml_tensor * llama_context_kv_self::build_mamba_layer(
ggml_tensor * llama_context_recurrent::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
}
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
return token_shift;
}
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
);
}
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,

View File

@ -433,15 +433,28 @@ public:
int32_t n_tokens,
bool worst_case) override;
// === recurrent ===
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};
// TODO: add recurrent cache
// TODO: add mamba-specific llama_context
// a recurrent transformer (ie.e RWKV, Mamba)
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
class llama_context_recurrent : public llama_context_kv_self {
public:
llama_context_recurrent(
const llama_model & model,
const llama_context_params & params);
virtual ~llama_context_recurrent();
virtual ggml_cgraph * graph_init() override;
virtual void input_set(const llama_ubatch & ubatch) override;
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
virtual ggml_tensor * build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) override;
@ -499,11 +512,10 @@ public:
bool worst_case) override;
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
// TODO: add recurrent cache
};
// For internal test use

View File

@ -1 +1,136 @@
#include "llama-graph.h"
#include "llama-impl.h"
ggml_tensor * llama_graph_i::build_inp_s_copy (
ggml_context * ctx0,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_tokens,
int32_t n_state,
int32_t n_seqs,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(s);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(n_tokens);
GGML_UNUSED(n_state);
GGML_UNUSED(n_seqs);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(token_shift);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}
ggml_tensor * llama_graph_i::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(x_prev);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr; // NOLINT
}

View File

@ -55,7 +55,7 @@ public:
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
ggml_backend_buffer * bbuft) = 0;
ggml_backend_buffer * bbuf) = 0;
// graph build API (context-specific)
@ -137,11 +137,11 @@ public:
virtual ggml_tensor * build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_copy_mask_state(
ggml_context * ctx0,
@ -152,7 +152,7 @@ public:
int32_t n_tokens,
int32_t n_state,
int32_t n_seqs,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_mamba_layer(
ggml_context * ctx0,
@ -162,7 +162,7 @@ public:
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_rwkv_token_shift_load(
ggml_context * ctx0,
@ -171,14 +171,14 @@ public:
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case) = 0;
bool worst_case);
virtual ggml_tensor * build_rwkv6_time_mix(
ggml_context * ctx0,
@ -189,5 +189,5 @@ public:
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) = 0;
bool worst_case);
};

View File

@ -326,8 +326,19 @@ struct llama_context * llama_init_from_model(
llama_context * ctx = nullptr;
try {
// TODO: add logic which llama_context implementation to construct
ctx = new llama_context_kv_self(*model, params);
// TODO: make static method of llama_context
switch (model->arch) {
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_MAMBA:
GGML_ASSERT(llama_model_is_recurrent(model));
ctx = new llama_context_recurrent(*model, params);
break;
default:
GGML_ASSERT(!llama_model_is_recurrent(model));
ctx = new llama_context_kv_self(*model, params);
};
ctx->init();
} catch (const std::exception & e) {
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());