mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
This commit is contained in:
@ -9171,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
ggml_tensor * cur,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
@ -9191,8 +9191,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv = build_rs(
|
||||
@ -11916,7 +11916,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
ggml_tensor * x_prev,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -11926,7 +11926,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
const auto n_head = n_embd / head_size;
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
@ -12038,7 +12038,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
}
|
||||
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
inp, gf, mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output;
|
||||
@ -12057,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_s_l(il),
|
||||
mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
@ -12313,7 +12313,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
@ -12322,7 +12322,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
const auto head_count = n_embd / head_size;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
@ -12393,7 +12393,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
inp, gf, mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
@ -12407,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_s_l(il),
|
||||
mctx_cur->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
Reference in New Issue
Block a user