mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd : support SmolVLM (version 1 and 2) (#13050)
* mtmd : support SmolVLM (version 1 and 2) * correct chat template * fix n_patches * scale_factor is an int * add more models to test
This commit is contained in:
@ -419,8 +419,12 @@ class ModelBase:
|
|||||||
def load_hparams(dir_model: Path):
|
def load_hparams(dir_model: Path):
|
||||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
hparams = json.load(f)
|
hparams = json.load(f)
|
||||||
|
architectures = hparams.get("architectures")
|
||||||
if "text_config" in hparams:
|
if "text_config" in hparams:
|
||||||
hparams = {**hparams, **hparams["text_config"]}
|
hparams = {**hparams, **hparams["text_config"]}
|
||||||
|
if architectures is not None:
|
||||||
|
# preserve "architectures" from root level config
|
||||||
|
hparams["architectures"] = architectures
|
||||||
return hparams
|
return hparams
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1061,6 +1065,8 @@ class TextModel(ModelBase):
|
|||||||
class VisionModel(ModelBase):
|
class VisionModel(ModelBase):
|
||||||
model_arch = gguf.MODEL_ARCH.CLIP_VISION
|
model_arch = gguf.MODEL_ARCH.CLIP_VISION
|
||||||
n_text_embd = 0
|
n_text_embd = 0
|
||||||
|
preprocessor_config: dict[str, Any]
|
||||||
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -1075,24 +1081,33 @@ class VisionModel(ModelBase):
|
|||||||
|
|
||||||
if "vision_config" not in self.hparams:
|
if "vision_config" not in self.hparams:
|
||||||
raise ValueError("vision_config not found in hparams")
|
raise ValueError("vision_config not found in hparams")
|
||||||
# move vision config to the top level
|
# move vision config to the top level, while preserving the original hparams in global_config
|
||||||
|
self.global_config = self.hparams
|
||||||
self.hparams = self.hparams["vision_config"]
|
self.hparams = self.hparams["vision_config"]
|
||||||
|
|
||||||
|
# load preprocessor config
|
||||||
|
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||||
|
self.preprocessor_config = json.load(f)
|
||||||
|
|
||||||
def set_type(self):
|
def set_type(self):
|
||||||
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
|
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, self.n_embd_text)
|
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
||||||
self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)
|
self.gguf_writer.add_vision_has_vision_encoder(True)
|
||||||
|
|
||||||
# vision config
|
# vision config
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE, self.find_hparam(["image_size"]))
|
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE, self.find_hparam(["patch_size"]))
|
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH, self.find_hparam(["hidden_size"]))
|
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH, self.find_hparam(["intermediate_size"]))
|
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT, self.find_hparam(["num_hidden_layers"]))
|
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
|
||||||
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
|
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
||||||
|
|
||||||
|
# preprocessor config
|
||||||
|
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
||||||
|
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_mean"])
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
raise ValueError("VisionModel does not support vocab writing")
|
raise ValueError("VisionModel does not support vocab writing")
|
||||||
@ -1703,11 +1718,23 @@ class StableLMModel(TextModel):
|
|||||||
raise ValueError(f"Unprocessed norms: {norms}")
|
raise ValueError(f"Unprocessed norms: {norms}")
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
@ModelBase.register(
|
||||||
|
"LLaMAForCausalLM",
|
||||||
|
"LlamaForCausalLM",
|
||||||
|
"MistralForCausalLM",
|
||||||
|
"MixtralForCausalLM",
|
||||||
|
"Idefics3ForConditionalGeneration",
|
||||||
|
"SmolVLMForConditionalGeneration")
|
||||||
class LlamaModel(TextModel):
|
class LlamaModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||||
undo_permute = True
|
undo_permute = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# fix for SmolVLM2, missing `num_attention_heads` in config.json
|
||||||
|
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
|
||||||
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
try:
|
try:
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
@ -1770,6 +1797,12 @@ class LlamaModel(TextModel):
|
|||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
n_head = self.hparams["num_attention_heads"]
|
n_head = self.hparams["num_attention_heads"]
|
||||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
|
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
|
||||||
|
|
||||||
|
if is_vision_tensor:
|
||||||
|
return [] # skip vision tensors
|
||||||
|
elif name.startswith("model.text_model"):
|
||||||
|
name = name.replace("text_model.", "") # for SmolVLM
|
||||||
|
|
||||||
if self.undo_permute:
|
if self.undo_permute:
|
||||||
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
if name.endswith(("q_proj.weight", "q_proj.bias")):
|
||||||
@ -1852,6 +1885,41 @@ class LlamaModel(TextModel):
|
|||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
|
||||||
|
class SmolVLMModel(VisionModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# fix for SmolVLM2, missing some keys in config.json
|
||||||
|
# default values are taken from transformers code
|
||||||
|
if self.hparams["model_type"] == "smolvlm_vision":
|
||||||
|
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
|
||||||
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
|
||||||
|
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
|
||||||
|
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||||
|
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
|
||||||
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
|
del bid, new_name, n_dims # unused
|
||||||
|
if ".embeddings." in name:
|
||||||
|
return gguf.GGMLQuantizationType.F32
|
||||||
|
return False
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
|
||||||
|
|
||||||
|
if is_vision_tensor:
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
return [] # skip other tensors
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||||
class Llama4Model(LlamaModel):
|
class Llama4Model(LlamaModel):
|
||||||
model_arch = gguf.MODEL_ARCH.LLAMA4
|
model_arch = gguf.MODEL_ARCH.LLAMA4
|
||||||
@ -3591,12 +3659,10 @@ class Gemma3VisionModel(VisionModel):
|
|||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
self.gguf_writer.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
|
||||||
# default values below are taken from HF tranformers code
|
# default values below are taken from HF tranformers code
|
||||||
self.gguf_writer.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, hparams.get("layer_norm_eps", 1e-6))
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
|
||||||
self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_STD, [0.5, 0.5, 0.5])
|
|
||||||
self.gguf_writer.add_bool (gguf.Keys.ClipVision.USE_GELU, True)
|
|
||||||
|
|
||||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
del bid, new_name, n_dims # unused
|
del bid, new_name, n_dims # unused
|
||||||
@ -3614,10 +3680,6 @@ class Gemma3VisionModel(VisionModel):
|
|||||||
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
|
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
|
||||||
# process vision tensors
|
# process vision tensors
|
||||||
name = name.replace("_weight", ".weight")
|
name = name.replace("_weight", ".weight")
|
||||||
if "fc1" in name:
|
|
||||||
name = name.replace("fc1", "fc2")
|
|
||||||
else:
|
|
||||||
name = name.replace("fc2", "fc1")
|
|
||||||
|
|
||||||
# correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
|
# correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
|
||||||
# the other norm values are part of SigLIP model, and they are already correct
|
# the other norm values are part of SigLIP model, and they are already correct
|
||||||
|
@ -33,13 +33,13 @@
|
|||||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
|
||||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
|
||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||||
@ -72,6 +72,7 @@
|
|||||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||||
|
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||||
|
|
||||||
// mimicpmv
|
// mimicpmv
|
||||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||||
@ -99,6 +100,7 @@ enum projector_type {
|
|||||||
PROJECTOR_TYPE_GLM_EDGE,
|
PROJECTOR_TYPE_GLM_EDGE,
|
||||||
PROJECTOR_TYPE_MERGER,
|
PROJECTOR_TYPE_MERGER,
|
||||||
PROJECTOR_TYPE_GEMMA3,
|
PROJECTOR_TYPE_GEMMA3,
|
||||||
|
PROJECTOR_TYPE_IDEFICS3,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -110,6 +112,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||||
|
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
@ -159,6 +159,7 @@ struct clip_hparams {
|
|||||||
int32_t projection_dim;
|
int32_t projection_dim;
|
||||||
int32_t n_head;
|
int32_t n_head;
|
||||||
int32_t n_layer;
|
int32_t n_layer;
|
||||||
|
int32_t proj_scale_factor = 0; // idefics3
|
||||||
|
|
||||||
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
|
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
|
||||||
|
|
||||||
@ -506,6 +507,35 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
embeddings = ggml_mul_mat(ctx0,
|
embeddings = ggml_mul_mat(ctx0,
|
||||||
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
||||||
embeddings);
|
embeddings);
|
||||||
|
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
|
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
||||||
|
|
||||||
|
ggml_tensor * cur = embeddings;
|
||||||
|
const int scale_factor = model.hparams.proj_scale_factor;
|
||||||
|
const int n_embd = cur->ne[0];
|
||||||
|
const int seq = cur->ne[1];
|
||||||
|
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||||
|
const int height = std::sqrt(seq);
|
||||||
|
const int width = std::sqrt(seq);
|
||||||
|
GGML_ASSERT(scale_factor != 0);
|
||||||
|
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
|
||||||
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||||
|
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
|
||||||
|
n_embd * scale_factor * scale_factor,
|
||||||
|
height / scale_factor,
|
||||||
|
width / scale_factor,
|
||||||
|
bsz);
|
||||||
|
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||||
|
cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
|
||||||
|
n_embd * scale_factor * scale_factor,
|
||||||
|
seq / (scale_factor * scale_factor),
|
||||||
|
bsz);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
||||||
|
embeddings = cur;
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("SigLIP: Unsupported projector type");
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the graph
|
// build the graph
|
||||||
@ -1081,12 +1111,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
ggml_cgraph * res;
|
||||||
return clip_image_build_graph_siglip(ctx, imgs);
|
switch (ctx->proj_type) {
|
||||||
} else {
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
|
{
|
||||||
|
res = clip_image_build_graph_siglip(ctx, imgs);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
// TODO: we should have one build_* function per model
|
// TODO: we should have one build_* function per model
|
||||||
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
res = clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct clip_model_loader {
|
struct clip_model_loader {
|
||||||
@ -1147,6 +1185,8 @@ struct clip_model_loader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void load_hparams() {
|
void load_hparams() {
|
||||||
|
auto & hparams = ctx_clip.vision_model.hparams;
|
||||||
|
|
||||||
// projector type
|
// projector type
|
||||||
{
|
{
|
||||||
std::string proj_type;
|
std::string proj_type;
|
||||||
@ -1177,7 +1217,6 @@ struct clip_model_loader {
|
|||||||
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
||||||
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
||||||
|
|
||||||
auto & hparams = ctx_clip.vision_model.hparams;
|
|
||||||
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
|
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
|
||||||
get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head);
|
get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head);
|
||||||
get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate);
|
get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate);
|
||||||
@ -1233,6 +1272,16 @@ struct clip_model_loader {
|
|||||||
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||||
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// model-specific params
|
||||||
|
switch (ctx_clip.proj_type) {
|
||||||
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
|
{
|
||||||
|
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_tensors() {
|
void load_tensors() {
|
||||||
@ -1422,6 +1471,10 @@ struct clip_model_loader {
|
|||||||
vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||||
vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
|
vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
|
{
|
||||||
|
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown projector type");
|
GGML_ASSERT(false && "unknown projector type");
|
||||||
}
|
}
|
||||||
@ -2195,10 +2248,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
if (ctx->has_glm_projector
|
||||||
|
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
||||||
|
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
clip_image_u8 resized_image;
|
clip_image_u8 resized_image;
|
||||||
int sz = params.image_size;
|
int sz = params.image_size;
|
||||||
image_manipulation::bicubic_resize(*img, resized_image, sz, sz);
|
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
|
||||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||||
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
||||||
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
|
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
|
||||||
@ -2330,6 +2385,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
|||||||
n_patches = x_patch * y_patch;
|
n_patches = x_patch * y_patch;
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||||
n_patches = 256;
|
n_patches = 256;
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
|
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
return n_patches;
|
return n_patches;
|
||||||
@ -2597,6 +2654,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
|
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||||
|
|
||||||
@ -2783,37 +2843,34 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
|||||||
}
|
}
|
||||||
|
|
||||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
switch (ctx->proj_type) {
|
||||||
|
case PROJECTOR_TYPE_LDP:
|
||||||
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
||||||
}
|
case PROJECTOR_TYPE_LDPV2:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
|
||||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||||
}
|
case PROJECTOR_TYPE_MLP:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
|
||||||
return ctx->vision_model.mm_2_b->ne[0];
|
return ctx->vision_model.mm_2_b->ne[0];
|
||||||
}
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
|
||||||
return ctx->vision_model.mm_3_b->ne[0];
|
return ctx->vision_model.mm_3_b->ne[0];
|
||||||
}
|
case PROJECTOR_TYPE_RESAMPLER:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
|
||||||
if (ctx->minicpmv_version == 2) {
|
if (ctx->minicpmv_version == 2) {
|
||||||
return 4096;
|
return 4096;
|
||||||
}
|
} else if (ctx->minicpmv_version == 3) {
|
||||||
else if (ctx->minicpmv_version == 3) {
|
return 3584;
|
||||||
|
} else if (ctx->minicpmv_version == 4) {
|
||||||
return 3584;
|
return 3584;
|
||||||
}
|
}
|
||||||
else if (ctx->minicpmv_version == 4) {
|
break; // Should not happen if version is valid
|
||||||
return 3584;
|
case PROJECTOR_TYPE_GLM_EDGE:
|
||||||
}
|
|
||||||
}
|
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
|
|
||||||
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
|
||||||
}
|
case PROJECTOR_TYPE_MERGER:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
|
||||||
return ctx->vision_model.mm_1_b->ne[0];
|
return ctx->vision_model.mm_1_b->ne[0];
|
||||||
}
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
|
||||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||||
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
|
return ctx->vision_model.projection->ne[1];
|
||||||
|
default:
|
||||||
|
break; // Fall through to throw
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||||
|
@ -176,6 +176,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
|
|
||||||
std::string prompt_modified(text.text);
|
std::string prompt_modified(text.text);
|
||||||
std::string marker_modified(ctx->image_marker);
|
std::string marker_modified(ctx->image_marker);
|
||||||
|
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||||
|
|
||||||
// a bit hacky here, but works for now
|
// a bit hacky here, but works for now
|
||||||
// for some models, we need to add prefix and suffix to the image embeddings
|
// for some models, we need to add prefix and suffix to the image embeddings
|
||||||
if (clip_is_gemma3(ctx->ctx_clip)) {
|
if (clip_is_gemma3(ctx->ctx_clip)) {
|
||||||
@ -183,6 +185,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||||
|
|
||||||
|
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
|
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||||
|
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
||||||
|
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||||
|
@ -28,6 +28,9 @@ add_test() {
|
|||||||
arr_tmpl+=("$tmpl")
|
arr_tmpl+=("$tmpl")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
||||||
add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
|
||||||
add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
|
add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
|
||||||
add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
|
add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
|
||||||
@ -39,7 +42,13 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
|||||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||||
|
|
||||||
# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" # this model has broken chat template, not usable
|
# these models always give the wrong answer, not sure why
|
||||||
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||||
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
|
||||||
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
|
||||||
|
|
||||||
|
# this model has broken chat template, not usable
|
||||||
|
# add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
|
||||||
|
|
||||||
###############
|
###############
|
||||||
|
|
||||||
|
@ -231,11 +231,15 @@ class Keys:
|
|||||||
IMAGE_MEAN = "clip.vision.image_mean"
|
IMAGE_MEAN = "clip.vision.image_mean"
|
||||||
IMAGE_STD = "clip.vision.image_std"
|
IMAGE_STD = "clip.vision.image_std"
|
||||||
USE_GELU = "clip.use_gelu"
|
USE_GELU = "clip.use_gelu"
|
||||||
|
USE_SILU = "clip.use_silu"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||||
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
|
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
|
||||||
|
|
||||||
|
class Projector:
|
||||||
|
SCALE_FACTOR = "clip.vision.projector.scale_factor"
|
||||||
|
|
||||||
#
|
#
|
||||||
# recommended mapping of model tensor names for storage in gguf
|
# recommended mapping of model tensor names for storage in gguf
|
||||||
#
|
#
|
||||||
@ -2122,6 +2126,11 @@ class GGUFValueType(IntEnum):
|
|||||||
raise ValueError(f"Unknown type: {type(val)}")
|
raise ValueError(f"Unknown type: {type(val)}")
|
||||||
|
|
||||||
|
|
||||||
|
class VisionProjectorType:
|
||||||
|
GEMMA3 = "gemma3"
|
||||||
|
IDEFICS3 = "idefics3"
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
QK_K = 256
|
QK_K = 256
|
||||||
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||||
|
@ -931,6 +931,53 @@ class GGUFWriter:
|
|||||||
def add_eom_token_id(self, id: int) -> None:
|
def add_eom_token_id(self, id: int) -> None:
|
||||||
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
||||||
|
|
||||||
|
# for vision models
|
||||||
|
|
||||||
|
def add_vision_projection_dim(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
||||||
|
|
||||||
|
def add_vision_has_vision_encoder(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
|
||||||
|
|
||||||
|
def add_vision_patch_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
||||||
|
|
||||||
|
def add_vision_embedding_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
|
||||||
|
|
||||||
|
def add_vision_feed_forward_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
|
||||||
|
|
||||||
|
def add_vision_block_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
|
||||||
|
|
||||||
|
def add_vision_head_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
||||||
|
|
||||||
|
def add_vision_projector_type(self, value: str) -> None:
|
||||||
|
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
|
||||||
|
|
||||||
|
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
||||||
|
|
||||||
|
def add_vision_image_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
|
||||||
|
|
||||||
|
def add_vision_image_mean(self, values: Sequence[float]) -> None:
|
||||||
|
self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
|
||||||
|
|
||||||
|
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
||||||
|
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
||||||
|
|
||||||
|
def add_vision_use_gelu(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
||||||
|
|
||||||
|
def add_vision_use_silu(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.ClipVision.USE_SILU, value)
|
||||||
|
|
||||||
|
def add_vision_projector_scale_factor(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
|
||||||
|
|
||||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||||
pack_prefix = ''
|
pack_prefix = ''
|
||||||
if not skip_pack_prefix:
|
if not skip_pack_prefix:
|
||||||
|
@ -961,13 +961,13 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_PRE_NORM: (
|
MODEL_TENSOR.V_PRE_NORM: (
|
||||||
|
@ -62,6 +62,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||||
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
@ -81,6 +82,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
if (tmpl_contains("<|im_start|>")) {
|
if (tmpl_contains("<|im_start|>")) {
|
||||||
return tmpl_contains("<|im_sep|>")
|
return tmpl_contains("<|im_sep|>")
|
||||||
? LLM_CHAT_TEMPLATE_PHI_4
|
? LLM_CHAT_TEMPLATE_PHI_4
|
||||||
|
: tmpl_contains("<end_of_utterance>")
|
||||||
|
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
|
||||||
: LLM_CHAT_TEMPLATE_CHATML;
|
: LLM_CHAT_TEMPLATE_CHATML;
|
||||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||||
@ -622,6 +625,22 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
|
||||||
|
// SmolVLM
|
||||||
|
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << message->content << "\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "User: " << message->content << "<end_of_utterance>\n";
|
||||||
|
} else {
|
||||||
|
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
@ -41,6 +41,7 @@ enum llm_chat_template {
|
|||||||
LLM_CHAT_TEMPLATE_YANDEX,
|
LLM_CHAT_TEMPLATE_YANDEX,
|
||||||
LLM_CHAT_TEMPLATE_BAILING,
|
LLM_CHAT_TEMPLATE_BAILING,
|
||||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||||
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user