context : add llama_context_recurrent

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-19 14:56:01 +02:00
parent 5f11a5502a
commit e17e4b72d1
5 changed files with 266 additions and 83 deletions

View File

@ -433,15 +433,28 @@ public:
int32_t n_tokens,
bool worst_case) override;
// === recurrent ===
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};
// TODO: add recurrent cache
// TODO: add mamba-specific llama_context
// a recurrent transformer (ie.e RWKV, Mamba)
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
class llama_context_recurrent : public llama_context_kv_self {
public:
llama_context_recurrent(
const llama_model & model,
const llama_context_params & params);
virtual ~llama_context_recurrent();
virtual ggml_cgraph * graph_init() override;
virtual void input_set(const llama_ubatch & ubatch) override;
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
virtual ggml_tensor * build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) override;
@ -499,11 +512,10 @@ public:
bool worst_case) override;
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
// TODO: add recurrent cache
};
// For internal test use