mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
recurrent : rework graph inputs + add TODOs
ggml-ci
This commit is contained in:
@ -255,11 +255,6 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
|
|
||||||
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
|
||||||
llm_graph_input_rs(kv_state->get_state_recurrent()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
|
|
||||||
@ -365,13 +360,6 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
|
|
||||||
const llama_hparams & hparams,
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
|
||||||
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
@ -416,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
|
if (self_kq_mask) {
|
||||||
|
kv_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t n_kv = kv_state->get_state_recurrent()->get_n_kv();
|
||||||
|
|
||||||
|
if (s_copy) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||||
|
int32_t * data = (int32_t *) s_copy->data;
|
||||||
|
|
||||||
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
|
data[i] = kv_state->get_state_recurrent()->s_copy(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_graph_context
|
// llm_graph_context
|
||||||
//
|
//
|
||||||
@ -1043,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|||||||
return pos_bias;
|
return pos_bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
|
const auto n_kv = inp->kv_state->get_state_attn()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = kv_state->get_state_recurrent()->get_n_kv();
|
||||||
|
|
||||||
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||||
|
ggml_set_input(inp->s_copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_attn_mha(
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * q,
|
ggml_tensor * q,
|
||||||
@ -1287,105 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
|
|
||||||
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
|
|
||||||
|
|
||||||
{
|
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
|
||||||
|
|
||||||
const auto n_kv = inp->kv_state->get_n_kv();
|
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_attn(
|
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
|
||||||
ggml_cgraph * gf,
|
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
|
||||||
ggml_tensor * k_cur,
|
|
||||||
ggml_tensor * v_cur,
|
|
||||||
ggml_tensor * kq_b,
|
|
||||||
ggml_tensor * v_mla,
|
|
||||||
float kq_scale,
|
|
||||||
int il) const {
|
|
||||||
// these nodes are added to the graph together so that they are not reordered
|
|
||||||
// by doing so, the number of splits in the graph is reduced
|
|
||||||
ggml_build_forward_expand(gf, q_cur);
|
|
||||||
ggml_build_forward_expand(gf, k_cur);
|
|
||||||
ggml_build_forward_expand(gf, v_cur);
|
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
|
|
||||||
|
|
||||||
// store to KV cache
|
|
||||||
{
|
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
|
||||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
|
||||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
|
||||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
||||||
cb(cur, "kqv_out", il);
|
|
||||||
|
|
||||||
if (wo) {
|
|
||||||
cur = build_lora_mm(wo, cur);
|
|
||||||
if (arch == LLM_ARCH_GLM4) {
|
|
||||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
||||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wo_b) {
|
|
||||||
cur = ggml_add(ctx0, cur, wo_b);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cur;
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_attn(
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
@ -1494,20 +1428,100 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
|
||||||
const llama_kv_cache_recurrent_state * kv_state,
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla,
|
||||||
|
float kq_scale,
|
||||||
|
int il) const {
|
||||||
|
// these nodes are added to the graph together so that they are not reordered
|
||||||
|
// by doing so, the number of splits in the graph is reduced
|
||||||
|
ggml_build_forward_expand(gf, q_cur);
|
||||||
|
ggml_build_forward_expand(gf, k_cur);
|
||||||
|
ggml_build_forward_expand(gf, v_cur);
|
||||||
|
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
|
||||||
|
|
||||||
|
// store to KV cache
|
||||||
|
{
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||||
|
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
|
ggml_tensor * q = q_cur;
|
||||||
|
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||||
|
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||||
|
|
||||||
|
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
|
||||||
|
if (wo) {
|
||||||
|
cur = build_lora_mm(wo, cur);
|
||||||
|
if (arch == LLM_ARCH_GLM4) {
|
||||||
|
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||||
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wo_b) {
|
||||||
|
cur = ggml_add(ctx0, cur, wo_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||||
|
|
||||||
|
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
ggml_tensor * state_copy,
|
ggml_tensor * state_copy,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
|
uint32_t n_kv,
|
||||||
|
uint32_t kv_head,
|
||||||
|
uint32_t kv_size,
|
||||||
|
int32_t rs_zero,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
||||||
const auto kv_head = kv_state->get_head();
|
|
||||||
const auto rs_zero = kv_state->get_rs_z();
|
|
||||||
|
|
||||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
|
||||||
|
|
||||||
// Clear a single state which will then be copied to the other cleared states.
|
// Clear a single state which will then be copied to the other cleared states.
|
||||||
// Note that this is a no-op when the view is zero-sized.
|
// Note that this is a no-op when the view is zero-sized.
|
||||||
@ -1538,17 +1552,15 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|||||||
return output_states;
|
return output_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
||||||
|
|
||||||
const auto n_kv = kv_state->get_n_kv();
|
const auto n_kv = kv_state->get_n_kv();
|
||||||
|
|
||||||
auto & cur = inp->s_copy;
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||||
|
ggml_set_input(inp->s_copy);
|
||||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
||||||
ggml_set_input(cur);
|
|
||||||
|
|
||||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
@ -1560,35 +1572,21 @@ ggml_tensor * llm_graph_context::build_rs(
|
|||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
||||||
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
|
|
||||||
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
|
|
||||||
|
|
||||||
const auto n_kv = inp->kv_state->get_n_kv();
|
|
||||||
|
|
||||||
auto & cur = inp->s_copy;
|
|
||||||
|
|
||||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
||||||
ggml_set_input(cur);
|
|
||||||
|
|
||||||
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rs(
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
llm_graph_input_rs_hybrid_recurrent * inp,
|
llm_graph_input_mem_hybrid * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
bool avoid_copies) const {
|
||||||
|
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
|
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
|
||||||
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
|
|
||||||
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
|
@ -201,12 +201,6 @@ public:
|
|||||||
const llama_kv_cache_recurrent_state * kv_state;
|
const llama_kv_cache_recurrent_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs {
|
|
||||||
public:
|
|
||||||
llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state);
|
|
||||||
virtual ~llm_graph_input_rs_hybrid_recurrent() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_cross_embd(
|
llm_graph_input_cross_embd(
|
||||||
@ -264,15 +258,6 @@ public:
|
|||||||
const llama_kv_cache_unified_state * kv_state;
|
const llama_kv_cache_unified_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
|
|
||||||
public:
|
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent(
|
|
||||||
const llama_hparams & hparams,
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
const llama_kv_cache_hybrid_recurrent_state * kv_state);
|
|
||||||
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_attn_kv_unified_iswa(
|
llm_graph_input_attn_kv_unified_iswa(
|
||||||
@ -316,6 +301,33 @@ public:
|
|||||||
const llama_cross * cross = nullptr;
|
const llama_cross * cross = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_mem_hybrid(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
||||||
|
hparams(hparams),
|
||||||
|
cparams(cparams),
|
||||||
|
kv_state(kv_state) {
|
||||||
|
}
|
||||||
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
|
const llama_kv_cache_hybrid_recurrent_state * kv_state;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_graph_result
|
// llm_graph_result
|
||||||
//
|
//
|
||||||
@ -530,6 +542,8 @@ struct llm_graph_context {
|
|||||||
ggml_tensor * build_inp_pos_bucket_dec() const;
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||||
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// attention
|
// attention
|
||||||
//
|
//
|
||||||
@ -604,10 +618,8 @@ struct llm_graph_context {
|
|||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
|
|
||||||
|
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
llm_graph_input_mem_hybrid * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * wo,
|
ggml_tensor * wo,
|
||||||
ggml_tensor * wo_b,
|
ggml_tensor * wo_b,
|
||||||
@ -622,16 +634,25 @@ struct llm_graph_context {
|
|||||||
// recurrent
|
// recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
ggml_tensor * build_recurrent_state(
|
// TODO: avoid notion of "kv"
|
||||||
const llama_kv_cache_recurrent_state * kv_state,
|
// TODO: move this implementation to llama_kv_cache_recurrent.
|
||||||
|
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||||
|
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||||
|
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||||
|
// `llama_kv_cache_recurrent`
|
||||||
|
ggml_tensor * build_rs(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
ggml_tensor * state_copy,
|
ggml_tensor * state_copy,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
|
uint32_t n_kv,
|
||||||
|
uint32_t kv_head,
|
||||||
|
uint32_t kv_size,
|
||||||
|
int32_t rs_zero,
|
||||||
bool avoid_copies = false) const;
|
bool avoid_copies = false) const;
|
||||||
|
|
||||||
llm_graph_input_rs * build_rs_inp_recurrent() const;
|
llm_graph_input_rs * build_rs_inp() const;
|
||||||
|
|
||||||
ggml_tensor * build_rs(
|
ggml_tensor * build_rs(
|
||||||
llm_graph_input_rs * inp,
|
llm_graph_input_rs * inp,
|
||||||
@ -641,10 +662,8 @@ struct llm_graph_context {
|
|||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies = false) const;
|
bool avoid_copies = false) const;
|
||||||
|
|
||||||
llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const;
|
|
||||||
|
|
||||||
ggml_tensor * build_rs(
|
ggml_tensor * build_rs(
|
||||||
llm_graph_input_rs_hybrid_recurrent * inp,
|
llm_graph_input_mem_hybrid * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
|
@ -99,9 +99,7 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
|
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this, lctx, optimize);
|
||||||
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
|
|
||||||
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
|
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
|
||||||
@ -171,35 +169,38 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
|
|||||||
return kv_recurrent.get();
|
return kv_recurrent.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
|
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
|
||||||
: status(status),
|
|
||||||
state_attn(new llama_kv_cache_unified_state(status)),
|
|
||||||
state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
|
|
||||||
|
|
||||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
|
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
|
||||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
: status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
|
state_attn = kv->get_kv_attn ()->init_full();
|
||||||
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
|
state_recurrent = kv->get_kv_recurrent()->init_full();
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
||||||
llama_kv_cache_unified_state * state_unified,
|
llama_kv_cache_hybrid_recurrent * kv,
|
||||||
llama_kv_cache_recurrent_state * state_recurrent)
|
llama_context * lctx,
|
||||||
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
|
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
state_attn(state_unified),
|
state_attn = kv->get_kv_attn ()->init_update(lctx, optimize);
|
||||||
state_recurrent(state_recurrent) {}
|
state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
||||||
llama_kv_cache_hybrid_recurrent * kv,
|
llama_kv_cache_hybrid_recurrent * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<uint32_t> heads_attn,
|
std::vector<uint32_t> heads_attn,
|
||||||
std::vector<llama_ubatch> ubatches)
|
std::vector<llama_ubatch> ubatches)
|
||||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||||
sbatch(std::move(sbatch)),
|
sbatch(std::move(sbatch)),
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)) {
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
|
state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches));
|
||||||
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
|
state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches));
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_hybrid_recurrent_state::next() {
|
bool llama_kv_cache_hybrid_recurrent_state::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include "llama-graph.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-kv-cache-recurrent.h"
|
||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
#include "llama-kv-cells.h"
|
|
||||||
#include "llama-memory.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -12,6 +11,7 @@
|
|||||||
|
|
||||||
//
|
//
|
||||||
// llama_kv_cache_hybrid_recurrent
|
// llama_kv_cache_hybrid_recurrent
|
||||||
|
// TODO: rename to llama_memory_hybrid
|
||||||
//
|
//
|
||||||
|
|
||||||
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
|
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
|
||||||
@ -93,9 +93,6 @@ private:
|
|||||||
|
|
||||||
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
|
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
|
||||||
public:
|
public:
|
||||||
using llama_kv_cache_unified_state_ptr = std::unique_ptr<llama_kv_cache_unified_state>;
|
|
||||||
using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>;
|
|
||||||
|
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
|
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
|
||||||
|
|
||||||
@ -104,8 +101,9 @@ public:
|
|||||||
|
|
||||||
// init update
|
// init update
|
||||||
explicit llama_kv_cache_hybrid_recurrent_state(
|
explicit llama_kv_cache_hybrid_recurrent_state(
|
||||||
llama_kv_cache_unified_state * state_unified,
|
llama_kv_cache_hybrid_recurrent * kv,
|
||||||
llama_kv_cache_recurrent_state * state_recurrent);
|
llama_context * lctx,
|
||||||
|
bool optimize);
|
||||||
|
|
||||||
// init success
|
// init success
|
||||||
llama_kv_cache_hybrid_recurrent_state(
|
llama_kv_cache_hybrid_recurrent_state(
|
||||||
@ -132,7 +130,7 @@ public:
|
|||||||
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
|
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llama_memory_status status;
|
llama_memory_status status;
|
||||||
|
|
||||||
llama_sbatch sbatch;
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
@ -141,6 +139,6 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
const llama_memory_state_ptr state_attn;
|
llama_memory_state_ptr state_attn;
|
||||||
const llama_memory_state_ptr state_recurrent;
|
llama_memory_state_ptr state_recurrent;
|
||||||
};
|
};
|
||||||
|
@ -384,11 +384,11 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
|||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
|
return std::make_unique<llama_kv_cache_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
return std::make_unique<llama_kv_cache_recurrent_state>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
@ -1043,15 +1043,13 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||||
llama_memory_status status,
|
llama_kv_cache_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), is_full(true) {
|
||||||
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_recurrent * kv,
|
llama_kv_cache_recurrent * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||||
|
|
||||||
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
||||||
|
|
||||||
|
@ -11,8 +11,10 @@
|
|||||||
// llama_kv_cache_recurrent
|
// llama_kv_cache_recurrent
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
// TODO: extract the cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||||
|
// TODO: avoid the notion of "KV cache" / "KV cells", etc.
|
||||||
|
// TODO: rename to llama_recurrent_state / llama_recurrent_cache
|
||||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -131,12 +133,10 @@ public:
|
|||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache state
|
||||||
llama_kv_cache_recurrent_state(
|
llama_kv_cache_recurrent_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_recurrent * kv);
|
llama_kv_cache_recurrent * kv);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a state from a batch
|
||||||
llama_kv_cache_recurrent_state(
|
llama_kv_cache_recurrent_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_recurrent * kv,
|
llama_kv_cache_recurrent * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
auto * rs_inp = build_rs_inp_recurrent();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
// norm
|
// norm
|
||||||
@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||||
|
|
||||||
auto * rs_inp = build_rs_inp_recurrent();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
@ -12187,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
auto * rs_inp = build_rs_inp_recurrent();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
@ -12441,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||||
|
|
||||||
auto * rs_inp = build_rs_inp_recurrent();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
@ -12532,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
auto * rs_inp = build_rs_inp_recurrent();
|
auto * rs_inp = build_rs_inp();
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = hparams.n_embd;
|
||||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
Reference in New Issue
Block a user