diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 611b3f6fc..ac6fc4e1e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13736,6 +13736,11 @@ struct llm_build_arcee : public llm_graph_context { struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + // collect layers for which RoPE is disabled (metadata key: "smollm3.no_rope_layers") std::vector no_rope_layers; if (arch == LLM_ARCH_SMOLLM3) { const int kid = gguf_find_key(model.meta, "smollm3.no_rope_layers"); @@ -13747,59 +13752,134 @@ struct llm_build_smollm3 : public llm_graph_context { } } - const int64_t n_tokens = params.n_tokens; - const int64_t n_layer = hparams.n_layer; + // token embeddings + ggml_tensor * inpL = build_inp_embd(model.tok_embd); - gf->n_threads = params.n_threads; + // positional ids + ggml_tensor * inp_pos = build_inp_pos(); - // build the graph - inp_tokens->set_input(ubatch); - inp_pos->set_input(ubatch); - inp_attn_temp->set_input(ubatch); + // attention helper (unified KV cache) + auto * inp_attn = build_attn_inp_kv_unified(); - struct ggml_tensor * cur = build_inp_embd(); - struct ggml_tensor * lay_out = nullptr; + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + ggml_tensor * cur = nullptr; for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inp_norm = build_norm(cur, hparams.f_norm_eps, il, tn(LLM_TENSOR_ATTN_NORM, il)); - struct ggml_tensor * qkv = build_attn(inp_norm, il); - struct ggml_tensor * q = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, 0, 0, 0); - struct ggml_tensor * k = ggml_view_4d(ctx, qkv, hparams.n_embd_head_k, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_k, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); - struct ggml_tensor * v = ggml_view_4d(ctx, qkv, hparams.n_embd_head_v, hparams.n_head_kv(il), n_tokens, 1, ggml_element_size(qkv)*hparams.n_embd_head_v, ggml_element_size(qkv)*hparams.n_embd_k_gqa(il) + ggml_element_size(qkv)*hparams.n_embd_k_gqa(il), 0, 0); + ggml_tensor * inpSA = inpL; - ggml_set_name(q, "q"); - ggml_set_name(k, "k"); - ggml_set_name(v, "v"); + // attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - struct ggml_tensor * qcur = q; - struct ggml_tensor * kcur = k; - - bool apply_rope = true; - if (arch == LLM_ARCH_SMOLLM3) { - if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) != no_rope_layers.end()) { - apply_rope = false; + // ---- self-attention ---- + { + // fused QKV projection + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + if (model.layers[il].bqkv) { + qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); + cb(qkv, "bqkv", il); } + + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (std::find(no_rope_layers.begin(), no_rope_layers.end(), il) == no_rope_layers.end()) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } - if (apply_rope && get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il))) { - qcur = ggml_rope_ext(ctx, q, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); - kcur = ggml_rope_ext(ctx, k, inp_pos->pos, get_tensor_meta(tn(LLM_TENSOR_ROPE_FREQS, il)), hparams.rope_type, 0, hparams.n_rot, hparams.n_gqa(il), hparams.rope_freq_base_train, hparams.rope_freq_scale_train, hparams.n_ctx_orig_yarn, hparams.rope_yarn_log_mul); + // skip padded tokens for final layer + if (il == n_layer - 1) { + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - struct ggml_tensor * attn_out = build_attn_out(inp_norm, qcur, kcur, v, il); - + // ---- feed-forward ---- if (hparams.use_par_res) { // parallel residual - lay_out = ggml_add(ctx, attn_out, build_ff_par(inp_norm, il)); + ggml_tensor * ffn_cur = build_norm(inpL, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(ffn_cur, "ffn_norm", il); + + ffn_cur = build_ffn(ffn_cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_cur); + cb(cur, "par_res", il); } else { // sequential residual - lay_out = ggml_add(ctx, cur, attn_out); - lay_out = build_ff_seq(lay_out, il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); } - cur = lay_out; + + // post-processing + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; } - build_output(cur, lay_out); + // final RMSNorm + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); } };