mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-07 17:24:18 -04:00
Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
147 lines
4.9 KiB
C++
147 lines
4.9 KiB
C++
#pragma once
|
|
|
|
#include "llama-batch.h"
|
|
#include "llama-graph.h"
|
|
#include "llama-kv-cache-recurrent.h"
|
|
#include "llama-kv-cache-unified.h"
|
|
#include "llama-kv-cells.h"
|
|
#include "llama-memory.h"
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
//
|
|
// llama_kv_cache_hybrid_recurrent
|
|
//
|
|
|
|
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
|
|
// support models where each layer may be either attention-based or recurrent
|
|
|
|
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
|
|
public:
|
|
|
|
// this callback is used to filter out layers that should not be included in the cache
|
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
|
|
|
llama_kv_cache_hybrid_recurrent(
|
|
const llama_model & model,
|
|
/* attn */
|
|
ggml_type attn_type_k,
|
|
ggml_type attn_type_v,
|
|
bool attn_v_trans,
|
|
uint32_t attn_kv_size,
|
|
uint32_t attn_n_pad,
|
|
uint32_t attn_n_swa,
|
|
llama_swa_type attn_swa_type,
|
|
/* recurrent */
|
|
ggml_type recurrent_type_k,
|
|
ggml_type recurrent_type_v,
|
|
uint32_t recurrent_kv_size,
|
|
/* common */
|
|
uint32_t n_seq_max,
|
|
bool offload,
|
|
/* layer filters */
|
|
layer_filter_cb && attn_filter = nullptr,
|
|
layer_filter_cb && recurrent_filter = nullptr);
|
|
|
|
~llama_kv_cache_hybrid_recurrent() = default;
|
|
|
|
//
|
|
// llama_memory_i
|
|
//
|
|
|
|
llama_memory_state_ptr init_batch(
|
|
const llama_batch & batch,
|
|
uint32_t n_ubatch,
|
|
bool embd_pooled) override;
|
|
|
|
llama_memory_state_ptr init_full() override;
|
|
|
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
|
|
bool get_can_shift() const override;
|
|
|
|
void clear(bool data) override;
|
|
|
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
void seq_keep(llama_seq_id seq_id) override;
|
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
|
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
|
|
// state write/load
|
|
|
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
|
|
|
//
|
|
// llama_kv_cache_hybrid_recurrent specific API
|
|
//
|
|
|
|
llama_kv_cache_unified * get_kv_attn () const;
|
|
llama_kv_cache_recurrent * get_kv_recurrent() const;
|
|
|
|
private:
|
|
const llama_hparams & hparams;
|
|
|
|
const std::unique_ptr<llama_kv_cache_unified> kv_attn;
|
|
const std::unique_ptr<llama_kv_cache_recurrent> kv_recurrent;
|
|
};
|
|
|
|
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
|
|
public:
|
|
using llama_kv_cache_unified_state_ptr = std::unique_ptr<llama_kv_cache_unified_state>;
|
|
using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>;
|
|
|
|
// init failure
|
|
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
|
|
|
|
// init full
|
|
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv);
|
|
|
|
// init update
|
|
explicit llama_kv_cache_hybrid_recurrent_state(
|
|
llama_kv_cache_unified_state * state_unified,
|
|
llama_kv_cache_recurrent_state * state_recurrent);
|
|
|
|
// init success
|
|
llama_kv_cache_hybrid_recurrent_state(
|
|
llama_kv_cache_hybrid_recurrent * kv,
|
|
llama_sbatch sbatch,
|
|
std::vector<uint32_t> heads_attn,
|
|
std::vector<llama_ubatch> ubatches);
|
|
|
|
~llama_kv_cache_hybrid_recurrent_state() = default;
|
|
|
|
bool next() override;
|
|
bool apply() override;
|
|
|
|
std::vector<int64_t> & out_ids() override;
|
|
|
|
llama_memory_status get_status() const override;
|
|
const llama_ubatch & get_ubatch() const override;
|
|
|
|
//
|
|
// llama_kv_cache_hybrid_recurrent_state
|
|
//
|
|
|
|
const llama_kv_cache_unified_state * get_state_attn () const;
|
|
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
|
|
|
|
private:
|
|
const llama_memory_status status;
|
|
|
|
llama_sbatch sbatch;
|
|
|
|
// the index of the next ubatch to process
|
|
size_t i_next = 0;
|
|
|
|
std::vector<llama_ubatch> ubatches;
|
|
|
|
const llama_memory_state_ptr state_attn;
|
|
const llama_memory_state_ptr state_recurrent;
|
|
};
|