mtmd : add ultravox audio input (#13623)

* convert ok, load ok

* warmup ok

* test

* still does not work?

* fix padding

* temporary give up

* fix merge conflict

* build_ultravox()

* rm test

* fix merge conflict

* add necessary mtmd APIs

* first working version (only 4s of audio)

* will this monster compile?

* fix compile

* please compile

* fPIC

* fix windows

* various fixes

* clean up audio_helpers

* fix conversion

* add some debug stuff

* long audio input ok

* adapt the api

* add --audio arg

* final touch UX

* add miniaudio to readme

* fix typo

* refactor kv metadata

* mtmd_default_marker()
This commit is contained in:
Xuan-Son Nguyen
2025-05-22 20:42:48 +02:00
committed by GitHub
parent ab86335760
commit 797990c4bc
21 changed files with 95401 additions and 259 deletions

View File

@ -48,3 +48,7 @@ end_of_line = unset
charset = unset
trim_trailing_whitespace = unset
insert_final_newline = unset
[tools/mtmd/miniaudio.h]
trim_trailing_whitespace = unset
insert_final_newline = unset

View File

@ -580,3 +580,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain

View File

@ -39,7 +39,7 @@
using json = nlohmann::ordered_json;
std::initializer_list<enum llama_example> mmproj_examples = {
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_SERVER,
};
@ -2233,12 +2233,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image"}, "FILE",
"path to an image file. use with multimodal models. Specify multiple times for batching",
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_LLAVA}));
).set_examples({LLAMA_EXAMPLE_MTMD}));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
@ -2868,7 +2868,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(

View File

@ -76,7 +76,7 @@ enum llama_example {
LLAMA_EXAMPLE_SERVER,
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
LLAMA_EXAMPLE_EXPORT_LORA,
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_MTMD,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,

View File

@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum):
class ModelType(IntEnum):
TEXT = 1
VISION = 2
MMPROJ = 2
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
class ModelBase:
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
ModelType.TEXT: {},
ModelType.VISION: {},
ModelType.MMPROJ: {},
}
dir_model: Path
@ -88,7 +88,7 @@ class ModelBase:
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is VisionModel:
type(self) is MmprojModel:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
@ -309,6 +309,7 @@ class ModelBase:
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
)
)
or not new_name.endswith(".weight")
@ -438,7 +439,7 @@ class ModelBase:
assert names
def func(modelcls: AnyModel) -> AnyModel:
model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT
for name in names:
cls._model_classes[model_type][name] = modelcls
return modelcls
@ -1114,60 +1115,87 @@ class TextModel(ModelBase):
self.gguf_writer.add_pooling_type(pooling_type)
class VisionModel(ModelBase):
model_type = ModelType.VISION
model_arch = gguf.MODEL_ARCH.CLIP_VISION
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
model_arch = gguf.MODEL_ARCH.MMPROJ
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
if self.has_vision_encoder and self.has_audio_encoder:
raise NotImplementedError("both vision + audio not supported yet")
# get n_embd of the text model
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams"
if "vision_config" not in self.hparams:
raise ValueError("vision_config not found in hparams")
# move vision config to the top level, while preserving the original hparams in global_config
self.global_config = self.hparams
self.hparams = self.hparams["vision_config"]
if "vision_config" in self.hparams:
self.hparams = self.hparams["vision_config"]
elif "audio_config" in self.hparams:
self.hparams = self.hparams["audio_config"]
else:
raise ValueError("vision_config / audio_config not found in hparams")
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
self.gguf_writer.add_vision_has_vision_encoder(True)
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
if self.has_vision_encoder:
self.gguf_writer.add_clip_has_vision_encoder(True)
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
# vision config
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
elif self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
# audio config
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_audio_block_count(self.block_count)
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
else:
raise ValueError("MmprojModel must have either vision or audio encoder")
def write_vocab(self):
raise ValueError("VisionModel does not support vocab writing")
raise ValueError("MmprojModel does not support vocab writing")
@ModelBase.register("GPTNeoXForCausalLM")
@ -1951,7 +1979,7 @@ class LlamaModel(TextModel):
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
)
class LlavaVisionModel(VisionModel):
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
def __init__(self, *args, **kwargs):
@ -1977,7 +2005,7 @@ class LlavaVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
@ -2016,7 +2044,7 @@ class LlavaVisionModel(VisionModel):
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
class SmolVLMModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "smolvlm_vision":
@ -2028,7 +2056,7 @@ class SmolVLMModel(VisionModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
@ -2094,10 +2122,10 @@ class Llama4Model(LlamaModel):
@ModelBase.register("Llama4ForConditionalGeneration")
class Llama4VisionModel(VisionModel):
class Llama4VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
assert self.hparams["hidden_act"] == "gelu"
@ -2670,7 +2698,7 @@ class Qwen2VLModel(TextModel):
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLVisionModel(VisionModel):
class Qwen2VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["image_size"] = self.hparams.get("image_size", 560)
@ -2685,9 +2713,9 @@ class Qwen2VLVisionModel(VisionModel):
super().set_gguf_parameters()
hparams = self.hparams
if self.global_config['model_type'] == 'qwen2_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
elif self.global_config['model_type'] == 'qwen2_5_vl':
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
self.gguf_writer.add_vision_use_silu(True)
# find n_wa_pattern (window attention pattern)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
@ -2746,11 +2774,11 @@ class Qwen2VLVisionModel(VisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(VisionModel):
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
# hidden_act
if hparams["hidden_act"] == "silu":
@ -4008,11 +4036,11 @@ class Gemma3Model(TextModel):
@ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(VisionModel):
class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
@ -5959,6 +5987,52 @@ class ChameleonModel(TextModel):
return data_torch
@ModelBase.register("UltravoxModel")
class UltravoxModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
@ModelBase.register("UltravoxModel")
class UltravoxAudioModel(MmprojModel):
has_vision_encoder = False # no vision encoder
has_audio_encoder = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, new_name, n_dims # unused
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# prevent clash naming with vision tensors
if name.startswith("multi_modal_projector"):
name = "audio." + name
if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######
@ -6134,13 +6208,15 @@ def split_str_to_n_bytes(split_str: str) -> int:
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
return arch
@ -6203,7 +6279,7 @@ def main() -> None:
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
hparams = ModelBase.load_hparams(dir_model)
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")

View File

@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
To enable it, can use use one of the 2 methods below:
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
To enable it, you can use one of the 2 methods below:
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
@ -37,6 +39,8 @@ Replaces the `(tool_name)` with the name of binary you want to use. For example,
NOTE: some models may require large context window, for example: `-c 8192`
**Vision models**:
```sh
# Gemma 3
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
@ -78,3 +82,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
# Llama 4 Scout
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
```
**Audio models**:
```sh
# Ultravox 0.5
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
```

View File

@ -219,10 +219,13 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class ClipVision:
class Clip:
PROJECTOR_TYPE = "clip.projector_type"
HAS_VISION_ENCODER = "clip.has_vision_encoder"
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
@ -243,19 +246,33 @@ class Keys:
class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor"
class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length"
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
PROJECTION_DIM = "clip.audio.projection_dim"
BLOCK_COUNT = "clip.audio.block_count"
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
class Projector:
STACK_FACTOR = "clip.audio.projector.stack_factor"
#
# recommended mapping of model tensor names for storage in gguf
#
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
CLIP_VISION = "clip-vision"
MODEL = "model"
ADAPTER = "adapter"
MMPROJ = "mmproj" # dummy, unused for now
class MODEL_ARCH(IntEnum):
CLIP_VISION = auto() # dummy arch for clip.cpp
MMPROJ = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
@ -514,10 +531,27 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1
# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
A_ENC_ATTN_Q = auto()
A_ENC_ATTN_K = auto()
A_ENC_ATTN_V = auto()
A_ENC_INPUT_NORM = auto()
A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto()
A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto()
A_MMPROJ = auto()
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
@ -776,10 +810,27 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.CLIP_VISION: [
MODEL_ARCH.MMPROJ: [
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_MMPROJ_FC,
MODEL_TENSOR.V_MMPROJ_MLP,
@ -819,6 +870,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
MODEL_TENSOR.A_ENC_ATTN_Q,
MODEL_TENSOR.A_ENC_ATTN_K,
MODEL_TENSOR.A_ENC_ATTN_V,
MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
@ -2186,6 +2254,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"

View File

@ -936,12 +936,18 @@ class GGUFWriter:
# for vision models
def add_clip_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
def add_clip_has_audio_encoder(self, value: bool) -> None:
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
def add_clip_projector_type(self, value: str) -> None:
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
def add_vision_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
def add_vision_has_vision_encoder(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
def add_vision_patch_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
@ -957,9 +963,6 @@ class GGUFWriter:
def add_vision_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
def add_vision_projector_type(self, value: str) -> None:
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
def add_vision_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
@ -987,6 +990,32 @@ class GGUFWriter:
def add_vision_n_wa_pattern(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
# audio models
def add_audio_projection_dim(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
def add_audio_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
def add_audio_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
def add_audio_block_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
def add_audio_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
def add_audio_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
def add_audio_num_mel_bins(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
def add_audio_stack_factor(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:

View File

@ -1110,6 +1110,68 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
),
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: (
"audio_tower.embed_positions", # ultravox
),
MODEL_TENSOR.A_ENC_CONV1D: (
"audio_tower.conv{bid}", # ultravox
),
MODEL_TENSOR.A_PRE_NORM: (),
MODEL_TENSOR.A_POST_NORM: (
"audio_tower.layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_Q: (
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_K: (
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
),
MODEL_TENSOR.A_ENC_ATTN_V: (
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
),
MODEL_TENSOR.A_ENC_INPUT_NORM: (
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT: (
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
),
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_UP: (
"audio_tower.layers.{bid}.fc1", # ultravox
),
MODEL_TENSOR.A_ENC_FFN_GATE: (),
MODEL_TENSOR.A_ENC_FFN_DOWN: (
"audio_tower.layers.{bid}.fc2", # ultravox
),
MODEL_TENSOR.A_MMPROJ: (
"audio.multi_modal_projector.linear_{bid}", # ultravox
),
MODEL_TENSOR.A_MM_NORM_PRE: (
"audio.multi_modal_projector.ln_pre", # ultravox
),
MODEL_TENSOR.A_MM_NORM_MID: (
"audio.multi_modal_projector.ln_mid", # ultravox
),
}
# architecture-specific block mappings

View File

@ -1,5 +1,15 @@
# mtmd
# compile mtmd-audio separately to avoid long compile times with miniaudio.h
# TODO @ngxson : move miniaudio.h and stb_image.h to mtmd-helper.cpp, then compile the helper as a separate library
add_library(mtmd_audio STATIC mtmd-audio.cpp mtmd-audio.h)
if (BUILD_SHARED_LIBS)
set_target_properties(mtmd_audio PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
target_link_libraries(mtmd_audio PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(mtmd_audio PRIVATE cxx_std_17)
target_include_directories(mtmd_audio PRIVATE .)
add_library(mtmd OBJECT
mtmd.cpp
mtmd-helper.cpp
@ -9,7 +19,7 @@ add_library(mtmd OBJECT
clip-impl.h
)
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(mtmd PUBLIC .)
target_include_directories(mtmd PRIVATE ../..)
@ -22,12 +32,13 @@ if (BUILD_SHARED_LIBS)
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(mtmd_shared PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS mtmd_shared LIBRARY)
endif()
if (NOT MSVC)
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
target_compile_options(mtmd_audio PRIVATE -Wno-cast-qual) # miniaudio.h
endif()
if(TARGET BUILD_INFO)

View File

@ -16,22 +16,26 @@
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
#define KEY_USE_GELU "clip.use_gelu"
#define KEY_USE_SILU "clip.use_silu"
#define KEY_N_EMBD "clip.vision.embedding_length"
#define KEY_N_FF "clip.vision.feed_forward_length"
#define KEY_N_BLOCK "clip.vision.block_count"
#define KEY_N_HEAD "clip.vision.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
#define KEY_PROJ_DIM "clip.vision.projection_dim"
#define KEY_N_EMBD "clip.%s.embedding_length"
#define KEY_N_FF "clip.%s.feed_forward_length"
#define KEY_N_BLOCK "clip.%s.block_count"
#define KEY_PROJ_DIM "clip.%s.projection_dim"
#define KEY_N_HEAD "clip.%s.attention.head_count"
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
@ -39,13 +43,18 @@
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
// audio-specific
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
//
// tensor name constants
//
#define TN_POS_EMBD "v.position_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
@ -95,6 +104,12 @@
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
// ultravox
#define TN_CONV1D "a.conv1d.%d.%s"
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
@ -110,6 +125,7 @@ enum projector_type {
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_ULTRAVOX,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_UNKNOWN,
@ -126,6 +142,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
};
@ -147,8 +164,10 @@ struct clip_image_u8 {
std::vector<uint8_t> buf;
};
// RGB float32 image (NHWC)
// Memory layout: RGBRGBRGB...
// For images, buf.size() == nx*ny*3
// Memory layout: RGBRGBRGB...
// For audio, only one channel is used, buf.size() == nx*ny
// nx will be n_frames and ny will be n_mel
struct clip_image_f32 {
int nx;
int ny;
@ -242,6 +261,7 @@ struct clip_image_u8_batch {
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
bool is_audio = false;
// for llava-uhd style models, we need to know the grid size
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
@ -249,7 +269,12 @@ struct clip_image_f32_batch {
int grid_y = 0;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
clip_image_f32_batch new_batch{
/* entries */ {},
/* is_audio */ is_audio,
/* grid_x */ grid_x,
/* grid_y */ grid_y,
};
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));

View File

@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
enum ffn_op_type {
FFN_GELU,
FFN_GELU_ERF,
FFN_SILU,
FFN_GELU_QUICK,
};
@ -165,6 +166,9 @@ enum patch_merge_type {
};
struct clip_hparams {
bool has_vision = false;
bool has_audio = false;
int32_t image_size;
int32_t patch_size;
int32_t n_embd;
@ -191,6 +195,10 @@ struct clip_hparams {
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
// audio
int32_t n_mel_bins = 0; // whisper preprocessor
int32_t proj_stack_factor = 0; // ultravox
};
struct clip_layer {
@ -332,6 +340,14 @@ struct clip_vision_model {
// pixtral
ggml_tensor * token_embd_img_break = nullptr;
ggml_tensor * mm_patch_merger_w = nullptr;
// ultravox / whisper encoder
ggml_tensor * conv1d_1_w = nullptr;
ggml_tensor * conv1d_1_b = nullptr;
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
};
struct clip_ctx {
@ -1408,6 +1424,104 @@ struct clip_graph {
return gf;
}
// whisper encoder with custom projector
ggml_cgraph * build_whisper_enc() {
const int n_frames = img.nx;
const int n_pos = n_frames / 2;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
ggml_tensor * inp = build_inp_raw(1);
// conv1d block
{
// convolution + gelu
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
cur = ggml_gelu_erf(ctx0, cur);
// transpose
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cb(inp, "after_conv1d", -1);
}
// sanity check (only check one layer, but it should be the same for all)
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
model.position_embeddings->ne[0], n_pos,
model.position_embeddings->nb[1], 0
);
ggml_tensor * cur = build_vit(
inp, n_pos,
NORM_TYPE_NORMAL,
hparams.ffn_op,
pos_embd_selected,
nullptr);
cb(cur, "after_transformer", -1);
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
}
cb(cur, "after_stacked", -1);
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1);
cur = ggml_mul(ctx0, x0, x1);
}
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
cb(cur, "projected", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
private:
//
// utility functions
@ -1562,8 +1676,8 @@ private:
return inp;
}
ggml_tensor * build_inp_raw() {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
ggml_tensor * build_inp_raw(int channels = 3) {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
@ -1641,6 +1755,11 @@ private:
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
} break;
case FFN_GELU_ERF:
{
cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il);
} break;
case FFN_GELU_QUICK:
{
cur = ggml_gelu_quick(ctx0, cur);
@ -1832,6 +1951,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
res = graph.build_whisper_enc();
} break;
default:
{
res = graph.build_llava();
@ -1915,18 +2038,30 @@ struct clip_model_loader {
// other hparams
{
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
get_u32(KEY_N_EMBD, hparams.n_embd);
get_u32(KEY_N_HEAD, hparams.n_head);
get_u32(KEY_N_FF, hparams.n_ff);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
const char * prefix = hparams.has_vision ? "vision" : "audio";
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
if (hparams.has_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
} else if (hparams.has_audio) {
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
} else {
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
}
// default warmup value
hparams.warmup_image_size = hparams.image_size;
@ -1964,7 +2099,7 @@ struct clip_model_loader {
}
}
{
if (hparams.has_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
@ -2050,30 +2185,43 @@ struct clip_model_loader {
isize, isize*3, // 336, 1008
};
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
}
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
} break;
default:
break;
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("\n");
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
LOG_INF("\n");
if (hparams.has_vision) {
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
} else if (hparams.has_audio) {
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
}
LOG_INF("\n");
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
}
}
@ -2082,6 +2230,9 @@ struct clip_model_loader {
std::map<std::string, size_t> tensor_offset;
std::vector<ggml_tensor *> tensors_to_load;
// TODO @ngxson : support both audio and video in the future
const char * prefix = hparams.has_audio ? "a" : "v";
// get offsets
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
@ -2119,47 +2270,47 @@ struct clip_model_loader {
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
// layers
vision_model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = vision_model.layers[il];
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
// ffn
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
@ -2301,6 +2452,17 @@ struct clip_model_loader {
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
case PROJECTOR_TYPE_ULTRAVOX:
{
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@ -2358,13 +2520,19 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
const auto & hparams = ctx_clip.vision_model.hparams;
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
if (hparams.has_vision) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
} else {
img->nx = 1024; // TODO @ngxson : use a better default
img->ny = hparams.n_mel_bins;
}
img->buf.resize(img->nx * img->ny * 3);
batch.entries.push_back(std::move(img));
@ -3278,6 +3446,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
n_patches /= (scale_factor * scale_factor);
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches = n_len / proj_stack_factor / 2;
}
return n_patches;
@ -3435,7 +3607,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
};
// set input pixel values
{
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx * img->ny * 3;
@ -3472,6 +3644,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
}
set_input_f32("inp_raw", inp_raw);
} else {
// audio input
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const int n_step = mel_inp->nx;
const int n_mel = mel_inp->ny;
std::vector<float> inp_raw(n_step * n_mel);
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
set_input_f32("inp_raw", inp_raw);
}
// set input per projector
@ -3668,6 +3850,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_ULTRAVOX:
{
// do nothing
} break;
@ -3766,6 +3949,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
case PROJECTOR_TYPE_LLAMA4:
@ -3798,6 +3983,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_vision;
}
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.has_audio;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
@ -3818,3 +4011,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
return ctx->proj_type;
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
clip_image_f32 * audio = new clip_image_f32;
audio->nx = n_frames;
audio->ny = n_mel;
audio->buf.resize(n_frames * n_mel);
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}

View File

@ -93,3 +93,9 @@ bool clip_is_llava(const struct clip_ctx * ctx);
bool clip_is_gemma3(const struct clip_ctx * ctx);
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
// use by audio input
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);

93468
tools/mtmd/miniaudio.h Normal file

File diff suppressed because it is too large Load Diff

855
tools/mtmd/mtmd-audio.cpp Normal file
View File

@ -0,0 +1,855 @@
// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include "mtmd-audio.h"
//#define MTMD_AUDIO_DEBUG
#define MINIAUDIO_IMPLEMENTATION
#ifndef MTMD_AUDIO_DEBUG
# define MA_NO_ENCODING
#endif
#define MA_NO_DEVICE_IO
#define MA_NO_RESOURCE_MANAGER
#define MA_NO_NODE_GRAPH
#define MA_NO_ENGINE
#define MA_NO_GENERATION
#define MA_API static
#include "miniaudio.h"
#define _USE_MATH_DEFINES // for M_PI
#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <vector>
#include <fstream>
#include <algorithm>
// most of the code here is copied from whisper.cpp
// align x to upper multiple of n
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
namespace whisper_preprocessor {
#define SIN_COS_N_COUNT WHISPER_N_FFT
namespace {
struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// Hann window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
float hann_window[WHISPER_N_FFT];
whisper_global_cache() {
fill_sin_cos_table();
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
}
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const float* in, int N, float* out) {
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(float* in, int N, float* out) {
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
const int half_N = N / 2;
if (N - half_N*2 == 1) {
dft(in, N, out);
return;
}
float* even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
}
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
float* odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
}
float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft = filters.n_fft;
int i = ith;
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
//const int64_t t_start_us = ggml_time_us();
// Hann window
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
const float * hann = global_cache.hann_window;
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
}
bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output) {
if (n_samples == 0) {
// empty audio
return false;
}
whisper_mel out_full;
bool ok = log_mel_spectrogram(
samples,
n_samples,
COMMON_SAMPLE_RATE,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
filters.n_mel,
4, // n_threads
filters,
false, // debug
out_full);
if (!ok) {
return false;
}
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
// we always expect the mel to have 3000 silent frames at the end
// printf("n_len %d\n", out_full.n_len);
const size_t frames_per_chunk = 3000;
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
if ((size_t)n_len < frames_per_chunk) {
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
}
whisper_mel out_chunk;
out_chunk.n_len = n_len;
out_chunk.n_mel = out_full.n_mel;
out_chunk.n_len_org = out_full.n_mel; // unused
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
for (int i = 0; i < out_full.n_mel; i++) {
auto src = out_full.data.begin() + i*out_full.n_len + off;
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
}
output.push_back(std::move(out_chunk));
}
return true;
}
} // namespace whisper_preprocessor
namespace audio_helpers {
bool is_audio_file(const char * buf, size_t len) {
if (len < 12) {
return false;
}
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
bool is_mp3 = len >= 3 && (
memcmp(buf, "ID3", 3) == 0 ||
// Check for MPEG sync word (simplified check)
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
);
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
return is_wav || is_mp3 || is_flac;
}
// returns true if the buffer is a valid audio file
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
ma_result result;
const int channels = 1;
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
ma_decoder decoder;
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
if (result != MA_SUCCESS) {
return false;
}
ma_uint64 frame_count;
ma_uint64 frames_read;
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
pcmf32_mono.resize(frame_count);
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
if (result != MA_SUCCESS) {
ma_decoder_uninit(&decoder);
return false;
}
#ifdef MTMD_AUDIO_DEBUG
// save audio to wav file
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
ma_encoder encoder;
ma_encoder_init_file("output.wav", &config, &encoder);
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
ma_encoder_uninit(&encoder);
#endif
ma_decoder_uninit(&decoder);
return true;
}
} // namespace wav_utils
// precalculated mel filter banks
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
//
// generated from python code:
//
// from numpy import load
// data = load('mel_filters.npz')
// lst = data.files
// for item in lst:
// print(item)
// print(data[item].shape)
// n_mel = data[item].shape[0]
// n_fft = data[item].shape[1]
// for i, row in enumerate(data[item]):
// for j, val in enumerate(row):
// val = val * 1000.0
// if val != 0:
// print(f"data[{i*n_fft + j}] = {val:.6f};")
namespace whisper_precalc_filters {
whisper_preprocessor::whisper_filters get_128_bins() {
whisper_preprocessor::whisper_filters filters;
filters.n_mel = 128;
filters.n_fft = 201;
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
data[1] = 12.37398665;
data[202] = 30.39256483;
data[404] = 24.74797331;
data[605] = 18.01857911;
data[807] = 37.12195903;
data[1008] = 5.64459199;
data[1009] = 6.72939420;
data[1210] = 36.03715822;
data[1412] = 19.10337992;
data[1613] = 23.66316877;
data[1815] = 31.47736564;
data[2016] = 11.28918398;
data[2017] = 1.08480197;
data[2218] = 41.68175161;
data[2420] = 13.45878839;
data[2621] = 29.30776216;
data[2823] = 25.83277412;
data[3024] = 16.93377644;
data[3226] = 38.20675984;
data[3427] = 4.55979025;
data[3428] = 7.81419594;
data[3629] = 34.95235741;
data[3831] = 20.18818259;
data[4032] = 22.57836796;
data[4234] = 32.56217018;
data[4435] = 10.20438317;
data[4436] = 2.16960395;
data[4637] = 40.59694707;
data[4839] = 14.54358920;
data[5040] = 28.22295949;
data[5242] = 26.91757679;
data[5443] = 15.84897563;
data[5645] = 39.29156065;
data[5846] = 3.47498828;
data[5847] = 8.89899861;
data[6048] = 33.86755288;
data[6250] = 21.27298526;
data[6451] = 21.49356715;
data[6653] = 33.64697099;
data[6854] = 9.11958050;
data[6855] = 3.25440569;
data[7056] = 39.51214626;
data[7258] = 15.62839188;
data[7459] = 27.13815868;
data[7661] = 28.00237760;
data[7862] = 14.76417296;
data[8064] = 40.37636518;
data[8265] = 2.38068704;
data[8266] = 10.20263787;
data[8467] = 31.61146119;
data[8669] = 24.54700135;
data[8870] = 15.32919332;
data[8871] = 1.66583748;
data[9072] = 36.72905266;
data[9274] = 20.09709924;
data[9475] = 16.93102531;
data[9476] = 2.90265540;
data[9677] = 32.84499049;
data[9879] = 23.52004871;
data[10080] = 11.03894413;
data[10081] = 10.72582975;
data[10282] = 22.71829173;
data[10484] = 32.27872774;
data[10685] = 0.11626833;
data[10686] = 22.85348251;
data[10887] = 8.56344029;
data[10888] = 14.97978810;
data[11089] = 15.51398356;
data[11090] = 8.51490628;
data[11291] = 21.10680379;
data[11292] = 3.32652032;
data[11493] = 25.47064796;
data[11695] = 27.35907957;
data[11896] = 0.65853616;
data[11897] = 23.83812517;
data[12098] = 3.44359246;
data[12099] = 21.22455277;
data[12300] = 5.35842171;
data[12301] = 19.42555793;
data[12502] = 6.49324711;
data[12503] = 18.35542172;
data[12704] = 6.93138083;
data[12705] = 17.93504693;
data[12906] = 6.74968259;
data[12907] = 18.09151843;
data[13108] = 6.01899112;
data[13109] = 18.75767298;
data[13310] = 4.80452832;
data[13311] = 19.87172849;
data[13512] = 3.16627859;
data[13513] = 21.37690969;
data[13514] = 1.25317345;
data[13714] = 1.15934468;
data[13715] = 20.80361731;
data[13716] = 4.04486805;
data[13917] = 17.55363122;
data[13918] = 7.08320038;
data[14119] = 14.07538634;
data[14120] = 10.32655034;
data[14321] = 10.40921453;
data[14322] = 13.73696327;
data[14523] = 6.59187697;
data[14524] = 17.27988198;
data[14525] = 1.46804214;
data[14725] = 2.65681883;
data[14726] = 18.09193194;
data[14727] = 5.85655728;
data[14928] = 13.34277913;
data[14929] = 10.28267574;
data[15130] = 8.56800377;
data[15131] = 14.72230814;
data[15132] = 1.04039861;
data[15332] = 3.79085587;
data[15333] = 17.14678481;
data[15334] = 6.11609267;
data[15535] = 11.75929047;
data[15536] = 11.13393717;
data[15737] = 6.43857848;
data[15738] = 16.07806236;
data[15739] = 4.23917221;
data[15939] = 1.19989377;
data[15940] = 12.75671553;
data[15941] = 9.65298992;
data[16142] = 7.06935255;
data[16143] = 14.94054683;
data[16144] = 4.19024844;
data[16344] = 1.51483389;
data[16345] = 12.00899947;
data[16346] = 9.84823331;
data[16547] = 6.10224018;
data[16548] = 15.33857174;
data[16549] = 5.57676842;
data[16749] = 0.36827257;
data[16750] = 9.89749376;
data[16751] = 11.35340426;
data[16752] = 2.05122307;
data[16952] = 3.89297144;
data[16953] = 12.97352277;
data[16954] = 8.06631614;
data[17155] = 6.74493238;
data[17156] = 13.85874674;
data[17157] = 5.41190524;
data[17357] = 0.74220158;
data[17358] = 8.98779090;
data[17359] = 11.37871388;
data[17360] = 3.32958088;
data[17560] = 2.82313535;
data[17561] = 10.68049297;
data[17562] = 9.43340641;
data[17563] = 1.76325557;
data[17763] = 4.39018616;
data[17764] = 11.87758986;
data[17765] = 7.97005836;
data[17766] = 0.66104700;
data[17966] = 5.49466675;
data[17967] = 12.62953598;
data[17968] = 6.93987962;
data[18169] = 6.18401915;
data[18170] = 12.93473132;
data[18171] = 6.29778765;
data[18371] = 0.02325210;
data[18372] = 6.50206627;
data[18373] = 12.32661773;
data[18374] = 6.00216538;
data[18574] = 0.31548753;
data[18575] = 6.48925547;
data[18576] = 12.04130240;
data[18577] = 6.01462880;
data[18777] = 0.29979556;
data[18778] = 6.18288014;
data[18779] = 12.04272825;
data[18780] = 6.29981188;
data[18781] = 0.55689598;
data[18980] = 0.01120471;
data[18981] = 5.61729167;
data[18982] = 11.22337859;
data[18983] = 6.82516303;
data[18984] = 1.35264499;
data[19184] = 4.82410006;
data[19185] = 10.16623247;
data[19186] = 7.56075513;
data[19187] = 2.34590308;
data[19387] = 3.83235747;
data[19388] = 8.92296247;
data[19389] = 8.47910438;
data[19390] = 3.50978645;
data[19590] = 2.66873185;
data[19591] = 7.51965167;
data[19592] = 9.55500547;
data[19593] = 4.81966138;
data[19594] = 0.08431751;
data[19793] = 1.35767367;
data[19794] = 5.98019501;
data[19795] = 10.60271543;
data[19796] = 6.25298498;
data[19797] = 1.74059917;
data[19997] = 4.32644226;
data[19998] = 8.73131864;
data[19999] = 7.78916525;
data[20000] = 3.48923868;
data[20200] = 2.57835095;
data[20201] = 6.77582854;
data[20202] = 9.40941647;
data[20203] = 5.31194592;
data[20204] = 1.21447595;
data[20403] = 0.75411191;
data[20404] = 4.75395704;
data[20405] = 8.75380263;
data[20406] = 7.19209015;
data[20407] = 3.28754401;
data[20607] = 2.68179690;
data[20608] = 6.49331464;
data[20609] = 9.11457930;
data[20610] = 5.39387390;
data[20611] = 1.67316827;
data[20810] = 0.57394296;
data[20811] = 4.20600036;
data[20812] = 7.83805829;
data[20813] = 7.52023002;
data[20814] = 3.97470826;
data[20815] = 0.42918732;
data[21014] = 1.90464477;
data[21015] = 5.36569161;
data[21016] = 8.82673822;
data[21017] = 6.27609482;
data[21018] = 2.89750961;
data[21218] = 2.89885257;
data[21219] = 6.19694078;
data[21220] = 8.56699049;
data[21221] = 5.34748193;
data[21222] = 2.12797290;
data[21421] = 0.44750227;
data[21422] = 3.59030394;
data[21423] = 6.73310598;
data[21424] = 7.77023612;
data[21425] = 4.70231380;
data[21426] = 1.63439126;
data[21625] = 1.01536023;
data[21626] = 4.01018746;
data[21627] = 7.00501446;
data[21628] = 7.23442994;
data[21629] = 4.31095669;
data[21630] = 1.38748321;
data[21829] = 1.33348850;
data[21830] = 4.18730825;
data[21831] = 7.04112789;
data[21832] = 6.93188375;
data[21833] = 4.14605811;
data[21834] = 1.36023236;
data[22033] = 1.42879714;
data[22034] = 4.14824858;
data[22035] = 6.86769979;
data[22036] = 6.83705276;
data[22037] = 4.18239459;
data[22038] = 1.52773573;
data[22237] = 1.32610439;
data[22238] = 3.91751388;
data[22239] = 6.50892360;
data[22240] = 6.92639686;
data[22241] = 4.39672917;
data[22242] = 1.86706171;
data[22441] = 1.04827771;
data[22442] = 3.51767405;
data[22443] = 5.98707050;
data[22444] = 7.17824046;
data[22445] = 4.76767914;
data[22446] = 2.35711760;
data[22645] = 0.61636406;
data[22646] = 2.96949223;
data[22647] = 5.32262027;
data[22648] = 7.57265091;
data[22649] = 5.27558755;
data[22650] = 2.97852419;
data[22651] = 0.68146095;
data[22849] = 0.04971400;
data[22850] = 2.29204819;
data[22851] = 4.53438237;
data[22852] = 6.77671656;
data[22853] = 5.90240723;
data[22854] = 3.71349836;
data[22855] = 1.52458926;
data[23054] = 1.50285335;
data[23055] = 3.63961048;
data[23056] = 5.77636715;
data[23057] = 6.63159089;
data[23058] = 4.54574358;
data[23059] = 2.45989650;
data[23060] = 0.37404924;
data[23258] = 0.61795861;
data[23259] = 2.65410915;
data[23260] = 4.69025923;
data[23261] = 6.72641024;
data[23262] = 5.46034705;
data[23263] = 3.47270933;
data[23264] = 1.48507138;
data[23463] = 1.59233576;
data[23464] = 3.53261665;
data[23465] = 5.47289755;
data[23466] = 6.44368259;
data[23467] = 4.54962999;
data[23468] = 2.65557761;
data[23469] = 0.76152512;
data[23667] = 0.46749352;
data[23668] = 2.31641904;
data[23669] = 4.16534441;
data[23670] = 6.01426978;
data[23671] = 5.67844696;
data[23672] = 3.87357362;
data[23673] = 2.06870004;
data[23674] = 0.26382666;
data[23872] = 1.05349103;
data[23873] = 2.81536230;
data[23874] = 4.57723346;
data[23875] = 6.33910485;
data[23876] = 5.12815686;
data[23877] = 3.40826320;
data[23878] = 1.68837002;
data[24077] = 1.43350090;
data[24078] = 3.11241671;
data[24079] = 4.79133241;
data[24080] = 6.40943693;
data[24081] = 4.77052201;
data[24082] = 3.13160778;
data[24083] = 1.49269309;
data[24281] = 0.02932359;
data[24282] = 1.62918994;
data[24283] = 3.22905602;
data[24284] = 4.82892245;
data[24285] = 6.14671456;
data[24286] = 4.58496623;
data[24287] = 3.02321767;
data[24288] = 1.46146910;
data[24486] = 0.13601698;
data[24487] = 1.66055572;
data[24488] = 3.18509457;
data[24489] = 4.70963307;
data[24490] = 6.04072399;
data[24491] = 4.55250870;
data[24492] = 3.06429295;
data[24493] = 1.57607743;
data[24494] = 0.08786193;
data[24691] = 0.09328097;
data[24692] = 1.54603878;
data[24693] = 2.99879676;
data[24694] = 4.45155473;
data[24695] = 5.90431225;
data[24696] = 4.65566106;
data[24697] = 3.23751615;
data[24698] = 1.81937125;
data[24699] = 0.40122634;
data[24897] = 1.30262633;
data[24898] = 2.68698297;
data[24899] = 4.07133950;
data[24900] = 5.45569602;
data[24901] = 4.87832492;
data[24902] = 3.52695142;
data[24903] = 2.17557792;
data[24904] = 0.82420459;
data[25102] = 0.94595028;
data[25103] = 2.26512621;
data[25104] = 3.58430226;
data[25105] = 4.90347855;
data[25106] = 5.20569785;
data[25107] = 3.91795207;
data[25108] = 2.63020652;
data[25109] = 1.34246063;
data[25110] = 0.05471494;
data[25307] = 0.49037894;
data[25308] = 1.74744334;
data[25309] = 3.00450763;
data[25310] = 4.26157191;
data[25311] = 5.51863620;
data[25312] = 4.39707236;
data[25313] = 3.16995848;
data[25314] = 1.94284460;
data[25315] = 0.71573065;
data[25513] = 1.14698056;
data[25514] = 2.34485767;
data[25515] = 3.54273478;
data[25516] = 4.74061165;
data[25517] = 4.95198462;
data[25518] = 3.78264743;
data[25519] = 2.61331047;
data[25520] = 1.44397374;
data[25521] = 0.27463681;
data[25718] = 0.47569509;
data[25719] = 1.61717169;
data[25720] = 2.75864848;
data[25721] = 3.90012516;
data[25722] = 5.04160160;
data[25723] = 4.45712078;
data[25724] = 3.34284059;
data[25725] = 2.22856039;
data[25726] = 1.11428020;
for (auto & val : data) {
val /= 1000.0f;
}
filters.data = std::move(data);
return filters;
}
} // namespace whisper_precalc_filters

62
tools/mtmd/mtmd-audio.h Normal file
View File

@ -0,0 +1,62 @@
#pragma once
#include "ggml.h"
#include <cstdint>
#include <vector>
#include <string>
#define WHISPER_ASSERT GGML_ASSERT
#define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400
#define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30
#define COMMON_SAMPLE_RATE 16000
namespace whisper_preprocessor {
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
extern bool preprocess_audio(
const float * samples,
size_t n_samples,
const whisper_filters & filters,
std::vector<whisper_mel> & output);
} // namespace whisper_preprocessor
// TODO @ngxson : move this helper to mtmd-helpers.cpp
namespace audio_helpers {
extern bool is_audio_file(const char * buf, size_t len);
extern bool decode_audio_from_buf(
const unsigned char * buf_in,
size_t len,
int target_sampler_rate,
std::vector<float> & pcmf32_mono);
} // namespace audio_helpers
namespace whisper_precalc_filters {
extern whisper_preprocessor::whisper_filters get_128_bins();
} // namespace whisper_precalc_filters

View File

@ -37,10 +37,10 @@ static volatile bool g_is_interrupted = false;
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for multimodal\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" -hf user/repo can replace both -m and --mmproj in most cases\n"
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
argv[0]
);
@ -142,7 +142,7 @@ struct mtmd_cli_context {
);
}
bool load_image(const std::string & fname) {
bool load_media(const std::string & fname) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
if (!bmp.ptr) {
return false;
@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
return 1;
}
@ -283,14 +283,14 @@ int main(int argc, char ** argv) {
if (is_single_turn) {
g_is_generating = true;
if (params.prompt.find("<__image__>") == std::string::npos) {
params.prompt += " <__image__>";
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
params.prompt += mtmd_default_marker();
}
common_chat_msg msg;
msg.role = "user";
msg.content = params.prompt;
for (const auto & image : params.image) {
if (!ctx.load_image(image)) {
if (!ctx.load_media(image)) {
return 1; // error is already printed by libmtmd
}
}
@ -303,7 +303,12 @@ int main(int argc, char ** argv) {
} else {
LOG("\n Running in chat mode, available commands:");
LOG("\n /image <path> load an image");
if (mtmd_support_vision(ctx.ctx_vision.get())) {
LOG("\n /image <path> load an image");
}
if (mtmd_support_audio(ctx.ctx_vision.get())) {
LOG("\n /audio <path> load an audio");
}
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
@ -333,15 +338,17 @@ int main(int argc, char ** argv) {
continue;
}
g_is_generating = true;
if (line == "/image" || line.find("/image ") == 0) {
bool is_image = line == "/image" || line.find("/image ") == 0;
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
if (is_image || is_audio) {
if (line.size() < 8) {
LOG_ERR("ERR: Missing image filename\n");
LOG_ERR("ERR: Missing media filename\n");
continue;
}
std::string image = line.substr(7);
if (ctx.load_image(image)) {
LOG("Image %s loaded\n", image.c_str());
content += "<__image__>";
std::string media_path = line.substr(7);
if (ctx.load_media(media_path)) {
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
content += mtmd_default_marker();
}
// else, error is already printed by libmtmd
continue;

View File

@ -149,13 +149,10 @@ int32_t mtmd_helper_decode_image_chunk(
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
return -1;
}
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (!image_tokens) {
LOG_ERR("failed to decode image chunk: image tokens are null\n");
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
return -1;
}
@ -163,15 +160,23 @@ int32_t mtmd_helper_decode_image_chunk(
int n_mmproj_embd = llama_model_n_embd(model);
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
if (mtmd_decode_use_mrope(ctx)) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
return -1;
}
if (!image_tokens) {
LOG_ERR("failed to decode chunk: image tokens are null\n");
return -1;
}
const int nx = mtmd_image_tokens_get_nx(image_tokens);
const int ny = mtmd_image_tokens_get_ny(image_tokens);
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
} else {
batch_embd.set_position_normal(n_past, seq_id);
@ -187,22 +192,22 @@ int32_t mtmd_helper_decode_image_chunk(
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
int64_t t1 = ggml_time_ms();
int32_t ret = llama_decode(lctx, batch_embd_view);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_set_causal_attn(lctx, true); // restore causal attn
return ret;
}
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
i_batch++;
}
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
n_past += mtmd_input_chunk_get_n_pos(chunk);
*new_n_past = n_past;
if (mtmd_decode_use_non_causal(ctx)) {
@ -253,25 +258,25 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
*new_n_past += text_batch.n_tokens;
}
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
int64_t t0 = ggml_time_ms();
LOG_INF("encoding image or slice...\n");
LOG_INF("encoding %s slice...\n", name);
ret = mtmd_encode(ctx, image_tokens);
ret = mtmd_encode_chunk(ctx, chunk);
if (ret != 0) {
LOG_ERR("failed to encode image\n");
LOG_ERR("failed to encode %s slice\n", name);
llama_batch_free(text_batch);
return ret;
}
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
float * embd = mtmd_get_output_embd(ctx);
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
if (ret != 0) {
LOG_ERR("failed to decode image\n");
LOG_ERR("failed to decode %s\n", name);
llama_batch_free(text_batch);
return ret;
}

View File

@ -1,6 +1,7 @@
#include "clip.h"
#include "clip-impl.h"
#include "mtmd.h"
#include "mtmd-audio.h"
#include "llama.h"
@ -19,17 +20,49 @@ struct mtmd_bitmap {
uint32_t ny;
std::vector<unsigned char> data;
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
bool is_audio = false; // true if the bitmap is audio
};
struct mtmd_image_tokens_deleter {
void operator()(mtmd_image_tokens * val); // forward declaration
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens>;
struct mtmd_audio_tokens {
uint32_t n_tokens; // number of tokens
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_audio_tokens clone() {
return mtmd_audio_tokens{
n_tokens,
batch_f32.clone(),
id
};
}
};
using mtmd_audio_tokens_ptr = std::unique_ptr<mtmd_audio_tokens>;
struct mtmd_input_chunk {
mtmd_input_chunk_type type;
std::vector<llama_token> tokens_text;
mtmd_image_tokens_ptr tokens_image;
mtmd_audio_tokens_ptr tokens_audio;
};
struct mtmd_input_chunks {
@ -46,6 +79,10 @@ enum mtmd_slice_tmpl {
// TODO @ngxson : add support for idefics (SmolVLM)
};
const char * mtmd_default_marker() {
return "<__media__>";
}
mtmd_context_params mtmd_context_params_default() {
mtmd_context_params params;
params.use_gpu = true;
@ -53,6 +90,7 @@ mtmd_context_params mtmd_context_params_default() {
params.n_threads = 4;
params.verbosity = GGML_LOG_LEVEL_INFO;
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
params.media_marker = mtmd_default_marker();
return params;
}
@ -63,7 +101,9 @@ struct mtmd_context {
bool print_timings;
int n_threads;
std::string image_marker;
std::string media_marker;
bool has_vision;
bool has_audio;
// for llava-uhd style models, we need special tokens in-between slices
// minicpmv calls them "slices", llama 4 calls them "tiles"
@ -81,6 +121,9 @@ struct mtmd_context {
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
// for whisper, we pre-calculate the mel filter bank
whisper_preprocessor::whisper_filters w_filters;
// TODO @ngxson : add timings
mtmd_context(const char * mmproj_fname,
@ -89,8 +132,12 @@ struct mtmd_context {
text_model (text_model),
print_timings(ctx_params.print_timings),
n_threads (ctx_params.n_threads),
image_marker (ctx_params.image_marker)
media_marker (ctx_params.media_marker)
{
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
}
clip_context_params ctx_clip_params;
ctx_clip_params.use_gpu = ctx_params.use_gpu;
ctx_clip_params.verbosity = ctx_params.verbosity;
@ -99,7 +146,9 @@ struct mtmd_context {
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
}
use_mrope = clip_is_qwen2vl(ctx_clip);
has_vision = clip_has_vision_encoder(ctx_clip);
has_audio = clip_has_audio_encoder(ctx_clip);
use_mrope = clip_is_qwen2vl(ctx_clip);
projector_type proj = clip_get_projector_type(ctx_clip);
int minicpmv_version = clip_is_minicpmv(ctx_clip);
@ -146,6 +195,21 @@ struct mtmd_context {
tok_row_end_trail = true; // add trailing end-of-row token
ov_img_first = false; // overview image is last
}
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// TODO @ngxson : check if model n_mel is 128 or 80
w_filters = whisper_precalc_filters::get_128_bins();
}
// warning messages
if (proj == PROJECTOR_TYPE_LLAMA4) {
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
}
if (has_audio) {
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
}
}
~mtmd_context() {
@ -179,29 +243,6 @@ private:
}
};
struct mtmd_image_tokens_data {
clip_image_f32_batch batch_f32; // preprocessed image patches
};
struct mtmd_image_tokens {
uint32_t nx; // number of tokens in x direction
uint32_t ny; // number of tokens in y direction
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
uint32_t n_tokens() const { return nx * ny; }
clip_image_f32_batch batch_f32; // preprocessed image patches
std::string id; // optional user-defined ID, useful for KV cache tracking
mtmd_image_tokens clone() {
return mtmd_image_tokens{
nx,
ny,
use_mrope_pos,
batch_f32.clone(),
id
};
}
};
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
const struct llama_model * text_model,
const struct mtmd_context_params ctx_params) {
@ -247,59 +288,63 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
auto vocab = llama_model_get_vocab(ctx->text_model);
std::string prompt_modified(text->text);
std::string marker_modified(ctx->image_marker);
std::string marker_modified(ctx->media_marker);
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
// for compatibility, we convert image marker to media marker
string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
// a bit hacky here, but works for now
// for some models, we need to add prefix and suffix to the image embeddings
if (clip_is_gemma3(ctx->ctx_clip)) {
// gemma 3
// <start_of_image> ... (image embeddings) ... <end_of_image>
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = ctx->media_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
// (more details in mtmd_context constructor)
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
// <img> ... (image embeddings) ... </img>
marker_modified = "<img>" + ctx->image_marker + "</img>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
marker_modified = "<img>" + ctx->media_marker + "</img>";
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
}
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
output->entries.clear();
output->entries.reserve(parts.size());
size_t i_img = 0;
size_t i_bm = 0;
// utility for adding raw tokens
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
};
@ -317,8 +362,9 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks.emplace_back(std::move(chunk));
}
@ -336,24 +382,36 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
if (&parts.back() != &part) {
// add image token to middle of 2 parts
// only add image/audio tokens to middle of 2 parts
// therefore, we skip handling image/audio if this is the last part
if (&parts.back() == &part) {
continue;
}
if (i_img >= n_bitmaps) {
if (!bitmaps[i_bm]->is_audio) {
// handle image
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_vision) {
LOG_ERR("%s: error: model does not support vision input\n", __func__);
return 2;
}
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->nx = bitmaps[i_img]->nx;
img_u8->ny = bitmaps[i_img]->ny;
img_u8->buf.resize(bitmaps[i_img]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
img_u8->nx = bitmaps[i_bm]->nx;
img_u8->ny = bitmaps[i_bm]->ny;
img_u8->buf.resize(bitmaps[i_bm]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
// preprocess image
clip_image_f32_batch batch_f32;
@ -370,7 +428,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
) {
// split batch into chunks of single images
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
GGML_ASSERT(chunks.size() > 0);
auto ov_chunk = std::move(chunks.front());
@ -446,7 +504,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
image_tokens->ny = 1;
}
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img]->id; // optional
image_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
@ -454,23 +512,101 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
output->entries.emplace_back(std::move(chunk));
}
i_img++; // move to next image
i_bm++; // move to next image
continue;
} else {
// handle audio
if (i_bm >= n_bitmaps) {
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
return 1;
}
if (!ctx->has_audio) {
LOG_ERR("%s: error: model does not support audio input\n", __func__);
return 2;
}
if (bitmaps[i_bm]->data.size() == 0) {
LOG_ERR("%s: error: empty audio data\n", __func__);
return 2;
}
// preprocess audio
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
const float * samples = (const float *)bitmaps[i_bm]->data.data();
size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
if (!ok) {
LOG_ERR("Unable to preprocess audio\n");
return 2;
}
// consider each mel_spec as a separate audio chunk
// TODO: maybe support batching, but this may come with memory cost
for (auto & mel_spec : mel_spec_chunks) {
clip_image_f32_ptr mel_f32(clip_image_f32_init());
mel_f32->nx = mel_spec.n_len;
mel_f32->ny = mel_spec.n_mel;
mel_f32->buf = std::move(mel_spec.data);
size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
clip_image_f32_batch batch_f32;
batch_f32.is_audio = true;
batch_f32.entries.push_back(std::move(mel_f32));
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
audio_tokens->n_tokens = n_tokens;
audio_tokens->batch_f32 = std::move(batch_f32);
audio_tokens->id = bitmaps[i_bm]->id; // optional
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
mtmd_input_chunk chunk{
MTMD_INPUT_CHUNK_TYPE_AUDIO,
{}, // text tokens
nullptr, // image tokens
std::move(audio_tokens),
};
output->entries.emplace_back(std::move(chunk));
}
i_bm++;
continue;
}
}
return 0;
}
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
if (image_tokens) {
delete image_tokens;
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
return 0;
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_encode(ctx, chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
bool ok = clip_image_batch_encode(
ctx->ctx_clip,
ctx->n_threads,
&chunk->tokens_audio->batch_f32,
ctx->image_embd_v.data());
return ok ? 0 : 1;
}
LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
return 1;
}
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@ -516,8 +652,12 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
return ctx->use_mrope;
}
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
mtmd_image_tokens_free(val);
bool mtmd_support_vision(mtmd_context * ctx) {
return ctx->has_vision;
}
bool mtmd_support_audio(mtmd_context * ctx) {
return ctx->has_audio;
}
// these 2 helpers below use internal clip_image_u8_ptr,
@ -526,6 +666,15 @@ void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
if (audio_helpers::is_audio_file((const char *)buf, len)) {
std::vector<float> pcmf32;
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
LOG_ERR("Unable to read WAV audio file from buffer\n");
return nullptr;
}
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
}
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
if (!ok) {
@ -538,15 +687,26 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t
}
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
clip_image_u8_ptr img_u8(clip_image_u8_init());
bool ok = clip_image_load_from_file(fname, img_u8.get());
if (!ok) {
LOG_ERR("Unable to load image %s\n", fname);
std::vector<unsigned char> buf;
FILE * f = fopen(fname, "rb");
if (!f) {
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
return nullptr;
}
uint32_t nx, ny;
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
return mtmd_bitmap_init(nx, ny, data);
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
buf.resize(file_size);
size_t n_read = fread(buf.data(), 1, file_size, f);
fclose(f);
if (n_read != (size_t)file_size) {
LOG_ERR("Failed to read entire file %s", fname);
return nullptr;
}
return mtmd_helper_bitmap_init_from_buf(buf.data(), buf.size());
}
//
@ -567,6 +727,18 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
return bitmap;
}
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
const float * data) {
mtmd_bitmap * bitmap = new mtmd_bitmap;
bitmap->nx = n_samples;
bitmap->ny = 1;
bitmap->is_audio = true;
size_t data_size = n_samples * sizeof(float);
bitmap->data.resize(data_size);
std::memcpy(bitmap->data.data(), data, data_size);
return bitmap;
}
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
return bitmap->nx;
}
@ -579,6 +751,10 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
return bitmap->data.data();
}
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
return bitmap->is_audio;
}
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
return bitmap->id.c_str();
}
@ -642,17 +818,56 @@ const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chu
return nullptr;
}
size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_tokens(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
return chunk->tokens_text.size();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return mtmd_image_tokens_get_n_pos(chunk->tokens_image.get());
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->n_tokens;
} else {
GGML_ABORT("invalid chunk type");
}
}
const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk) {
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
return chunk->tokens_image->id.c_str();
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
return chunk->tokens_audio->id.c_str();
}
return nullptr;
}
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
mtmd_input_chunk * copy = new mtmd_input_chunk{
chunk->type,
chunk->tokens_text,
mtmd_image_tokens_ptr(),
nullptr,
nullptr,
};
if (chunk->tokens_image) {
// copy the image tokens
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
*copy->tokens_image = chunk->tokens_image->clone();
}
if (chunk->tokens_audio) {
// copy the audio tokens
copy->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
*copy->tokens_audio = chunk->tokens_audio->clone();
}
return copy;
}
@ -700,7 +915,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
mtmd_input_chunk chunk_text{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::move(tokens_text),
{},
nullptr, // image tokens
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_text));
@ -712,8 +928,9 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
image_tokens->id = "image_1";
mtmd_input_chunk chunk_image{
MTMD_INPUT_CHUNK_TYPE_IMAGE,
{},
{}, // text tokens
std::move(image_tokens),
nullptr, // audio tokens
};
chunks->entries.emplace_back(std::move(chunk_image));

