mtmd : Support Pixtral 12B (#13065)

* add pixtral text model (vision is wip)

* cgraph ok, just missing 2D RoPE

* fix bad rebase

* first working version

* fix problem with img_break token

* support dynamic image size

* update docs

* update test script
This commit is contained in:
Xuan-Son Nguyen
2025-04-23 20:21:59 +02:00
committed by GitHub
parent eb1776b15a
commit ecda2ec4b3
14 changed files with 643 additions and 31 deletions

View File

@ -776,6 +776,9 @@ class TextModel(ModelBase):
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
# ref: https://huggingface.co/mistral-community/pixtral-12b
res = "pixtral"
if res is None:
logger.warning("\n")
@ -1724,7 +1727,8 @@ class StableLMModel(TextModel):
"MistralForCausalLM",
"MixtralForCausalLM",
"Idefics3ForConditionalGeneration",
"SmolVLMForConditionalGeneration")
"SmolVLMForConditionalGeneration",
"LlavaForConditionalGeneration")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
@ -1734,6 +1738,10 @@ class LlamaModel(TextModel):
# fix for SmolVLM2, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
# fix for Pixtral, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
and self.hparams.get("model_type") == "mistral":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self):
try:
@ -1797,12 +1805,17 @@ class LlamaModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
is_vision_tensor = "vision_tower" in name \
or "vision_model" in name \
or "model.connector" in name \
or "multi_modal_projector" in name
if is_vision_tensor:
return [] # skip vision tensors
elif name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
elif name.startswith("language_model."):
name = name.replace("language_model.", "") # for the rest
if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
@ -1885,6 +1898,55 @@ class LlamaModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("LlavaForConditionalGeneration")
class LlavaVisionModel(VisionModel):
img_break_tok_id = -1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral":
# fix missing config.json values
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
n_head = self.hparams["num_attention_heads"]
n_kv_head = n_head
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
# process vision tensors
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
img_break_embd = data_torch[self.img_break_tok_id]
name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
return [(self.map_tensor_name(name), img_break_embd)]
return [] # skip other tensors
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs):

View File

@ -115,6 +115,7 @@ models = [
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
]

View File

@ -11,15 +11,15 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
cmake --build build --target llama-mtmd-cli
# alternatively, install from brew (MacOS)
brew install llama.cpp
# run it
llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
# note: 1B model does not support vision
```
@ -44,8 +44,8 @@ What you need:
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
cmake --build build --target llama-mtmd-cli
# run it
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
```

View File

@ -14,6 +14,28 @@ The naming and structure related to multimodal support have evolved, which might
- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
## Pre-quantized models
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default:
```sh
# Gemma 3
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
# SmolVLM
llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
# Pixtral 12B
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
```
## How it works and what is `mmproj`?
Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
@ -45,3 +67,9 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint

View File

