From e17e4b72d16710ee430b6858d58ce6ab3f4a31bb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Feb 2025 14:56:01 +0200 Subject: [PATCH] context : add llama_context_recurrent ggml-ci --- src/llama-context.cpp | 151 ++++++++++++++++++++++++------------------ src/llama-context.h | 32 ++++++--- src/llama-graph.cpp | 135 +++++++++++++++++++++++++++++++++++++ src/llama-graph.h | 16 ++--- src/llama.cpp | 15 ++++- 5 files changed, 266 insertions(+), 83 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bec82b446..b571c9343 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -20,6 +20,8 @@ llama_context::llama_context( model (model), t_start_us(model.t_start_us), t_load_us (model.t_load_us) { + LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + const auto & hparams = model.hparams; cparams.n_seq_max = std::max(1u, params.n_seq_max); @@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self( const llama_context_params & params) : llama_context(model, params), kv_self(model.hparams) { + LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); + const auto & hparams = model.hparams; LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); @@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() { inp_KQ_mask_swa_cnv = nullptr; inp_KQ_mask_cross = nullptr; inp_k_shift = nullptr; - inp_s_copy = nullptr; - inp_s_mask = nullptr; inp_embd_enc = nullptr; inp_pos_bucket = nullptr; @@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; - - if (inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer)); - float * data = (float *) inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // TODO: do not mutate the KV cache - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - - if (inp_s_copy) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer)); - int32_t * data = (int32_t *) inp_s_copy->data; - - // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = kv_self.cells[cell_id]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } - } - } - } - if (inp_pos_bucket) { const int64_t n_tokens = ubatch.n_tokens; @@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp( void llama_context_kv_self::build_attn_kv_store( ggml_context * ctx0, - ggml_cgraph * graph, + ggml_cgraph * gf, ggml_tensor * k_cur, ggml_tensor * v_cur, int32_t n_tokens, @@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store( //cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); @@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store( } //cb(v_cache_view, "v_cache_view", il); - ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); } ggml_tensor * llama_context_kv_self::build_attn_qkv( ggml_context * ctx0, - ggml_cgraph * graph, + ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv( } } - ggml_build_forward_expand(graph, cur); + ggml_build_forward_expand(gf, cur); if (wo) { cur = build_lora_mm(ctx0, wo, cur); @@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross( return inp_KQ_mask_cross; } -ggml_tensor * llama_context_kv_self::build_inp_s_copy( +// +// llama_context_recurrent +// + +llama_context_recurrent::llama_context_recurrent( + const llama_model & model, + const llama_context_params & params) : + llama_context_kv_self(model, params) { + LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); +} + +llama_context_recurrent::~llama_context_recurrent() = default; + +ggml_cgraph * llama_context_recurrent::graph_init() { + inp_s_copy = nullptr; + inp_s_mask = nullptr; + + return llama_context_kv_self::graph_init(); +} + +void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { + // call base functionality + llama_context_kv_self::input_set(ubatch); + + GGML_ASSERT(kv_self.recurrent); + + const int64_t n_kv = kv_self.n; + + if (inp_s_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer)); + float * data = (float *) inp_s_mask->data; + + // clear unused states + for (int i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = kv_self.cells[cell_id]; + + data[i] = (float) (kv_cell.src >= 0); + + // TODO: do not mutate the KV cache + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; + } + } + } + + if (inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer)); + int32_t * data = (int32_t *) inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = kv_self.cells[cell_id]; + + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; + } + } + } +} + +ggml_tensor * llama_context_recurrent::build_inp_s_copy( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy( return inp_s_copy; } -ggml_tensor * llama_context_kv_self::build_inp_s_mask( +ggml_tensor * llama_context_recurrent::build_inp_s_mask( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask( return inp_s_mask; } -ggml_tensor * llama_context_kv_self::build_copy_mask_state( +ggml_tensor * llama_context_recurrent::build_copy_mask_state( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * s, @@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state( } // TODO: split -ggml_tensor * llama_context_kv_self::build_mamba_layer( +ggml_tensor * llama_context_recurrent::build_mamba_layer( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, @@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer( } -ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load( +ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * state_copy, @@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load( return token_shift; } - -ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store( +ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, @@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store( ); } - -ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix( +ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, diff --git a/src/llama-context.h b/src/llama-context.h index a256f3042..133eb8b36 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -433,15 +433,28 @@ public: int32_t n_tokens, bool worst_case) override; - // === recurrent === +protected: + virtual size_t state_get_data(llama_io_write_i & io) override; + virtual size_t state_set_data(llama_io_read_i & io) override; - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; + virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; +}; - // TODO: add recurrent cache - // TODO: add mamba-specific llama_context +// a recurrent transformer (ie.e RWKV, Mamba) +// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache +class llama_context_recurrent : public llama_context_kv_self { +public: + llama_context_recurrent( + const llama_model & model, + const llama_context_params & params); + + virtual ~llama_context_recurrent(); + + virtual ggml_cgraph * graph_init() override; + + virtual void input_set(const llama_ubatch & ubatch) override; - // TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, bool worst_case) override; @@ -499,11 +512,10 @@ public: bool worst_case) override; protected: - virtual size_t state_get_data(llama_io_write_i & io) override; - virtual size_t state_set_data(llama_io_read_i & io) override; + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; - virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; + // TODO: add recurrent cache }; // For internal test use diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 20f2ee0bd..17605e74c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1 +1,136 @@ #include "llama-graph.h" + +#include "llama-impl.h" + +ggml_tensor * llama_graph_i::build_inp_s_copy ( + ggml_context * ctx0, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_inp_s_mask( + ggml_context * ctx0, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_copy_mask_state( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + int32_t n_tokens, + int32_t n_state, + int32_t n_seqs, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(s); + GGML_UNUSED(state_copy); + GGML_UNUSED(state_mask); + GGML_UNUSED(n_tokens); + GGML_UNUSED(n_state); + GGML_UNUSED(n_seqs); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_mamba_layer( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(cur); + GGML_UNUSED(state_copy); + GGML_UNUSED(state_mask); + GGML_UNUSED(ubatch); + GGML_UNUSED(il); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_rwkv_token_shift_load( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(state_copy); + GGML_UNUSED(state_mask); + GGML_UNUSED(ubatch); + GGML_UNUSED(il); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_rwkv_token_shift_store( + ggml_context * ctx0, + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(token_shift); + GGML_UNUSED(ubatch); + GGML_UNUSED(il); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} + +ggml_tensor * llama_graph_i::build_rwkv6_time_mix( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(cur); + GGML_UNUSED(x_prev); + GGML_UNUSED(state_copy); + GGML_UNUSED(state_mask); + GGML_UNUSED(ubatch); + GGML_UNUSED(il); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + + return nullptr; // NOLINT +} diff --git a/src/llama-graph.h b/src/llama-graph.h index bb51b9a91..b9456e3d1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -55,7 +55,7 @@ public: ggml_tensor * cur, ggml_tensor * shift, ggml_tensor * factors, - ggml_backend_buffer * bbuft) = 0; + ggml_backend_buffer * bbuf) = 0; // graph build API (context-specific) @@ -137,11 +137,11 @@ public: virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_inp_s_mask( ggml_context * ctx0, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_copy_mask_state( ggml_context * ctx0, @@ -152,7 +152,7 @@ public: int32_t n_tokens, int32_t n_state, int32_t n_seqs, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_mamba_layer( ggml_context * ctx0, @@ -162,7 +162,7 @@ public: ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_rwkv_token_shift_load( ggml_context * ctx0, @@ -171,14 +171,14 @@ public: ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, int il, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_rwkv6_time_mix( ggml_context * ctx0, @@ -189,5 +189,5 @@ public: ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case) = 0; + bool worst_case); }; diff --git a/src/llama.cpp b/src/llama.cpp index a677902f0..3db164477 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -326,8 +326,19 @@ struct llama_context * llama_init_from_model( llama_context * ctx = nullptr; try { - // TODO: add logic which llama_context implementation to construct - ctx = new llama_context_kv_self(*model, params); + // TODO: make static method of llama_context + switch (model->arch) { + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_MAMBA: + GGML_ASSERT(llama_model_is_recurrent(model)); + ctx = new llama_context_recurrent(*model, params); + break; + default: + GGML_ASSERT(!llama_model_is_recurrent(model)); + ctx = new llama_context_kv_self(*model, params); + }; + ctx->init(); } catch (const std::exception & e) { LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());