View File

@ -39,6 +39,7 @@
# define MTMD_API
#endif
// deprecated marker, use mtmd_default_marker() instead
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
#ifdef __cplusplus
@ -48,6 +49,7 @@ extern "C" {
enum mtmd_input_chunk_type {
MTMD_INPUT_CHUNK_TYPE_TEXT,
MTMD_INPUT_CHUNK_TYPE_IMAGE,
MTMD_INPUT_CHUNK_TYPE_AUDIO,
};
// opaque types
@ -79,9 +81,12 @@ struct mtmd_context_params {
bool print_timings;
int n_threads;
enum ggml_log_level verbosity;
const char * image_marker;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
};
MTMD_API const char * mtmd_default_marker(void);
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
// initialize the mtmd context
@ -98,17 +103,26 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
// whether the current model use M-RoPE for llama_decode
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
// whether the current model supports vision input
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
// whether the current model supports audio input
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
// mtmd_bitmap
//
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx,
uint32_t ny,
const unsigned char * data);
// if bitmap is image:
// length of data must be nx * ny * 3
// the data is in RGBRGBRGB... format
// if bitmap is audio:
// length of data must be n_samples * sizeof(float)
// the data is in float format (PCM F32)
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
MTMD_API bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
// bitmap ID is optional, but useful for KV cache tracking
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
@ -132,6 +146,11 @@ MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chu
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
// you can move the chunk ownership to your own code by copying it
@ -144,27 +163,28 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
//
// the instance will be constructed via mtmd_tokenize()
// it will be freed along with mtmd_input_chunk
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// tokenize an input text prompt and an image
// the prompt must have the input image marker (default: "<__image__>") in it
// the marker will be replaced with the image tokens
// tokenize an input text prompt and a list of bitmaps (images/audio)
// the prompt must have the input image marker (default: "<__media__>") in it
// the default marker is defined by mtmd_default_marker()
// the marker will be replaced with the image/audio chunk
// for example:
// "here is an image: <__image__>\ndescribe it in detail."
// "here is an image: <__media__>\ndescribe it in detail."
// this will gives 3 chunks:
// 1. "here is an image: <start_of_image>"
// 2. (image tokens)
// 2. (image/audio tokens)
// 3. "<end_of_image>\ndescribe it in detail."
// number of bitmaps must be equal to the number of image markers in the prompt
// number of bitmaps must be equal to the number of markers in the prompt
// this function is thread-safe (shared ctx)
// return values:
// 0 on success
// 1 on number of images not matching the number of markers
// 1 on number of bitmaps not matching the number of markers
// 2 on image preprocessing error
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
mtmd_input_chunks * output,
@ -173,9 +193,14 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
size_t n_bitmaps);
// returns 0 on success
// TODO: deprecate
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
const mtmd_image_tokens * image_tokens);
// returns 0 on success
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
const mtmd_input_chunk * chunk);
// get output embeddings from the last encode pass
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
@ -189,12 +214,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
//
// helper function to construct a mtmd_bitmap from a file
// it calls mtmd_helper_bitmap_init_from_buf() internally
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
// helper function to construct a mtmd_bitmap from a buffer containing a file
// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
// supported formats:
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
// audio: formats supported by miniaudio: wav, mp3, flac
// note: audio files will be auto-detected based on magic bytes
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);

View File

@ -710,7 +710,7 @@ static json oaicompat_completion_params_parse(
// replace this chunk with a marker
p["type"] = "text";
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
p["text"] = mtmd_default_marker();
p.erase("image_url");
}
}