@ -60,6 +60,7 @@
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
@ -73,6 +74,7 @@
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@ -101,6 +103,7 @@ enum projector_type {
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_UNKNOWN,
};
@ -113,6 +116,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -163,7 +163,8 @@ struct clip_hparams {
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
float eps;
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
@ -187,11 +188,17 @@ struct clip_layer {
struct ggml_tensor * ln_1_b = nullptr;
// ff
struct ggml_tensor * ff_i_w = nullptr;
struct ggml_tensor * ff_i_b = nullptr;
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr;
struct ggml_tensor * ff_o_b = nullptr;
struct ggml_tensor * ff_up_w = nullptr;
struct ggml_tensor * ff_up_b = nullptr;
struct ggml_tensor * ff_gate_w = nullptr;
struct ggml_tensor * ff_gate_b = nullptr;
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
@ -297,6 +304,9 @@ struct clip_vision_model {
// gemma3
struct ggml_tensor * mm_input_proj_w = nullptr;
struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
};
struct clip_ctx {
@ -329,6 +339,7 @@ struct clip_ctx {
ggml_backend_t backend_cpu;
ggml_backend_buffer_ptr buf;
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_image_size load_image_size;
@ -544,6 +555,218 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
return gf;
}
// implementation of the 2D RoPE without adding a new op in ggml
static ggml_tensor * build_rope_2d(
ggml_cgraph * gf,
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
) {
ggml_tensor * tmp;
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
// for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
// first half of cur will use 1e-0, 1e-2 (even)
// second half of cur will use 1e-1, 1e-3 (odd)
//
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
// first half
{
cur = ggml_rope_ext_inplace(
ctx0,
cur,
pos_h, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);
}
// second half
{
tmp = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
n_dim/2 * ggml_element_size(cur));
tmp = ggml_rope_ext_inplace(
ctx0,
tmp,
pos_w, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
freq_scale_odd,
0.0f, 1.0f, 0.0f, 0.0f
);
// calculate inplace (modify cur directly)
ggml_build_forward_expand(gf, tmp);
}
return cur;
}
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
int image_size_width = imgs.entries[0]->nx;
int image_size_height = imgs.entries[0]->ny;
const int patch_size = hparams.patch_size;
const int n_patches_x = image_size_width / patch_size;
const int n_patches_y = image_size_height / patch_size;
const int num_patches = n_patches_x * n_patches_y;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx0_ptr(ggml_init(params));
auto ctx0 = ctx0_ptr.get();
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
// 2D input positions
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
struct ggml_tensor * embeddings = inp;
// pre-layer norm
embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings;
// pre-attention norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
// self-attention
{
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// pre-ffn norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
// feed-forward
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// LlavaMultiModalProjector (with GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
// arrangement of the [IMG_BREAK] token
{
// not efficient, but works
// the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
embeddings = ggml_view_2d(ctx0, cur,
n_embd_text, n_tokens_output,
ggml_row_size(cur->type, n_embd_text), 0);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
@ -1118,6 +1341,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = clip_image_build_graph_siglip(ctx, imgs);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
res = clip_image_build_graph_pixtral(ctx, imgs);
} break;
default:
{
// TODO: we should have one build_* function per model
@ -1279,6 +1506,10 @@ struct clip_model_loader {
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
} break;
default:
break;
}
@ -1350,16 +1581,26 @@ struct clip_model_loader {
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
// new naming
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
// legacy naming (the in and out is reversed! don't ask me why)
layer.ff_i_w = layer.ff_down_w;
layer.ff_o_w = layer.ff_up_w;
layer.ff_i_b = layer.ff_down_b;
layer.ff_o_b = layer.ff_up_b;
}
switch (ctx_clip.proj_type) {
@ -1475,6 +1716,15 @@ struct clip_model_loader {
{
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@ -1517,18 +1767,17 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
clip_image_size image_size;
image_size.width = clip_get_image_size(&ctx_clip);
image_size.height = clip_get_image_size(&ctx_clip);
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
img->nx = n_patches;
img->ny = n_patches;
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
image_size.width = ctx_clip.vision_model.hparams.image_size;
image_size.height = ctx_clip.vision_model.hparams.image_size;
img->nx = image_size.width;
img->ny = image_size.height;
img->buf.resize(image_size.width * image_size.height * 3);
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
@ -1916,6 +2165,26 @@ struct image_manipulation {
}
}
// calculate the size of the **resized** image, while preserving the aspect ratio
// the calculated size will be aligned to the nearest multiple of align_size
// if H or W size is larger than max_dimension, it will be resized to max_dimension
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
return {0, 0};
}
float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
static_cast<float>(max_dimension) / inp_size.height));
float target_width_f = static_cast<float>(inp_size.width) * scale;
float target_height_f = static_cast<float>(inp_size.height) * scale;
int aligned_width = GGML_PAD((int)target_width_f, align_size);
int aligned_height = GGML_PAD((int)target_height_f, align_size);
return {aligned_width, aligned_height};
}
private:
static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper));
@ -2247,8 +2516,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
if (ctx->has_glm_projector
else if (ctx->has_glm_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
@ -2260,6 +2528,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
@ -2387,6 +2664,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
n_patches = 256;
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size;
int n_patches_y = img->ny / params.patch_size;
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
return n_patches;
@ -2540,10 +2821,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
// TODO @ngxson : this whole code block is ugly, will need to be refactored
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
if (ctx->has_glm_projector
|| ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
@ -2657,6 +2943,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
// do nothing
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
std::vector<int> pos_data(num_positions);
struct ggml_tensor * pos;
// dimension H
pos = ggml_graph_get_tensor(gf, "pos_h");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i / n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
// dimension W
pos = ggml_graph_get_tensor(gf, "pos_w");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i % n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@ -2849,6 +3153,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_LDPV2:
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];

View File

@ -190,6 +190,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
}
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
@ -219,7 +224,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
for (auto & entry : batch_f32.entries) {
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_patches(ctx->ctx_clip);
image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
image_tokens->ny = 1;
image_tokens->batch_f32.entries.push_back(std::move(entry));
image_tokens->id = id;
@ -313,8 +318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
}
} else {
size_t n_tokens = 0;
for (const auto & entry : batch_f32.entries) {
n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
}
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
image_tokens->nx = n_tokens;
image_tokens->ny = 1; // TODO
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img].id; // optional
@ -382,7 +392,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) {
int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
ok = clip_image_encode(
ctx->ctx_clip,
ctx->n_threads,

View File

@ -13,6 +13,14 @@ mkdir -p $SCRIPT_DIR/output
PROJ_ROOT="$SCRIPT_DIR/../.."
cd $PROJ_ROOT
# Check if the first argument is "big", then run test with big models
# This is useful if we're running the script on a larger machine, so we can test the big models
RUN_BIG_TESTS=false
if [ "${1:-}" = "big" ]; then
RUN_BIG_TESTS=true
echo "Include BIG models..."
fi
###############
arr_bin=()
@ -28,6 +36,12 @@ add_test() {
arr_tmpl+=("$tmpl")
}
add_test_big() {
if [ "$RUN_BIG_TESTS" = true ]; then
add_test "$@"
fi
}
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@ -42,6 +56,9 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
# these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"

View File

@ -485,6 +485,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_OUTPUT = auto()
V_ENC_OUTPUT_NORM = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
@ -501,6 +502,7 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_Q_NORM = auto() # minicpmv
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -737,6 +739,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
@ -753,6 +756,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -771,6 +775,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
@ -787,6 +792,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_Q_NORM,
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
@ -2129,6 +2135,7 @@ class GGUFValueType(IntEnum):
class VisionProjectorType:
GEMMA3 = "gemma3"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
# Items here are (block size, type size)

View File

@ -914,6 +914,7 @@ class TensorNameMap:
"vision_tower.vision_model.embeddings.patch_embedding",
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
"vision_tower.patch_conv", # pixtral
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
@ -926,52 +927,65 @@ class TensorNameMap:
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
),
MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
),
MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
),
MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
),
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
),
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral
),
MODEL_TENSOR.V_POST_NORM: (
@ -1030,6 +1044,10 @@ class TensorNameMap:
MODEL_TENSOR.V_RESMPL_QUERY: (
"resampler.query",
),
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
"v.token_embd.img_break", # for pixtral, this is a generated vector
),
}
# architecture-specific block mappings

