mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
This commit is contained in:
@ -280,8 +280,8 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
|
|
||||||
const auto mstate = memory->init_full();
|
const auto mctx = memory->init_full();
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
throw std::runtime_error("failed to initialize KV cache");
|
throw std::runtime_error("failed to initialize KV cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,7 +289,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve pp graph first so that buffers are only allocated once
|
// reserve pp graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
@ -300,7 +300,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve with tg graph to get the number of splits and nodes
|
// reserve with tg graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
@ -311,7 +311,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
|
|||||||
optimize |= memory_force_optimize;
|
optimize |= memory_force_optimize;
|
||||||
memory_force_optimize = false;
|
memory_force_optimize = false;
|
||||||
|
|
||||||
const auto mstate = memory->init_update(this, optimize);
|
const auto mctx = memory->init_update(this, optimize);
|
||||||
switch (mstate->get_status()) {
|
switch (mctx->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
// noop
|
// noop
|
||||||
@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!mstate->apply()) {
|
if (!mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||||
{
|
{
|
||||||
const auto mstate = memory->init_full();
|
const auto mctx = memory->init_full();
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
throw std::runtime_error("failed to initialize memory state");
|
throw std::runtime_error("failed to initialize memory context");
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||||
}
|
}
|
||||||
@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
|
|||||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||||
if (mstate && !mstate->apply()) {
|
if (mctx && !mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||||
ret = GGML_STATUS_FAILED;
|
ret = GGML_STATUS_FAILED;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||||
ret = GGML_STATUS_FAILED;
|
ret = GGML_STATUS_FAILED;
|
||||||
@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
// handle any pending defrags/shifts
|
// handle any pending defrags/shifts
|
||||||
kv_self_update(false);
|
kv_self_update(false);
|
||||||
|
|
||||||
llama_memory_state_ptr mstate;
|
llama_memory_context_ptr mctx;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||||
if (!mstate) {
|
if (!mctx) {
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (mstate->get_status()) {
|
switch (mctx->get_status()) {
|
||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
{
|
{
|
||||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
||||||
|
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = mstate->get_ubatch();
|
const auto & ubatch = mctx->get_ubatch();
|
||||||
|
|
||||||
// count the outputs in this ubatch
|
// count the outputs in this ubatch
|
||||||
{
|
{
|
||||||
@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
ggml_status status;
|
ggml_status status;
|
||||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||||
@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
n_outputs_prev += n_outputs;
|
n_outputs_prev += n_outputs;
|
||||||
} while (mstate->next());
|
} while (mctx->next());
|
||||||
|
|
||||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||||
n_outputs = n_outputs_all;
|
n_outputs = n_outputs_all;
|
||||||
@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
|
|||||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
||||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
|
|
||||||
if (n_tokens % n_seqs != 0) {
|
if (n_tokens % n_seqs != 0) {
|
||||||
@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||||
|
|
||||||
this->n_outputs = save_n_outputs;
|
this->n_outputs = save_n_outputs;
|
||||||
|
|
||||||
@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_result_ptr llama_context::graph_build(
|
llm_graph_result_ptr llama_context::graph_build(
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
const llama_memory_state_i * mstate) {
|
const llama_memory_context_i * mctx) {
|
||||||
return model.build_graph(
|
return model.build_graph(
|
||||||
{
|
{
|
||||||
/*.ctx =*/ ctx,
|
/*.ctx =*/ ctx,
|
||||||
@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|||||||
/*.backend_cpu =*/ backend_cpu,
|
/*.backend_cpu =*/ backend_cpu,
|
||||||
/*.cvec =*/ &cvec,
|
/*.cvec =*/ &cvec,
|
||||||
/*.loras =*/ &loras,
|
/*.loras =*/ &loras,
|
||||||
/*.mstate =*/ mstate,
|
/*.mctx =*/ mctx,
|
||||||
/*.cross =*/ &cross,
|
/*.cross =*/ &cross,
|
||||||
/*.n_outputs =*/ n_outputs,
|
/*.n_outputs =*/ n_outputs,
|
||||||
/*.cb =*/ graph_get_cb(),
|
/*.cb =*/ graph_get_cb(),
|
||||||
@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
|
|||||||
|
|
||||||
uint32_t n_outputs_all = n_tokens_all;
|
uint32_t n_outputs_all = n_tokens_all;
|
||||||
|
|
||||||
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
||||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
|
|||||||
|
|
||||||
uint32_t pos_batch = 0;
|
uint32_t pos_batch = 0;
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = mstate->get_ubatch();
|
const auto & ubatch = mctx->get_ubatch();
|
||||||
|
|
||||||
n_outputs = ubatch.n_tokens;
|
n_outputs = ubatch.n_tokens;
|
||||||
|
|
||||||
if (!mstate->apply()) {
|
if (!mctx->apply()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
||||||
|
|
||||||
struct ggml_context * ctx_compute_opt;
|
struct ggml_context * ctx_compute_opt;
|
||||||
{
|
{
|
||||||
@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
|
|||||||
ggml_free(ctx_compute_opt);
|
ggml_free(ctx_compute_opt);
|
||||||
|
|
||||||
pos_batch += ubatch.n_tokens;
|
pos_batch += ubatch.n_tokens;
|
||||||
} while (mstate->next());
|
} while (mctx->next());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class llama_io_read_i;
|
|||||||
class llama_io_write_i;
|
class llama_io_write_i;
|
||||||
|
|
||||||
struct llama_memory_i;
|
struct llama_memory_i;
|
||||||
struct llama_memory_state_i;
|
struct llama_memory_context_i;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
// init scheduler and compute buffers, reserve worst-case graphs
|
// init scheduler and compute buffers, reserve worst-case graphs
|
||||||
@ -93,14 +93,14 @@ struct llama_context {
|
|||||||
int32_t il_end);
|
int32_t il_end);
|
||||||
|
|
||||||
// process a single ubatch with a specific graph type
|
// process a single ubatch with a specific graph type
|
||||||
// if memory_state is provided, it will be applied first to the context's memory
|
// if memory_context is provided, it will be applied first to the context's memory
|
||||||
// ret contains the status of the graph computation
|
// ret contains the status of the graph computation
|
||||||
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||||
llm_graph_result_ptr process_ubatch(
|
llm_graph_result_ptr process_ubatch(
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
llama_memory_state_i * mstate,
|
llama_memory_context_i * mctx,
|
||||||
ggml_status & ret);
|
ggml_status & ret);
|
||||||
|
|
||||||
int encode(const llama_batch & batch_inp);
|
int encode(const llama_batch & batch_inp);
|
||||||
int decode(const llama_batch & batch_inp);
|
int decode(const llama_batch & batch_inp);
|
||||||
@ -197,15 +197,15 @@ public:
|
|||||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
|
|
||||||
// reserve a graph with a dummy ubatch of the specified size
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_result_ptr graph_build(
|
llm_graph_result_ptr graph_build(
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype,
|
llm_graph_type gtype,
|
||||||
const llama_memory_state_i * mstate);
|
const llama_memory_context_i * mctx);
|
||||||
|
|
||||||
llm_graph_cb graph_get_cb() const;
|
llm_graph_cb graph_get_cb() const;
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||||
if (pos_bucket) {
|
if (pos_bucket) {
|
||||||
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
const int64_t n_rs = mem_state->get_n_rs();
|
const int64_t n_rs = mctx->get_n_rs();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
data[i] = mem_state->s_copy(i);
|
data[i] = mctx->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self_kq_mask_swa) {
|
if (self_kq_mask_swa) {
|
||||||
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
if (s_copy) {
|
if (s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
data[i] = mem_state->get_state_recr()->s_copy(i);
|
data[i] = mctx->get_recr()->s_copy(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||||||
backend_cpu (params.backend_cpu),
|
backend_cpu (params.backend_cpu),
|
||||||
cvec (params.cvec),
|
cvec (params.cvec),
|
||||||
loras (params.loras),
|
loras (params.loras),
|
||||||
mstate (params.mstate),
|
mctx (params.mctx),
|
||||||
cross (params.cross),
|
cross (params.cross),
|
||||||
cb_func (params.cb),
|
cb_func (params.cb),
|
||||||
res (std::make_unique<llm_graph_result>()) {
|
res (std::make_unique<llm_graph_result>()) {
|
||||||
@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->pos_bucket;
|
auto & cur = inp->pos_bucket;
|
||||||
|
|
||||||
@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
||||||
|
|
||||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
ggml_set_input(inp->s_copy);
|
ggml_set_input(inp->s_copy);
|
||||||
@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||||
|
|
||||||
const bool is_swa = hparams.is_swa(il);
|
const bool is_swa = hparams.is_swa(il);
|
||||||
|
|
||||||
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
ggml_build_forward_expand(gf, k_cur);
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
cb(cur, "kqv_out", il);
|
cb(cur, "kqv_out", il);
|
||||||
@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
||||||
|
|
||||||
const auto n_rs = kv_state->get_n_rs();
|
const auto n_rs = mctx_cur->get_n_rs();
|
||||||
|
|
||||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
ggml_set_input(inp->s_copy);
|
ggml_set_input(inp->s_copy);
|
||||||
@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rs(
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
||||||
|
|
||||||
ggml_tensor * token_shift = build_rs(
|
ggml_tensor * token_shift = build_rs(
|
||||||
inp, gf, token_shift_all,
|
inp, gf, token_shift_all,
|
||||||
@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|||||||
ggml_tensor * token_shift,
|
ggml_tensor * token_shift,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto token_shift_count = hparams.token_shift_count;
|
const auto token_shift_count = hparams.token_shift_count;
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
return ggml_cpy(
|
return ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||||
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ struct ggml_tensor;
|
|||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
struct llama_cparams;
|
struct llama_cparams;
|
||||||
|
|
||||||
struct llama_memory_state_i;
|
struct llama_memory_context_i;
|
||||||
|
|
||||||
class llama_kv_cache_unified_state;
|
class llama_kv_cache_unified_context;
|
||||||
class llama_kv_cache_unified_iswa_state;
|
class llama_kv_cache_unified_iswa_context;
|
||||||
class llama_memory_recurrent_state;
|
class llama_memory_recurrent_context;
|
||||||
class llama_memory_hybrid_state;
|
class llama_memory_hybrid_context;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|||||||
public:
|
public:
|
||||||
llm_graph_input_pos_bucket_kv(
|
llm_graph_input_pos_bucket_kv(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
@ -144,7 +144,8 @@ public:
|
|||||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_kv_cache_unified_state * kv_state;
|
|
||||||
|
const llama_kv_cache_unified_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||||
@ -191,14 +192,14 @@ public:
|
|||||||
|
|
||||||
class llm_graph_input_rs : public llm_graph_input_i {
|
class llm_graph_input_rs : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
||||||
virtual ~llm_graph_input_rs() = default;
|
virtual ~llm_graph_input_rs() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [kv_size]
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
const llama_memory_recurrent_state * mem_state;
|
const llama_memory_recurrent_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
@ -238,10 +239,10 @@ public:
|
|||||||
llm_graph_input_attn_kv_unified(
|
llm_graph_input_attn_kv_unified(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache_unified_state * kv_state) :
|
const llama_kv_cache_unified_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_state(kv_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified() = default;
|
~llm_graph_input_attn_kv_unified() = default;
|
||||||
|
|
||||||
@ -255,7 +256,7 @@ public:
|
|||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * kv_state;
|
const llama_kv_cache_unified_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||||
@ -263,10 +264,10 @@ public:
|
|||||||
llm_graph_input_attn_kv_unified_iswa(
|
llm_graph_input_attn_kv_unified_iswa(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_kv_cache_unified_iswa_state * kv_state) :
|
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_state(kv_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||||
|
|
||||||
@ -283,7 +284,7 @@ public:
|
|||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
const llama_kv_cache_unified_iswa_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||||
@ -306,10 +307,10 @@ public:
|
|||||||
llm_graph_input_mem_hybrid(
|
llm_graph_input_mem_hybrid(
|
||||||
const llama_hparams & hparams,
|
const llama_hparams & hparams,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
const llama_memory_hybrid_state * mem_state) :
|
const llama_memory_hybrid_context * mctx) :
|
||||||
hparams(hparams),
|
hparams(hparams),
|
||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
mem_state(mem_state) {
|
mctx(mctx) {
|
||||||
}
|
}
|
||||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
@ -325,7 +326,7 @@ public:
|
|||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
const llama_memory_hybrid_state * mem_state;
|
const llama_memory_hybrid_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -401,10 +402,10 @@ struct llm_graph_params {
|
|||||||
ggml_backend_sched_t sched;
|
ggml_backend_sched_t sched;
|
||||||
ggml_backend_t backend_cpu;
|
ggml_backend_t backend_cpu;
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_state_i * mstate;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
uint32_t n_outputs;
|
uint32_t n_outputs;
|
||||||
|
|
||||||
@ -453,10 +454,10 @@ struct llm_graph_context {
|
|||||||
|
|
||||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||||
|
|
||||||
const llama_adapter_cvec * cvec;
|
const llama_adapter_cvec * cvec;
|
||||||
const llama_adapter_loras * loras;
|
const llama_adapter_loras * loras;
|
||||||
const llama_memory_state_i * mstate;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
const llm_graph_cb & cb_func;
|
const llm_graph_cb & cb_func;
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
return kv_swa->seq_pos_max(seq_id);
|
return kv_swa->seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
GGML_UNUSED(embd_all);
|
GGML_UNUSED(embd_all);
|
||||||
|
|
||||||
// first try simple split
|
// first try simple split
|
||||||
@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
|||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
|||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// TODO: if we fail again, we should attempt different splitting strategies
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa_state
|
// llama_kv_cache_unified_iswa_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv) :
|
llama_kv_cache_unified_iswa * kv) :
|
||||||
state_base(kv->get_base()->init_full()),
|
ctx_base(kv->get_base()->init_full()),
|
||||||
state_swa (kv->get_swa ()->init_full()),
|
ctx_swa (kv->get_swa ()->init_full()),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize) :
|
bool optimize) :
|
||||||
state_base(kv->get_base()->init_update(lctx, optimize)),
|
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||||
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
|
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||||
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa_state::next() {
|
bool llama_kv_cache_unified_iswa_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
state_base->next();
|
ctx_base->next();
|
||||||
state_swa ->next();
|
ctx_swa ->next();
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa_state::apply() {
|
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
bool res = true;
|
bool res = true;
|
||||||
|
|
||||||
res = res & state_base->apply();
|
res = res & ctx_base->apply();
|
||||||
res = res & state_swa ->apply();
|
res = res & ctx_swa ->apply();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||||
}
|
}
|
||||||
|
@ -31,14 +31,14 @@ public:
|
|||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
@ -72,32 +72,32 @@ private:
|
|||||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache context
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv);
|
llama_kv_cache_unified_iswa * kv);
|
||||||
|
|
||||||
// used to create an update state
|
// used to create an update context
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize);
|
bool optimize);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_iswa_state();
|
virtual ~llama_kv_cache_unified_iswa_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
@ -107,11 +107,11 @@ public:
|
|||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_iswa_state specific API
|
// llama_kv_cache_unified_iswa_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * get_base() const;
|
const llama_kv_cache_unified_context * get_base() const;
|
||||||
const llama_kv_cache_unified_state * get_swa() const;
|
const llama_kv_cache_unified_context * get_swa() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//llama_kv_cache_unified_iswa * kv;
|
//llama_kv_cache_unified_iswa * kv;
|
||||||
@ -121,8 +121,8 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_memory_state_ptr state_base;
|
const llama_memory_context_ptr ctx_base;
|
||||||
const llama_memory_state_ptr state_swa;
|
const llama_memory_context_ptr ctx_swa;
|
||||||
|
|
||||||
const llama_memory_status status;
|
const llama_memory_status status;
|
||||||
};
|
};
|
||||||
|
@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
return cells.seq_pos_max(seq_id);
|
return cells.seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) {
|
bool embd_all) {
|
||||||
@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(
|
return std::make_unique<llama_kv_cache_unified_context>(
|
||||||
this, std::move(heads), std::move(ubatches));
|
this, std::move(heads), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
return std::make_unique<llama_kv_cache_unified_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||||
bool do_shift = get_has_shift();
|
bool do_shift = get_has_shift();
|
||||||
|
|
||||||
defrag_info dinfo;
|
defrag_info dinfo;
|
||||||
@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_state
|
// llama_kv_cache_unified_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
head = 0;
|
head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool do_shift,
|
bool do_shift,
|
||||||
@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified::ubatch_heads heads,
|
llama_kv_cache_unified::ubatch_heads heads,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
|
|
||||||
bool llama_kv_cache_unified_state::next() {
|
bool llama_kv_cache_unified_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_state::apply() {
|
bool llama_kv_cache_unified_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
// no ubatches -> this is a KV cache update
|
// no ubatches -> this is a KV cache update
|
||||||
@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||||
return n_kv;
|
return n_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||||
return kv->get_k(ctx, il, n_kv);
|
return kv->get_k(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||||
return kv->get_v(ctx, il, n_kv);
|
return kv->get_v(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, il, head);
|
return kv->cpy_k(ctx, k_cur, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, il, head);
|
return kv->cpy_v(ctx, v_cur, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
kv->set_input_pos_bucket(dst, ubatch);
|
kv->set_input_pos_bucket(dst, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,14 +56,14 @@ public:
|
|||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
@ -208,36 +208,36 @@ private:
|
|||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// some shorthands
|
// some shorthands
|
||||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_state(llama_memory_status status);
|
llama_kv_cache_unified_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache context
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv);
|
llama_kv_cache_unified * kv);
|
||||||
|
|
||||||
// used to create an update state
|
// used to create an update context
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool do_shift,
|
bool do_shift,
|
||||||
defrag_info dinfo);
|
defrag_info dinfo);
|
||||||
|
|
||||||
// used to create a decode state from a batch
|
// used to create a batch procesing context from a batch
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
ubatch_heads heads,
|
ubatch_heads heads,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_state();
|
virtual ~llama_kv_cache_unified_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
@ -247,7 +247,7 @@ public:
|
|||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_unified_state specific API
|
// llama_kv_cache_unified_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_kv() const;
|
uint32_t get_n_kv() const;
|
||||||
@ -272,7 +272,7 @@ private:
|
|||||||
llama_context * lctx;
|
llama_context * lctx;
|
||||||
|
|
||||||
//
|
//
|
||||||
// update state
|
// update context
|
||||||
//
|
//
|
||||||
|
|
||||||
bool do_shift = false;
|
bool do_shift = false;
|
||||||
@ -280,7 +280,7 @@ private:
|
|||||||
defrag_info dinfo;
|
defrag_info dinfo;
|
||||||
|
|
||||||
//
|
//
|
||||||
// batch processing state
|
// batch processing context
|
||||||
//
|
//
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the next ubatch to process
|
||||||
|
@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||||||
n_seq_max
|
n_seq_max
|
||||||
)) {}
|
)) {}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
do {
|
do {
|
||||||
balloc.split_reset();
|
balloc.split_reset();
|
||||||
|
|
||||||
@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
|
|||||||
|
|
||||||
// prepare the recurrent batches first
|
// prepare the recurrent batches first
|
||||||
if (!mem_recr->prepare(ubatches)) {
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
// TODO: will the recurrent cache be in an undefined state at this point?
|
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the attention cache
|
// prepare the attention cache
|
||||||
auto heads_attn = mem_attn->prepare(ubatches);
|
auto heads_attn = mem_attn->prepare(ubatches);
|
||||||
if (heads_attn.empty()) {
|
if (heads_attn.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_memory_hybrid_state>(
|
return std::make_unique<llama_memory_hybrid_context>(
|
||||||
this, std::move(heads_attn), std::move(ubatches));
|
this, std::move(heads_attn), std::move(ubatches));
|
||||||
} while(false);
|
} while(false);
|
||||||
|
|
||||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
||||||
return std::make_unique<llama_memory_hybrid_state>(this);
|
return std::make_unique<llama_memory_hybrid_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||||
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid::get_can_shift() const {
|
bool llama_memory_hybrid::get_can_shift() const {
|
||||||
@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
|||||||
return mem_recr.get();
|
return mem_recr.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
||||||
state_attn(mem->get_mem_attn()->init_full()),
|
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||||
state_recr(mem->get_mem_recr()->init_full()),
|
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize) :
|
bool optimize) :
|
||||||
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||||
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
std::vector<uint32_t> heads_attn,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||||
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid_state::next() {
|
bool llama_memory_hybrid_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
state_attn->next();
|
ctx_attn->next();
|
||||||
state_recr->next();
|
ctx_recr->next();
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid_state::apply() {
|
bool llama_memory_hybrid_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
bool res = true;
|
bool res = true;
|
||||||
|
|
||||||
res = res & state_attn->apply();
|
res = res & ctx_attn->apply();
|
||||||
res = res & state_recr->apply();
|
res = res & ctx_recr->apply();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||||
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||||
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||||
}
|
}
|
||||||
|
@ -49,14 +49,14 @@ public:
|
|||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
@ -90,27 +90,27 @@ private:
|
|||||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_memory_hybrid_state : public llama_memory_state_i {
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_memory_hybrid_state(llama_memory_status status);
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||||
|
|
||||||
// init full
|
// init full
|
||||||
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
|
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
||||||
|
|
||||||
// init update
|
// init update
|
||||||
explicit llama_memory_hybrid_state(
|
explicit llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
llama_context * lctx,
|
llama_context * lctx,
|
||||||
bool optimize);
|
bool optimize);
|
||||||
|
|
||||||
// init success
|
// init success
|
||||||
llama_memory_hybrid_state(
|
llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
std::vector<uint32_t> heads_attn,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
~llama_memory_hybrid_state() = default;
|
~llama_memory_hybrid_context() = default;
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
bool apply() override;
|
bool apply() override;
|
||||||
@ -119,11 +119,11 @@ public:
|
|||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_hybrid_state
|
// llama_memory_hybrid_context
|
||||||
//
|
//
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * get_state_attn() const;
|
const llama_kv_cache_unified_context * get_attn() const;
|
||||||
const llama_memory_recurrent_state * get_state_recr() const;
|
const llama_memory_recurrent_context * get_recr() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// the index of the next ubatch to process
|
// the index of the next ubatch to process
|
||||||
@ -131,8 +131,8 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_memory_state_ptr state_attn;
|
const llama_memory_context_ptr ctx_attn;
|
||||||
const llama_memory_state_ptr state_recr;
|
const llama_memory_context_ptr ctx_recr;
|
||||||
|
|
||||||
const llama_memory_status status;
|
const llama_memory_status status;
|
||||||
};
|
};
|
||||||
|
@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!prepare(ubatches)) {
|
if (!prepare(ubatches)) {
|
||||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
|
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
||||||
return std::make_unique<llama_memory_recurrent_state>(this);
|
return std::make_unique<llama_memory_recurrent_context>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
GGML_UNUSED(lctx);
|
GGML_UNUSED(lctx);
|
||||||
GGML_UNUSED(optimize);
|
GGML_UNUSED(optimize);
|
||||||
|
|
||||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_recurrent_state
|
// llama_memory_recurrent_context
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem,
|
llama_memory_recurrent * mem,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
||||||
|
|
||||||
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
||||||
|
|
||||||
bool llama_memory_recurrent_state::next() {
|
bool llama_memory_recurrent_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_next >= ubatches.size()) {
|
||||||
@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_recurrent_state::apply() {
|
bool llama_memory_recurrent_context::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
mem->find_slot(ubatches[i_next]);
|
mem->find_slot(ubatches[i_next]);
|
||||||
@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
||||||
return is_full ? mem->size : mem->n;
|
return is_full ? mem->size : mem->n;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_head() const {
|
uint32_t llama_memory_recurrent_context::get_head() const {
|
||||||
return is_full ? 0 : mem->head;
|
return is_full ? 0 : mem->head;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
||||||
return is_full ? 0 : mem->rs_z;
|
return is_full ? 0 : mem->rs_z;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_memory_recurrent_state::get_size() const {
|
uint32_t llama_memory_recurrent_context::get_size() const {
|
||||||
return mem->size;
|
return mem->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
||||||
return mem->r_l[il];
|
return mem->r_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||||
return mem->s_l[il];
|
return mem->s_l[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
||||||
return mem->cells[i + mem->head].src0;
|
return mem->cells[i + mem->head].src0;
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,8 @@
|
|||||||
// llama_memory_recurrent
|
// llama_memory_recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
|
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
||||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
||||||
class llama_memory_recurrent : public llama_memory_i {
|
class llama_memory_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -34,14 +34,14 @@ public:
|
|||||||
// llama_memory_i
|
// llama_memory_i
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_memory_state_ptr init_batch(
|
llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) override;
|
bool embd_all) override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
@ -125,24 +125,24 @@ private:
|
|||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_memory_recurrent_state : public llama_memory_state_i {
|
class llama_memory_recurrent_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_memory_recurrent_state(llama_memory_status status);
|
llama_memory_recurrent_context(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache or update context
|
||||||
llama_memory_recurrent_state(
|
llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem);
|
llama_memory_recurrent * mem);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_memory_recurrent_state(
|
llama_memory_recurrent_context(
|
||||||
llama_memory_recurrent * mem,
|
llama_memory_recurrent * mem,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_memory_recurrent_state();
|
virtual ~llama_memory_recurrent_context();
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_state_i
|
// llama_memory_context_i
|
||||||
//
|
//
|
||||||
|
|
||||||
bool next() override;
|
bool next() override;
|
||||||
@ -152,7 +152,7 @@ public:
|
|||||||
const llama_ubatch & get_ubatch() const override;
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
//
|
//
|
||||||
// llama_memory_recurrent_state specific API
|
// llama_memory_recurrent_context specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_rs() const;
|
uint32_t get_n_rs() const;
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
|
|
||||||
@ -28,23 +27,21 @@ enum llama_memory_status {
|
|||||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||||
};
|
};
|
||||||
|
|
||||||
// helper function for combining the status of two memory states
|
// helper function for combining the status of two memory contexts
|
||||||
// useful for implementing hybrid memory types (e.g. iSWA)
|
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||||
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||||
|
|
||||||
// the interface for managing the memory state during batch processing
|
// the interface for managing the memory context during batch processing
|
||||||
// this interface is implemented per memory type. see:
|
// this interface is implemented per memory type. see:
|
||||||
// - llama_kv_cache_unified_state
|
// - llama_kv_cache_unified_context
|
||||||
// - llama_kv_cache_unified_iswa_state
|
// - llama_kv_cache_unified_iswa_context
|
||||||
// ...
|
// ...
|
||||||
//
|
//
|
||||||
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
|
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
||||||
//
|
struct llama_memory_context_i {
|
||||||
// TODO: rename to llama_memory_context_i ?
|
virtual ~llama_memory_context_i() = default;
|
||||||
struct llama_memory_state_i {
|
|
||||||
virtual ~llama_memory_state_i() = default;
|
|
||||||
|
|
||||||
// consume the current ubatch from the state and proceed to the next one
|
// consume the current ubatch from the context and proceed to the next one
|
||||||
// return false if we are done
|
// return false if we are done
|
||||||
virtual bool next() = 0;
|
virtual bool next() = 0;
|
||||||
|
|
||||||
@ -55,11 +52,11 @@ struct llama_memory_state_i {
|
|||||||
// get the current ubatch
|
// get the current ubatch
|
||||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||||
|
|
||||||
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
// get the status of the memory context - used for error handling and checking if any updates would be applied
|
||||||
virtual llama_memory_status get_status() const = 0;
|
virtual llama_memory_status get_status() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
|
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
||||||
|
|
||||||
// general concept of LLM memory
|
// general concept of LLM memory
|
||||||
// the KV cache is a type of LLM memory, but there can be other types
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
@ -67,19 +64,19 @@ struct llama_memory_i {
|
|||||||
virtual ~llama_memory_i() = default;
|
virtual ~llama_memory_i() = default;
|
||||||
|
|
||||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||||
// return a state object containing the ubatches and KV cache state required to process them
|
// return a context object containing the ubatches and memory state required to process them
|
||||||
// check the llama_memory_state_i::get_status() for the result
|
// check the llama_memory_context_i::get_status() for the result
|
||||||
virtual llama_memory_state_ptr init_batch(
|
virtual llama_memory_context_ptr init_batch(
|
||||||
llama_batch_allocr & balloc,
|
llama_batch_allocr & balloc,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_all) = 0;
|
bool embd_all) = 0;
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
virtual llama_memory_state_ptr init_full() = 0;
|
virtual llama_memory_context_ptr init_full() = 0;
|
||||||
|
|
||||||
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||||
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||||
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
|
|
||||||
// getters
|
// getters
|
||||||
virtual bool get_can_shift() const = 0;
|
virtual bool get_can_shift() const = 0;
|
||||||
|
@ -9171,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const int64_t d_conv = hparams.ssm_d_conv;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
@ -9191,8 +9191,8 @@ struct llm_build_mamba : public llm_graph_context {
|
|||||||
GGML_ASSERT(ubatch.equal_seqs);
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||||
|
|
||||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||||
|
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
ggml_tensor * conv = build_rs(
|
ggml_tensor * conv = build_rs(
|
||||||
@ -11916,7 +11916,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||||||
ggml_tensor * x_prev,
|
ggml_tensor * x_prev,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
@ -11926,7 +11926,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||||||
const auto n_head = n_embd / head_size;
|
const auto n_head = n_embd / head_size;
|
||||||
const auto n_head_kv = hparams.n_head_kv(il);
|
const auto n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const auto & layer = model.layers[il];
|
const auto & layer = model.layers[il];
|
||||||
|
|
||||||
@ -12038,7 +12038,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_rs(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
inp, gf, kv_state->get_s_l(il),
|
inp, gf, mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output;
|
ggml_tensor * wkv_output;
|
||||||
@ -12057,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_s_l(il),
|
mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@ -12313,7 +12313,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||||||
ggml_tensor *& first_layer_value,
|
ggml_tensor *& first_layer_value,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
const auto n_tokens = ubatch.n_tokens;
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
const auto n_seqs = ubatch.n_seqs;
|
const auto n_seqs = ubatch.n_seqs;
|
||||||
@ -12322,7 +12322,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||||||
const auto head_count = n_embd / head_size;
|
const auto head_count = n_embd / head_size;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|
||||||
const auto kv_head = kv_state->get_head();
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
const auto & layer = model.layers[il];
|
const auto & layer = model.layers[il];
|
||||||
|
|
||||||
@ -12393,7 +12393,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||||
|
|
||||||
ggml_tensor * wkv_state = build_rs(
|
ggml_tensor * wkv_state = build_rs(
|
||||||
inp, gf, kv_state->get_s_l(il),
|
inp, gf, mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s(), n_seqs);
|
hparams.n_embd_s(), n_seqs);
|
||||||
|
|
||||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||||
@ -12407,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_state->get_s_l(il),
|
mctx_cur->get_s_l(il),
|
||||||
hparams.n_embd_s() * n_seqs,
|
hparams.n_embd_s() * n_seqs,
|
||||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
Reference in New Issue
Block a user