mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 00:57:41 +00:00
context : add llama_context_recurrent
ggml-ci
This commit is contained in:
@ -433,15 +433,28 @@ public:
|
||||
int32_t n_tokens,
|
||||
bool worst_case) override;
|
||||
|
||||
// === recurrent ===
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
};
|
||||
|
||||
// TODO: add recurrent cache
|
||||
// TODO: add mamba-specific llama_context
|
||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
||||
class llama_context_recurrent : public llama_context_kv_self {
|
||||
public:
|
||||
llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params);
|
||||
|
||||
virtual ~llama_context_recurrent();
|
||||
|
||||
virtual ggml_cgraph * graph_init() override;
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
|
||||
virtual ggml_tensor * build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) override;
|
||||
@ -499,11 +512,10 @@ public:
|
||||
bool worst_case) override;
|
||||
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
// TODO: add recurrent cache
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
Reference in New Issue
Block a user