View File

@ -111,6 +111,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
};
enum llama_rope_type {

View File

@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__

View File

@ -0,0 +1,46 @@
2014 1032 1052 1032 28504 6972
1070 7088 1258
1032
1256
1293
1009
1010
1267
4688
1009 1010
22177 4304
45383 4304
22177 5325
45383 5325
45383 5325 1033
22177 1044 4304 1033
45383 1044 4304 1033
1593 1395 119685 1166 1153 1046 51228
1119 1048 1052 1056 1032 1055 17391 23216 30203 7785 17279
3337 30757 1902 4200 63073 3671
1225 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1225 1158 1129 1225 1158 1155 1225 1158 1133 1225 21359 1225 1158 1137
1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 1319 11234 1873 26303 1455 1934 2246 3754 10835 1041
22177
45383
1032 45383
1256 45383
1293 45383
1293 45383 1010 1293 45383
1319
1010 1376
1039 4033
22177 1044 1404 48054 1033 3075 1584 1636 119685 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749
7290 7290 7290
1051
1051 1051
1051 1051 1051
1051 1051 1051 1051
1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051 1051 1051
1067 59503 28783
3724 4058
1010 1032 1267 1032 4688 1032 17152 1458 29356 1010 1256 1010 1293 1010 1260 1010 1652 1010 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 119685 1166 1153 1240 1159 1166 1153 1032 1051 1032 1051 1051 1032 1051 1051 1051 1032 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1051 1032 1051 1046 1051 1032 1051 1791 1051 1032 1051 2880 1051 71881 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1240 1159 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749 45577 1045 6626 43555 2843 30757 1902 4200 63073 3671 14931 20040 20040 1657 1657 1975 14135 14135 83923 7290 7290 7290 45509 45509 45509 1362 6483 2151 1576 1116 2189 1514 1681 2156 1044 1576 3609 1636 5257 1063 1576 1077 1605 5257 1362 7534 3180 1494 1044 1576 1068 1636 2479 2269 26883 1063 2837 1039 45654 1261 54297 1076

View File

@ -1506,7 +1506,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "llama3" ||
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe"||
tokenizer_pre == "falcon3") {
tokenizer_pre == "falcon3" ||
tokenizer_pre == "pixtral") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;
add_bos = true;