From 6c6e397affc4fac717e718364fb4b635cec6433a Mon Sep 17 00:00:00 2001 From: Dongliang Wei <121270393+wdl339@users.noreply.github.com> Date: Mon, 28 Jul 2025 19:47:00 +0800 Subject: [PATCH] model : add support for SmallThinker series (#14898) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support smallthinker * support 20b softmax, 4b no sliding window * new build_moe_ffn_from_probs, and can run 4b * fix 4b rope bug * fix python type check * remove is_moe judge * remove set_dense_start_swa_pattern function and modify set_swa_pattern function * trim trailing whitespace * remove get_vocab_base of SmallThinkerModel in convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * better whitespace Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * use GGML_ASSERT for expert count validation Co-authored-by: Sigbjørn Skjæret * Improve null pointer check for probs Co-authored-by: Sigbjørn Skjæret * use template parameter for SWA attention logic * better whitespace Co-authored-by: Georgi Gerganov * move the creation of inp_out_ids before the layer loop * remove redundant judge for probs --------- Co-authored-by: Sigbjørn Skjæret Co-authored-by: Georgi Gerganov --- convert_hf_to_gguf.py | 82 +++++++++++++++ gguf-py/gguf/constants.py | 20 ++++ gguf-py/gguf/tensor_mapping.py | 7 ++ src/llama-arch.cpp | 22 ++++ src/llama-arch.h | 1 + src/llama-graph.cpp | 94 +++++++++++++++++ src/llama-graph.h | 12 +++ src/llama-hparams.cpp | 12 ++- src/llama-hparams.h | 13 ++- src/llama-model.cpp | 186 +++++++++++++++++++++++++++++++++ 10 files changed, 443 insertions(+), 6 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2c51b6c42..9461206cf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7589,6 +7589,88 @@ class LFM2Model(TextModel): return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("SmallThinkerForCausalLM") +class SmallThinkerModel(TextModel): + model_arch = gguf.MODEL_ARCH.SMALLTHINKER + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + self.gguf_writer.add_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + if (self.hparams.get('moe_primary_router_apply_softmax')): + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + # YaRN is not enabled by default + # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + sliding_window_layout = self.hparams.get("sliding_window_layout") + if sliding_window_layout: + for i in sliding_window_layout: + if i != 0: + sliding_window = self.hparams.get("sliding_window_size") + if sliding_window: + self.gguf_writer.add_sliding_window(sliding_window) + break + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down", "gate", "up"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 680210db7..7ac1a1971 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() LFM2 = auto() DREAM = auto() + SMALLTHINKER = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -695,6 +696,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", + MODEL_ARCH.SMALLTHINKER: "smallthinker", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2483,6 +2485,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, ], + MODEL_ARCH.SMALLTHINKER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7fbda422f..bfd4fd37a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -317,6 +317,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.router", # llama4 jamba "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan + "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -362,6 +363,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w12", # neobert + "model.layers.{bid}.block_sparse_moe.up", # smallthinker ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -372,6 +374,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -401,6 +404,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid + "model.layers.{bid}.block_sparse_moe.gate", # smallthinker ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -410,6 +414,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -448,6 +453,7 @@ class TensorNameMap: "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w3", # neobert + "model.layers.{bid}.block_sparse_moe.down", # smallthinker ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -459,6 +465,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 062a99776..dbf977443 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -88,6 +88,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, + { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1933,6 +1934,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, } }, + { + LLM_ARCH_SMALLTHINKER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } + }, + }, { LLM_ARCH_DREAM, { diff --git a/src/llama-arch.h b/src/llama-arch.h index d09b7d781..8267a8d3a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -92,6 +92,7 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, LLM_ARCH_DREAM, + LLM_ARCH_SMALLTHINKER, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b63a41053..1b9cc4aec 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn( return moe_out; } +ggml_tensor * llm_graph_context::build_moe_ffn_from_probs( + ggml_tensor * cur, + ggml_tensor * probs, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llama_expert_gating_func_type gating_op, + int il) const { + const int64_t n_embd = cur->ne[0]; + const int64_t n_tokens = cur->ne[1]; + + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx0, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) { + weights = ggml_soft_max(ctx0, weights); + } else { + weights = ggml_sigmoid(ctx0, weights); + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + } + + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + + ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * experts = nullptr; + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + + cur = ggml_reglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_reglu", il); + + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + experts = ggml_mul(ctx0, experts, weights); + cb(cur, "ffn_moe_weighted", il); + + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; + + assert(n_expert_used > 0); + + // order the views before the adds + for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { + cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); + + ggml_build_forward_expand(gf, cur_experts[i]); + } + + // aggregate experts + // note: here we explicitly use hparams.n_expert_used instead of n_expert_used + // to avoid potentially a large number of add nodes during warmup + // ref: https://github.com/ggml-org/llama.cpp/pull/14753 + ggml_tensor * moe_out = cur_experts[0]; + + for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { + moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); + } + + if (n_expert_used == 1) { + // avoid returning a non-contiguous tensor + moe_out = ggml_cont(ctx0, moe_out); + } + + cb(moe_out, "ffn_moe_out", il); + + return moe_out; +} + // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index a28a8c4bd..d4d565754 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -625,6 +625,18 @@ struct llm_graph_context { llama_expert_gating_func_type gating_op, int il) const; + ggml_tensor * build_moe_ffn_from_probs( + ggml_tensor * cur, + ggml_tensor * probs, + ggml_tensor * up_exps, + ggml_tensor * gate_exps, + ggml_tensor * down_exps, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llama_expert_gating_func_type gating_op, + int il) const; + // // inputs // diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c6c67d26f..7a06368dc 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,9 +2,15 @@ #include "ggml.h" -void llama_hparams::set_swa_pattern(uint32_t n_pattern) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); +void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } } } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index ec7fd6a42..8b7e2a113 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -140,7 +140,7 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; - // llama4 + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; @@ -161,9 +161,10 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA // if n_pattern == 1, all layers are dense - // example: n_pattern = 3 + // example 1: n_pattern = 3, dense_first = false // il == 0: swa // il == 1: swa // il == 2: dense @@ -172,7 +173,13 @@ struct llama_hparams { // il == 5: dense // il == 6: swa // etc ... - void set_swa_pattern(uint32_t n_pattern); + // example 2: n_pattern = 2, dense_first = true + // il == 0: dense + // il == 1: swa + // il == 2: dense + // il == 3: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern, bool dense_first = false); // return true if one of the layers is SWA bool is_swa_any() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 71f89e190..e3aa9e6f9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1768,6 +1768,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SMALLTHINKER: + { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_swa_pattern(4, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; + } + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -5165,6 +5188,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_SMALLTHINKER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5490,6 +5549,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } + if (arch == LLM_ARCH_SMALLTHINKER) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } + vocab.print_info(); } @@ -17011,6 +17075,119 @@ struct llm_build_lfm2 : public llm_graph_context { } }; +template +struct llm_build_smallthinker : public llm_graph_context{ + llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * probs = nullptr; + + probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + cb(probs, "ffn_moe_logits", il); + + // norm + cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + probs = ggml_get_rows(ctx0, probs, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, + static_cast(hparams.expert_gating_func), il); + + cb(ffn_out, "ffn_out", il); + cur = ffn_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -17449,6 +17626,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_SMALLTHINKER: + { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } + } break; default: GGML_ABORT("fatal error"); } @@ -17647,6 +17832,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_LFM2: + case LLM_ARCH_SMALLTHINKER: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: