mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
update: support Qwen2-57B-A14B (#7835)
* update: convert-hf-to-gguf.py to support Qwen2-57B-A14B * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shared_exp are now properly calculated * update: convert-hf-to-gguf.py cleanup for Qwen2MoeForCausalLM * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated
This commit is contained in:
committed by
GitHub
parent
5b6da18750
commit
a94e6ff877
@ -1632,6 +1632,12 @@ class Qwen2MoeModel(Model):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
|
||||
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
|
Reference in New Issue
Block a user