mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 09:08:04 +00:00
context : add llama_context_recurrent
ggml-ci
This commit is contained in:
@ -20,6 +20,8 @@ llama_context::llama_context(
|
||||
model (model),
|
||||
t_start_us(model.t_start_us),
|
||||
t_load_us (model.t_load_us) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
|
||||
const llama_context_params & params) :
|
||||
llama_context(model, params),
|
||||
kv_self(model.hparams) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
|
||||
inp_KQ_mask_swa_cnv = nullptr;
|
||||
inp_KQ_mask_cross = nullptr;
|
||||
inp_k_shift = nullptr;
|
||||
inp_s_copy = nullptr;
|
||||
inp_s_mask = nullptr;
|
||||
inp_embd_enc = nullptr;
|
||||
inp_pos_bucket = nullptr;
|
||||
|
||||
@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
if (inp_s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
|
||||
float * data = (float *) inp_s_mask->data;
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inp_s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inp_pos_bucket) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
|
||||
|
||||
void llama_context_kv_self::build_attn_kv_store(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
int32_t n_tokens,
|
||||
@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
|
||||
//cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
|
||||
|
||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
||||
|
||||
@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
|
||||
}
|
||||
//cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_attn_qkv(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(ctx0, wo, cur);
|
||||
@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
|
||||
return inp_KQ_mask_cross;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
|
||||
//
|
||||
// llama_context_recurrent
|
||||
//
|
||||
|
||||
llama_context_recurrent::llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params) :
|
||||
llama_context_kv_self(model, params) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
||||
}
|
||||
|
||||
llama_context_recurrent::~llama_context_recurrent() = default;
|
||||
|
||||
ggml_cgraph * llama_context_recurrent::graph_init() {
|
||||
inp_s_copy = nullptr;
|
||||
inp_s_mask = nullptr;
|
||||
|
||||
return llama_context_kv_self::graph_init();
|
||||
}
|
||||
|
||||
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||
// call base functionality
|
||||
llama_context_kv_self::input_set(ubatch);
|
||||
|
||||
GGML_ASSERT(kv_self.recurrent);
|
||||
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
if (inp_s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
|
||||
float * data = (float *) inp_s_mask->data;
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
||||
|
||||
data[i] = (float) (kv_cell.src >= 0);
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// only clear once
|
||||
if (kv_cell.src < 0) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inp_s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) {
|
||||
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
|
||||
@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
|
||||
return inp_s_copy;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
|
||||
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) {
|
||||
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
|
||||
@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
|
||||
return inp_s_mask;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
|
||||
ggml_tensor * llama_context_recurrent::build_copy_mask_state(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
|
||||
}
|
||||
|
||||
// TODO: split
|
||||
ggml_tensor * llama_context_kv_self::build_mamba_layer(
|
||||
ggml_tensor * llama_context_recurrent::build_mamba_layer(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
|
||||
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
|
||||
return token_shift;
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
|
||||
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
|
||||
ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
|
@ -433,15 +433,28 @@ public:
|
||||
int32_t n_tokens,
|
||||
bool worst_case) override;
|
||||
|
||||
// === recurrent ===
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
};
|
||||
|
||||
// TODO: add recurrent cache
|
||||
// TODO: add mamba-specific llama_context
|
||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
||||
class llama_context_recurrent : public llama_context_kv_self {
|
||||
public:
|
||||
llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params);
|
||||
|
||||
virtual ~llama_context_recurrent();
|
||||
|
||||
virtual ggml_cgraph * graph_init() override;
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
|
||||
virtual ggml_tensor * build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) override;
|
||||
@ -499,11 +512,10 @@ public:
|
||||
bool worst_case) override;
|
||||
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
// TODO: add recurrent cache
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
@ -1 +1,136 @@
|
||||
#include "llama-graph.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
|
||||
ggml_tensor * llama_graph_i::build_inp_s_copy (
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_inp_s_mask(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_copy_mask_state(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_tokens,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(s);
|
||||
GGML_UNUSED(state_copy);
|
||||
GGML_UNUSED(state_mask);
|
||||
GGML_UNUSED(n_tokens);
|
||||
GGML_UNUSED(n_state);
|
||||
GGML_UNUSED(n_seqs);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_mamba_layer(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(cur);
|
||||
GGML_UNUSED(state_copy);
|
||||
GGML_UNUSED(state_mask);
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_UNUSED(il);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_rwkv_token_shift_load(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(state_copy);
|
||||
GGML_UNUSED(state_mask);
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_UNUSED(il);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_rwkv_token_shift_store(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(token_shift);
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_UNUSED(il);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_rwkv6_time_mix(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) {
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(cur);
|
||||
GGML_UNUSED(x_prev);
|
||||
GGML_UNUSED(state_copy);
|
||||
GGML_UNUSED(state_mask);
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_UNUSED(il);
|
||||
GGML_UNUSED(worst_case);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
|
||||
return nullptr; // NOLINT
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ public:
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
ggml_backend_buffer * bbuft) = 0;
|
||||
ggml_backend_buffer * bbuf) = 0;
|
||||
|
||||
// graph build API (context-specific)
|
||||
|
||||
@ -137,11 +137,11 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_inp_s_mask(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_copy_mask_state(
|
||||
ggml_context * ctx0,
|
||||
@ -152,7 +152,7 @@ public:
|
||||
int32_t n_tokens,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_mamba_layer(
|
||||
ggml_context * ctx0,
|
||||
@ -162,7 +162,7 @@ public:
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_context * ctx0,
|
||||
@ -171,14 +171,14 @@ public:
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_rwkv_token_shift_store(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
|
||||
virtual ggml_tensor * build_rwkv6_time_mix(
|
||||
ggml_context * ctx0,
|
||||
@ -189,5 +189,5 @@ public:
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il,
|
||||
bool worst_case) = 0;
|
||||
bool worst_case);
|
||||
};
|
||||
|
@ -326,8 +326,19 @@ struct llama_context * llama_init_from_model(
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
try {
|
||||
// TODO: add logic which llama_context implementation to construct
|
||||
ctx = new llama_context_kv_self(*model, params);
|
||||
// TODO: make static method of llama_context
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_MAMBA:
|
||||
GGML_ASSERT(llama_model_is_recurrent(model));
|
||||
ctx = new llama_context_recurrent(*model, params);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(!llama_model_is_recurrent(model));
|
||||
ctx = new llama_context_kv_self(*model, params);
|
||||
};
|
||||
|
||||
ctx->init();
|
||||
} catch (const std::exception & e) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
|
||||
|
Reference in New Issue
Block a user