llama: Add support for RWKV v7 architecture (#12412)

* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code-format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* fix MUSA build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia
2025-03-18 07:27:50 +08:00
committed by GitHub
parent 60c902926c
commit 7dfad387e3
35 changed files with 2948 additions and 438 deletions

View File

@ -32,6 +32,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_109M: return "109M";
case LLM_TYPE_137M: return "137M";
case LLM_TYPE_160M: return "160M";
case LLM_TYPE_190M: return "190M";
case LLM_TYPE_220M: return "220M";
case LLM_TYPE_250M: return "250M";
case LLM_TYPE_270M: return "270M";
@ -48,6 +49,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_1_6B: return "1.6B";
case LLM_TYPE_2B: return "2B";
case LLM_TYPE_2_8B: return "2.8B";
case LLM_TYPE_2_9B: return "2.9B";
case LLM_TYPE_3B: return "3B";
case LLM_TYPE_4B: return "4B";
case LLM_TYPE_6B: return "6B";
@ -1250,6 +1252,36 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay);
ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr);
ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false);
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
switch (hparams.n_layer) {
case 12: type = LLM_TYPE_190M; break;
case 24:
switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_450M; break;
case 2048: type = LLM_TYPE_1_5B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 28:
switch (hparams.n_embd) {
case 1536: type = LLM_TYPE_1_5B; break;
case 3584: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
@ -3366,6 +3398,146 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_RWKV7:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// Block 0, LN0
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
const int n_lora_decay = hparams.n_lora_decay;
const int n_lora_iclr = hparams.n_lora_iclr;
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
const int n_lora_gate = hparams.n_lora_gate;
const int attn_hidden_size = n_embd;
const int ffn_size = hparams.n_ff_arr[0];
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0);
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
if (i == 0) {
// actually not used
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
} else {
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
}
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0);
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0);
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
}
} break;
case LLM_ARCH_ARWKV7:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
const int n_lora_decay = hparams.n_lora_decay;
const int n_lora_iclr = hparams.n_lora_iclr;
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
const int n_lora_gate = hparams.n_lora_gate;
const int attn_hidden_size = n_embd;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
if (i == 0) {
// actually not used
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
} else {
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
}
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
try {
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
} catch(std::runtime_error & e) {
// ARWKV models may not have gate tensors
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
}
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_CHAMELEON:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -10212,6 +10384,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_embd = hparams.n_embd;
const auto head_size = hparams.wkv_head_size;
const auto n_head = n_embd / head_size;
@ -10224,6 +10397,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
bool is_qrwkv = layer.time_mix_first == nullptr;
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
xxx = ggml_reshape_4d(
@ -10366,7 +10543,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
cur = ggml_mul(ctx0, cur, g);
cur = build_lora_mm(layer.time_mix_output, cur);
return cur;
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
}
};
@ -10389,6 +10566,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, state_mask, ubatch, il
@ -10422,9 +10600,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
1
);
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
cur = ggml_add(ctx0, cur, ffn_inp);
token_shift = ggml_concat(ctx0,
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
@ -10432,6 +10607,18 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
);
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
}
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
cur = ggml_add(ctx0, cur, ffn_inp);
if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
cur = ggml_scale(ctx0, cur, 0.5F);
}
@ -10444,12 +10631,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
}
cur = inpL;
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
cb(cur, "result_norm", -1);
@ -10481,10 +10662,9 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_seqs = ubatch.n_seqs;
inpL = build_inp_embd(model.tok_embd);
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, state_mask, ubatch, il
@ -10508,6 +10688,13 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
}
// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
@ -10532,10 +10719,358 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
}
cur = inpL;
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_rwkv7_base : public llm_graph_context {
const llama_model & model;
llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
}
ggml_tensor * build_rwkv7_channel_mix(
const llama_layer * layer,
ggml_tensor * cur,
ggml_tensor * x_prev,
llm_arch arch) const {
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
switch (arch) {
case LLM_ARCH_RWKV7:
{
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
ggml_tensor * k = ggml_sqr(
ctx0,
ggml_relu(
ctx0,
build_lora_mm(layer->channel_mix_key, xk)
)
);
cur = build_lora_mm(layer->channel_mix_value, k);
} break;
default:
GGML_ABORT("fatal error");
}
return cur;
}
ggml_tensor * build_rwkv7_time_mix(
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
const auto n_embd = hparams.n_embd;
const auto head_size = hparams.wkv_head_size;
const auto head_count = n_embd / head_size;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = kv_self->head;
const auto & layer = model.layers[il];
bool has_gating = layer.time_mix_g1 && layer.time_mix_g2;
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
sx = ggml_repeat(ctx0, sx, dummy);
ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur);
ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
ggml_tensor * xg = has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
ggml_tensor * w = ggml_add(
ctx0,
ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))),
layer.time_mix_w0
);
w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531));
ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
if (first_layer_value == nullptr) {
first_layer_value = v;
} else {
// Add the first layer value as a residual connection.
v = ggml_add(ctx0, v,
ggml_mul(ctx0,
ggml_sub(ctx0, first_layer_value, v),
ggml_sigmoid(ctx0, ggml_add(ctx0,
ggml_mul_mat(ctx0, layer.time_mix_v2, ggml_mul_mat(ctx0, layer.time_mix_v1, xv)),
layer.time_mix_v0
)
)
)
);
}
ggml_tensor * g = nullptr;
if (layer.time_mix_g1 && layer.time_mix_g2) {
g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg)));
}
ggml_tensor * a = ggml_sigmoid(ctx0,
ggml_add(
ctx0,
ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)),
layer.time_mix_a0
)
);
ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens);
kk = ggml_l2_norm(ctx0, kk, 1e-12);
ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a);
k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka));
r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens);
w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens);
k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens);
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_copy_mask_state(
gf, kv_self->v_l[il], state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
ggml_build_forward_expand(
gf,
ggml_cpy(
ctx0,
wkv_state,
ggml_view_1d(
ctx0,
kv_self->v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
)
)
);
if (layer.time_mix_ln && layer.time_mix_ln_b) {
// group norm with head_count groups
cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens);
cur = ggml_norm(ctx0, cur, 64e-5f);
// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
} else {
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
}
ggml_tensor * rk = ggml_sum_rows(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count)));
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens));
if (has_gating) {
cur = ggml_mul(ctx0, cur, g);
}
cur = build_lora_mm(layer.time_mix_output, cur);
return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
}
};
struct llm_build_rwkv7 : public llm_build_rwkv7_base {
llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
GGML_ASSERT(hparams.token_shift_count == 2);
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * v_first = nullptr;
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_seqs = ubatch.n_seqs;
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, state_mask, ubatch, il
);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
cb(att_norm, "attn_norm", il);
ggml_tensor * x_prev = ggml_concat(
ctx0,
att_shift,
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
cb(ffn_norm, "ffn_norm", il);
x_prev = ggml_concat(
ctx0,
ffn_shift,
ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
1
);
token_shift = ggml_concat(ctx0,
ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
1
);
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
ffn_norm = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens), inp_out_ids);
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
}
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * v_first = nullptr;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto n_seqs = ubatch.n_seqs;
for (int il = 0; il < n_layer; ++il) {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, state_mask, ubatch, il
);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
ggml_tensor * x_prev = ggml_concat(
ctx0,
token_shift,
ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens), inp_out_ids);
}
// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
@ -10883,9 +11418,11 @@ llama_memory_i * llama_model::create_memory() const {
llama_memory_i * res;
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_unified(hparams, {
/*.get_rope_factors =*/ nullptr
@ -11132,6 +11669,14 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
} break;
case LLM_ARCH_RWKV7:
{
llm = std::make_unique<llm_build_rwkv7>(*this, params, gf);
} break;
case LLM_ARCH_ARWKV7:
{
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
} break;
case LLM_ARCH_CHAMELEON:
{
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
@ -11245,6 +11790,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_WAVTOKENIZER_DEC:
return LLAMA_ROPE_TYPE_NONE;
@ -11399,6 +11946,8 @@ bool llama_model_is_recurrent(const llama_model * model) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV6: return true;
case LLM_ARCH_RWKV6QWEN2: return true;
case LLM_ARCH_RWKV7: return true;
case LLM_ARCH_ARWKV7: return true;
default: return false;
}
}