mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd : add ultravox audio input (#13623)
* convert ok, load ok * warmup ok * test * still does not work? * fix padding * temporary give up * fix merge conflict * build_ultravox() * rm test * fix merge conflict * add necessary mtmd APIs * first working version (only 4s of audio) * will this monster compile? * fix compile * please compile * fPIC * fix windows * various fixes * clean up audio_helpers * fix conversion * add some debug stuff * long audio input ok * adapt the api * add --audio arg * final touch UX * add miniaudio to readme * fix typo * refactor kv metadata * mtmd_default_marker()
This commit is contained in:
@ -48,3 +48,7 @@ end_of_line = unset
|
|||||||
charset = unset
|
charset = unset
|
||||||
trim_trailing_whitespace = unset
|
trim_trailing_whitespace = unset
|
||||||
insert_final_newline = unset
|
insert_final_newline = unset
|
||||||
|
|
||||||
|
[tools/mtmd/miniaudio.h]
|
||||||
|
trim_trailing_whitespace = unset
|
||||||
|
insert_final_newline = unset
|
||||||
|
@ -580,3 +580,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
|||||||
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
|
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
|
||||||
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
||||||
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
||||||
|
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain
|
||||||
|
@ -39,7 +39,7 @@
|
|||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
std::initializer_list<enum llama_example> mmproj_examples = {
|
std::initializer_list<enum llama_example> mmproj_examples = {
|
||||||
LLAMA_EXAMPLE_LLAVA,
|
LLAMA_EXAMPLE_MTMD,
|
||||||
LLAMA_EXAMPLE_SERVER,
|
LLAMA_EXAMPLE_SERVER,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2233,12 +2233,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
}
|
}
|
||||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
|
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--image"}, "FILE",
|
{"--image", "--audio"}, "FILE",
|
||||||
"path to an image file. use with multimodal models. Specify multiple times for batching",
|
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.image.emplace_back(value);
|
params.image.emplace_back(value);
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_LLAVA}));
|
).set_examples({LLAMA_EXAMPLE_MTMD}));
|
||||||
if (llama_supports_rpc()) {
|
if (llama_supports_rpc()) {
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--rpc"}, "SERVERS",
|
{"--rpc"}, "SERVERS",
|
||||||
@ -2868,7 +2868,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.chat_template = value;
|
params.chat_template = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||||
string_format(
|
string_format(
|
||||||
|
@ -76,7 +76,7 @@ enum llama_example {
|
|||||||
LLAMA_EXAMPLE_SERVER,
|
LLAMA_EXAMPLE_SERVER,
|
||||||
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
||||||
LLAMA_EXAMPLE_EXPORT_LORA,
|
LLAMA_EXAMPLE_EXPORT_LORA,
|
||||||
LLAMA_EXAMPLE_LLAVA,
|
LLAMA_EXAMPLE_MTMD,
|
||||||
LLAMA_EXAMPLE_LOOKUP,
|
LLAMA_EXAMPLE_LOOKUP,
|
||||||
LLAMA_EXAMPLE_PARALLEL,
|
LLAMA_EXAMPLE_PARALLEL,
|
||||||
LLAMA_EXAMPLE_TTS,
|
LLAMA_EXAMPLE_TTS,
|
||||||
|
@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum):
|
|||||||
|
|
||||||
class ModelType(IntEnum):
|
class ModelType(IntEnum):
|
||||||
TEXT = 1
|
TEXT = 1
|
||||||
VISION = 2
|
MMPROJ = 2
|
||||||
|
|
||||||
|
|
||||||
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
||||||
@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
|||||||
class ModelBase:
|
class ModelBase:
|
||||||
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
|
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
|
||||||
ModelType.TEXT: {},
|
ModelType.TEXT: {},
|
||||||
ModelType.VISION: {},
|
ModelType.MMPROJ: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
dir_model: Path
|
dir_model: Path
|
||||||
@ -88,7 +88,7 @@ class ModelBase:
|
|||||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
|
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
|
||||||
if type(self) is ModelBase or \
|
if type(self) is ModelBase or \
|
||||||
type(self) is TextModel or \
|
type(self) is TextModel or \
|
||||||
type(self) is VisionModel:
|
type(self) is MmprojModel:
|
||||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
|
|
||||||
self.dir_model = dir_model
|
self.dir_model = dir_model
|
||||||
@ -309,6 +309,7 @@ class ModelBase:
|
|||||||
gguf.MODEL_TENSOR.POSNET_NORM1,
|
gguf.MODEL_TENSOR.POSNET_NORM1,
|
||||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||||
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
|
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||||
|
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
or not new_name.endswith(".weight")
|
or not new_name.endswith(".weight")
|
||||||
@ -438,7 +439,7 @@ class ModelBase:
|
|||||||
assert names
|
assert names
|
||||||
|
|
||||||
def func(modelcls: AnyModel) -> AnyModel:
|
def func(modelcls: AnyModel) -> AnyModel:
|
||||||
model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
|
model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT
|
||||||
for name in names:
|
for name in names:
|
||||||
cls._model_classes[model_type][name] = modelcls
|
cls._model_classes[model_type][name] = modelcls
|
||||||
return modelcls
|
return modelcls
|
||||||
@ -1114,60 +1115,87 @@ class TextModel(ModelBase):
|
|||||||
self.gguf_writer.add_pooling_type(pooling_type)
|
self.gguf_writer.add_pooling_type(pooling_type)
|
||||||
|
|
||||||
|
|
||||||
class VisionModel(ModelBase):
|
class MmprojModel(ModelBase):
|
||||||
model_type = ModelType.VISION
|
model_type = ModelType.MMPROJ
|
||||||
model_arch = gguf.MODEL_ARCH.CLIP_VISION
|
model_arch = gguf.MODEL_ARCH.MMPROJ
|
||||||
preprocessor_config: dict[str, Any]
|
preprocessor_config: dict[str, Any]
|
||||||
global_config: dict[str, Any]
|
global_config: dict[str, Any]
|
||||||
|
|
||||||
|
has_vision_encoder: bool = True # by default
|
||||||
|
has_audio_encoder: bool = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
|
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
|
||||||
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
|
||||||
|
|
||||||
|
if self.has_vision_encoder and self.has_audio_encoder:
|
||||||
|
raise NotImplementedError("both vision + audio not supported yet")
|
||||||
|
|
||||||
# get n_embd of the text model
|
# get n_embd of the text model
|
||||||
if "text_config" not in self.hparams:
|
if "text_config" not in self.hparams:
|
||||||
self.hparams["text_config"] = {}
|
self.hparams["text_config"] = {}
|
||||||
|
if "audio_config" not in self.hparams:
|
||||||
|
self.hparams["audio_config"] = {}
|
||||||
text_config = {**self.hparams, **self.hparams["text_config"]}
|
text_config = {**self.hparams, **self.hparams["text_config"]}
|
||||||
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
||||||
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
||||||
|
|
||||||
if "vision_config" not in self.hparams:
|
|
||||||
raise ValueError("vision_config not found in hparams")
|
|
||||||
# move vision config to the top level, while preserving the original hparams in global_config
|
# move vision config to the top level, while preserving the original hparams in global_config
|
||||||
self.global_config = self.hparams
|
self.global_config = self.hparams
|
||||||
self.hparams = self.hparams["vision_config"]
|
|
||||||
|
if "vision_config" in self.hparams:
|
||||||
|
self.hparams = self.hparams["vision_config"]
|
||||||
|
elif "audio_config" in self.hparams:
|
||||||
|
self.hparams = self.hparams["audio_config"]
|
||||||
|
else:
|
||||||
|
raise ValueError("vision_config / audio_config not found in hparams")
|
||||||
|
|
||||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
|
||||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||||
|
|
||||||
# load preprocessor config
|
# load preprocessor config
|
||||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||||
self.preprocessor_config = json.load(f)
|
self.preprocessor_config = json.load(f)
|
||||||
|
|
||||||
def set_type(self):
|
def set_type(self):
|
||||||
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
|
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
|
||||||
self.gguf_writer.add_vision_has_vision_encoder(True)
|
|
||||||
|
|
||||||
# vision config
|
if self.has_vision_encoder:
|
||||||
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
|
self.gguf_writer.add_clip_has_vision_encoder(True)
|
||||||
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
||||||
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
|
||||||
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
|
||||||
self.gguf_writer.add_vision_block_count(self.block_count)
|
|
||||||
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
|
||||||
|
|
||||||
# preprocessor config
|
# vision config
|
||||||
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
|
||||||
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
|
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
||||||
|
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
||||||
|
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||||
|
self.gguf_writer.add_vision_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
||||||
|
|
||||||
|
# preprocessor config
|
||||||
|
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
||||||
|
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
|
||||||
|
|
||||||
|
elif self.has_audio_encoder:
|
||||||
|
self.gguf_writer.add_clip_has_audio_encoder(True)
|
||||||
|
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
|
||||||
|
|
||||||
|
# audio config
|
||||||
|
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
|
||||||
|
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||||
|
self.gguf_writer.add_audio_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("MmprojModel must have either vision or audio encoder")
|
||||||
|
|
||||||
def write_vocab(self):
|
def write_vocab(self):
|
||||||
raise ValueError("VisionModel does not support vocab writing")
|
raise ValueError("MmprojModel does not support vocab writing")
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("GPTNeoXForCausalLM")
|
@ModelBase.register("GPTNeoXForCausalLM")
|
||||||
@ -1951,7 +1979,7 @@ class LlamaModel(TextModel):
|
|||||||
"LlavaForConditionalGeneration", # pixtral
|
"LlavaForConditionalGeneration", # pixtral
|
||||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||||
)
|
)
|
||||||
class LlavaVisionModel(VisionModel):
|
class LlavaVisionModel(MmprojModel):
|
||||||
img_break_tok_id = -1
|
img_break_tok_id = -1
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -1977,7 +2005,7 @@ class LlavaVisionModel(VisionModel):
|
|||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if hparams["model_type"] == "pixtral":
|
if hparams["model_type"] == "pixtral":
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
|
|
||||||
# hidden_act
|
# hidden_act
|
||||||
@ -2016,7 +2044,7 @@ class LlavaVisionModel(VisionModel):
|
|||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
|
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
|
||||||
class SmolVLMModel(VisionModel):
|
class SmolVLMModel(MmprojModel):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.hparams["model_type"] == "smolvlm_vision":
|
if self.hparams["model_type"] == "smolvlm_vision":
|
||||||
@ -2028,7 +2056,7 @@ class SmolVLMModel(VisionModel):
|
|||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||||
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
|
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
|
||||||
self.gguf_writer.add_vision_use_gelu(True)
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
@ -2094,10 +2122,10 @@ class Llama4Model(LlamaModel):
|
|||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||||
class Llama4VisionModel(VisionModel):
|
class Llama4VisionModel(MmprojModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
|
||||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
|
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
|
||||||
assert self.hparams["hidden_act"] == "gelu"
|
assert self.hparams["hidden_act"] == "gelu"
|
||||||
@ -2670,7 +2698,7 @@ class Qwen2VLModel(TextModel):
|
|||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
|
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
|
||||||
class Qwen2VLVisionModel(VisionModel):
|
class Qwen2VLVisionModel(MmprojModel):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.hparams["image_size"] = self.hparams.get("image_size", 560)
|
self.hparams["image_size"] = self.hparams.get("image_size", 560)
|
||||||
@ -2685,9 +2713,9 @@ class Qwen2VLVisionModel(VisionModel):
|
|||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if self.global_config['model_type'] == 'qwen2_vl':
|
if self.global_config['model_type'] == 'qwen2_vl':
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
|
||||||
elif self.global_config['model_type'] == 'qwen2_5_vl':
|
elif self.global_config['model_type'] == 'qwen2_5_vl':
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
|
||||||
self.gguf_writer.add_vision_use_silu(True)
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
# find n_wa_pattern (window attention pattern)
|
# find n_wa_pattern (window attention pattern)
|
||||||
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
|
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
|
||||||
@ -2746,11 +2774,11 @@ class Qwen2VLVisionModel(VisionModel):
|
|||||||
|
|
||||||
|
|
||||||
@ModelBase.register("InternVisionModel")
|
@ModelBase.register("InternVisionModel")
|
||||||
class InternVisionModel(VisionModel):
|
class InternVisionModel(MmprojModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
# hidden_act
|
# hidden_act
|
||||||
if hparams["hidden_act"] == "silu":
|
if hparams["hidden_act"] == "silu":
|
||||||
@ -4008,11 +4036,11 @@ class Gemma3Model(TextModel):
|
|||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Gemma3ForConditionalGeneration")
|
@ModelBase.register("Gemma3ForConditionalGeneration")
|
||||||
class Gemma3VisionModel(VisionModel):
|
class Gemma3VisionModel(MmprojModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3)
|
||||||
# default values below are taken from HF tranformers code
|
# default values below are taken from HF tranformers code
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
|
||||||
self.gguf_writer.add_vision_use_gelu(True)
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
@ -5959,6 +5987,52 @@ class ChameleonModel(TextModel):
|
|||||||
return data_torch
|
return data_torch
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("UltravoxModel")
|
||||||
|
class UltravoxModel(TextModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("UltravoxModel")
|
||||||
|
class UltravoxAudioModel(MmprojModel):
|
||||||
|
has_vision_encoder = False # no vision encoder
|
||||||
|
has_audio_encoder = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hparams["hidden_size"] = self.hparams["d_model"]
|
||||||
|
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
|
||||||
|
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
|
||||||
|
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
|
||||||
|
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||||
|
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||||
|
|
||||||
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
|
del bid, new_name, n_dims # unused
|
||||||
|
if ".conv" in name and ".weight" in name:
|
||||||
|
return gguf.GGMLQuantizationType.F16
|
||||||
|
return False
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
# prevent clash naming with vision tensors
|
||||||
|
if name.startswith("multi_modal_projector"):
|
||||||
|
name = "audio." + name
|
||||||
|
|
||||||
|
if "conv1.bias" in name or "conv2.bias" in name:
|
||||||
|
# transpose conv1 and conv2 bias
|
||||||
|
data_torch = data_torch.unsqueeze(-1)
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
@ -6134,13 +6208,15 @@ def split_str_to_n_bytes(split_str: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
|
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
|
||||||
|
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
|
||||||
|
# maybe we should fallback to text model's arch in that case, since not many models have both
|
||||||
text_config = hparams.get("text_config", {})
|
text_config = hparams.get("text_config", {})
|
||||||
vision_config = hparams.get("vision_config", {})
|
vision_config = hparams.get("vision_config", {})
|
||||||
arch = hparams["architectures"][0]
|
arch = hparams["architectures"][0]
|
||||||
# if "architectures" is found in the sub-config, use that instead
|
# if "architectures" is found in the sub-config, use that instead
|
||||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||||
arch = text_config["architectures"][0]
|
arch = text_config["architectures"][0]
|
||||||
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
|
||||||
arch = vision_config["architectures"][0]
|
arch = vision_config["architectures"][0]
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
@ -6203,7 +6279,7 @@ def main() -> None:
|
|||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output_type = ftype_map[args.outtype]
|
output_type = ftype_map[args.outtype]
|
||||||
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
|
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
|
||||||
hparams = ModelBase.load_hparams(dir_model)
|
hparams = ModelBase.load_hparams(dir_model)
|
||||||
model_architecture = get_model_architecture(hparams, model_type)
|
model_architecture = get_model_architecture(hparams, model_type)
|
||||||
logger.info(f"Model architecture: {model_architecture}")
|
logger.info(f"Model architecture: {model_architecture}")
|
||||||
|
@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools
|
|||||||
- [llama-mtmd-cli](../tools/mtmd/README.md)
|
- [llama-mtmd-cli](../tools/mtmd/README.md)
|
||||||
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
|
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
|
||||||
|
|
||||||
To enable it, can use use one of the 2 methods below:
|
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
|
||||||
|
|
||||||
|
To enable it, you can use one of the 2 methods below:
|
||||||
|
|
||||||
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
|
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
|
||||||
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
|
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
|
||||||
@ -37,6 +39,8 @@ Replaces the `(tool_name)` with the name of binary you want to use. For example,
|
|||||||
|
|
||||||
NOTE: some models may require large context window, for example: `-c 8192`
|
NOTE: some models may require large context window, for example: `-c 8192`
|
||||||
|
|
||||||
|
**Vision models**:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Gemma 3
|
# Gemma 3
|
||||||
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
|
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
|
||||||
@ -78,3 +82,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||||||
# Llama 4 Scout
|
# Llama 4 Scout
|
||||||
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
|
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Audio models**:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Ultravox 0.5
|
||||||
|
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
|
||||||
|
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
|
||||||
|
```
|
||||||
|
@ -219,10 +219,13 @@ class Keys:
|
|||||||
TYPE = "adapter.type"
|
TYPE = "adapter.type"
|
||||||
LORA_ALPHA = "adapter.lora.alpha"
|
LORA_ALPHA = "adapter.lora.alpha"
|
||||||
|
|
||||||
class ClipVision:
|
class Clip:
|
||||||
PROJECTOR_TYPE = "clip.projector_type"
|
PROJECTOR_TYPE = "clip.projector_type"
|
||||||
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
||||||
|
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
|
||||||
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
|
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
|
||||||
|
|
||||||
|
class ClipVision:
|
||||||
IMAGE_SIZE = "clip.vision.image_size"
|
IMAGE_SIZE = "clip.vision.image_size"
|
||||||
PATCH_SIZE = "clip.vision.patch_size"
|
PATCH_SIZE = "clip.vision.patch_size"
|
||||||
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
||||||
@ -243,19 +246,33 @@ class Keys:
|
|||||||
class Projector:
|
class Projector:
|
||||||
SCALE_FACTOR = "clip.vision.projector.scale_factor"
|
SCALE_FACTOR = "clip.vision.projector.scale_factor"
|
||||||
|
|
||||||
|
class ClipAudio:
|
||||||
|
NUM_MEL_BINS = "clip.audio.num_mel_bins"
|
||||||
|
EMBEDDING_LENGTH = "clip.audio.embedding_length"
|
||||||
|
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
|
||||||
|
PROJECTION_DIM = "clip.audio.projection_dim"
|
||||||
|
BLOCK_COUNT = "clip.audio.block_count"
|
||||||
|
|
||||||
|
class Attention:
|
||||||
|
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||||
|
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
|
||||||
|
|
||||||
|
class Projector:
|
||||||
|
STACK_FACTOR = "clip.audio.projector.stack_factor"
|
||||||
|
|
||||||
#
|
#
|
||||||
# recommended mapping of model tensor names for storage in gguf
|
# recommended mapping of model tensor names for storage in gguf
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
class GGUFType:
|
class GGUFType:
|
||||||
MODEL = "model"
|
MODEL = "model"
|
||||||
ADAPTER = "adapter"
|
ADAPTER = "adapter"
|
||||||
CLIP_VISION = "clip-vision"
|
MMPROJ = "mmproj" # dummy, unused for now
|
||||||
|
|
||||||
|
|
||||||
class MODEL_ARCH(IntEnum):
|
class MODEL_ARCH(IntEnum):
|
||||||
CLIP_VISION = auto() # dummy arch for clip.cpp
|
MMPROJ = auto() # dummy arch for clip.cpp
|
||||||
LLAMA = auto()
|
LLAMA = auto()
|
||||||
LLAMA4 = auto()
|
LLAMA4 = auto()
|
||||||
DECI = auto()
|
DECI = auto()
|
||||||
@ -514,10 +531,27 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
V_RESMPL_QUERY = auto() # minicpmv
|
V_RESMPL_QUERY = auto() # minicpmv
|
||||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||||
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||||
|
# audio (mtmd)
|
||||||
|
A_ENC_EMBD_POS = auto()
|
||||||
|
A_ENC_CONV1D = auto()
|
||||||
|
A_PRE_NORM = auto()
|
||||||
|
A_POST_NORM = auto()
|
||||||
|
A_ENC_ATTN_Q = auto()
|
||||||
|
A_ENC_ATTN_K = auto()
|
||||||
|
A_ENC_ATTN_V = auto()
|
||||||
|
A_ENC_INPUT_NORM = auto()
|
||||||
|
A_ENC_OUTPUT = auto()
|
||||||
|
A_ENC_OUTPUT_NORM = auto()
|
||||||
|
A_ENC_FFN_UP = auto()
|
||||||
|
A_ENC_FFN_GATE = auto()
|
||||||
|
A_ENC_FFN_DOWN = auto()
|
||||||
|
A_MMPROJ = auto()
|
||||||
|
A_MM_NORM_PRE = auto()
|
||||||
|
A_MM_NORM_MID = auto()
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
|
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
|
||||||
MODEL_ARCH.LLAMA: "llama",
|
MODEL_ARCH.LLAMA: "llama",
|
||||||
MODEL_ARCH.LLAMA4: "llama4",
|
MODEL_ARCH.LLAMA4: "llama4",
|
||||||
MODEL_ARCH.DECI: "deci",
|
MODEL_ARCH.DECI: "deci",
|
||||||
@ -776,10 +810,27 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||||
|
# audio (mtmd)
|
||||||
|
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||||
|
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||||
|
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||||
|
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
|
||||||
|
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
||||||
|
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
||||||
|
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
||||||
|
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_ARCH.CLIP_VISION: [
|
MODEL_ARCH.MMPROJ: [
|
||||||
MODEL_TENSOR.V_MMPROJ,
|
MODEL_TENSOR.V_MMPROJ,
|
||||||
MODEL_TENSOR.V_MMPROJ_FC,
|
MODEL_TENSOR.V_MMPROJ_FC,
|
||||||
MODEL_TENSOR.V_MMPROJ_MLP,
|
MODEL_TENSOR.V_MMPROJ_MLP,
|
||||||
@ -819,6 +870,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||||
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||||
|
# audio
|
||||||
|
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||||
|
MODEL_TENSOR.A_ENC_CONV1D,
|
||||||
|
MODEL_TENSOR.A_PRE_NORM,
|
||||||
|
MODEL_TENSOR.A_POST_NORM,
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_Q,
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_K,
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_V,
|
||||||
|
MODEL_TENSOR.A_ENC_INPUT_NORM,
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT,
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_UP,
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_GATE,
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
||||||
|
MODEL_TENSOR.A_MMPROJ,
|
||||||
|
MODEL_TENSOR.A_MM_NORM_PRE,
|
||||||
|
MODEL_TENSOR.A_MM_NORM_MID,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.LLAMA: [
|
MODEL_ARCH.LLAMA: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
@ -2186,6 +2254,7 @@ class VisionProjectorType:
|
|||||||
LLAMA4 = "llama4"
|
LLAMA4 = "llama4"
|
||||||
QWEN2VL = "qwen2vl_merger"
|
QWEN2VL = "qwen2vl_merger"
|
||||||
QWEN25VL = "qwen2.5vl_merger"
|
QWEN25VL = "qwen2.5vl_merger"
|
||||||
|
ULTRAVOX = "ultravox"
|
||||||
INTERNVL = "internvl"
|
INTERNVL = "internvl"
|
||||||
|
|
||||||
|
|
||||||
|
@ -936,12 +936,18 @@ class GGUFWriter:
|
|||||||
|
|
||||||
# for vision models
|
# for vision models
|
||||||
|
|
||||||
|
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
|
||||||
|
|
||||||
|
def add_clip_has_audio_encoder(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
|
||||||
|
|
||||||
|
def add_clip_projector_type(self, value: str) -> None:
|
||||||
|
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
|
||||||
|
|
||||||
def add_vision_projection_dim(self, value: int) -> None:
|
def add_vision_projection_dim(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
||||||
|
|
||||||
def add_vision_has_vision_encoder(self, value: bool) -> None:
|
|
||||||
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
|
|
||||||
|
|
||||||
def add_vision_patch_size(self, value: int) -> None:
|
def add_vision_patch_size(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
||||||
|
|
||||||
@ -957,9 +963,6 @@ class GGUFWriter:
|
|||||||
def add_vision_head_count(self, value: int) -> None:
|
def add_vision_head_count(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
||||||
|
|
||||||
def add_vision_projector_type(self, value: str) -> None:
|
|
||||||
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
|
|
||||||
|
|
||||||
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
||||||
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
||||||
|
|
||||||
@ -987,6 +990,32 @@ class GGUFWriter:
|
|||||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||||
|
|
||||||
|
# audio models
|
||||||
|
|
||||||
|
def add_audio_projection_dim(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
|
||||||
|
|
||||||
|
def add_audio_embedding_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
|
||||||
|
|
||||||
|
def add_audio_feed_forward_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
|
||||||
|
|
||||||
|
def add_audio_block_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
|
||||||
|
|
||||||
|
def add_audio_head_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
|
||||||
|
|
||||||
|
def add_audio_attention_layernorm_eps(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
|
||||||
|
|
||||||
|
def add_audio_num_mel_bins(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
|
||||||
|
|
||||||
|
def add_audio_stack_factor(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
|
||||||
|
|
||||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||||
pack_prefix = ''
|
pack_prefix = ''
|
||||||
if not skip_pack_prefix:
|
if not skip_pack_prefix:
|
||||||
|
@ -1110,6 +1110,68 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||||
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# audio (mtmd)
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_EMBD_POS: (
|
||||||
|
"audio_tower.embed_positions", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_CONV1D: (
|
||||||
|
"audio_tower.conv{bid}", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_PRE_NORM: (),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_POST_NORM: (
|
||||||
|
"audio_tower.layer_norm", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_Q: (
|
||||||
|
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_K: (
|
||||||
|
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_ATTN_V: (
|
||||||
|
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_INPUT_NORM: (
|
||||||
|
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT: (
|
||||||
|
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
|
||||||
|
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_UP: (
|
||||||
|
"audio_tower.layers.{bid}.fc1", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_GATE: (),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_ENC_FFN_DOWN: (
|
||||||
|
"audio_tower.layers.{bid}.fc2", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_MMPROJ: (
|
||||||
|
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||||
|
"audio.multi_modal_projector.ln_pre", # ultravox
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_MM_NORM_MID: (
|
||||||
|
"audio.multi_modal_projector.ln_mid", # ultravox
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# architecture-specific block mappings
|
# architecture-specific block mappings
|
||||||
|
@ -1,5 +1,15 @@
|
|||||||
# mtmd
|
# mtmd
|
||||||
|
|
||||||
|
# compile mtmd-audio separately to avoid long compile times with miniaudio.h
|
||||||
|
# TODO @ngxson : move miniaudio.h and stb_image.h to mtmd-helper.cpp, then compile the helper as a separate library
|
||||||
|
add_library(mtmd_audio STATIC mtmd-audio.cpp mtmd-audio.h)
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(mtmd_audio PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
endif()
|
||||||
|
target_link_libraries(mtmd_audio PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(mtmd_audio PRIVATE cxx_std_17)
|
||||||
|
target_include_directories(mtmd_audio PRIVATE .)
|
||||||
|
|
||||||
add_library(mtmd OBJECT
|
add_library(mtmd OBJECT
|
||||||
mtmd.cpp
|
mtmd.cpp
|
||||||
mtmd-helper.cpp
|
mtmd-helper.cpp
|
||||||
@ -9,7 +19,7 @@ add_library(mtmd OBJECT
|
|||||||
clip-impl.h
|
clip-impl.h
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(mtmd PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
target_include_directories(mtmd PUBLIC .)
|
target_include_directories(mtmd PUBLIC .)
|
||||||
target_include_directories(mtmd PRIVATE ../..)
|
target_include_directories(mtmd PRIVATE ../..)
|
||||||
@ -22,12 +32,13 @@ if (BUILD_SHARED_LIBS)
|
|||||||
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||||
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
|
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
|
||||||
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(mtmd_shared PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
|
||||||
install(TARGETS mtmd_shared LIBRARY)
|
install(TARGETS mtmd_shared LIBRARY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
|
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
|
target_compile_options(mtmd_audio PRIVATE -Wno-cast-qual) # miniaudio.h
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(TARGET BUILD_INFO)
|
if(TARGET BUILD_INFO)
|
||||||
|
@ -16,22 +16,26 @@
|
|||||||
#define KEY_FTYPE "general.file_type"
|
#define KEY_FTYPE "general.file_type"
|
||||||
#define KEY_NAME "general.name"
|
#define KEY_NAME "general.name"
|
||||||
#define KEY_DESCRIPTION "general.description"
|
#define KEY_DESCRIPTION "general.description"
|
||||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
|
||||||
|
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
|
||||||
#define KEY_USE_GELU "clip.use_gelu"
|
#define KEY_USE_GELU "clip.use_gelu"
|
||||||
#define KEY_USE_SILU "clip.use_silu"
|
#define KEY_USE_SILU "clip.use_silu"
|
||||||
#define KEY_N_EMBD "clip.vision.embedding_length"
|
|
||||||
#define KEY_N_FF "clip.vision.feed_forward_length"
|
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||||
#define KEY_N_BLOCK "clip.vision.block_count"
|
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||||
#define KEY_N_HEAD "clip.vision.attention.head_count"
|
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||||
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
|
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||||
#define KEY_PROJ_DIM "clip.vision.projection_dim"
|
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||||
|
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||||
|
|
||||||
|
// vision-specific
|
||||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
|
||||||
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
@ -39,13 +43,18 @@
|
|||||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||||
|
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||||
|
|
||||||
|
// audio-specific
|
||||||
|
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||||
|
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// tensor name constants
|
// tensor name constants
|
||||||
//
|
//
|
||||||
|
|
||||||
#define TN_POS_EMBD "v.position_embd.weight"
|
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||||
#define TN_CLASS_EMBD "v.class_embd"
|
#define TN_CLASS_EMBD "v.class_embd"
|
||||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||||
@ -95,6 +104,12 @@
|
|||||||
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
||||||
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
||||||
|
|
||||||
|
// ultravox
|
||||||
|
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||||
|
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||||
|
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||||
|
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
||||||
|
|
||||||
// align x to upper multiple of n
|
// align x to upper multiple of n
|
||||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||||
|
|
||||||
@ -110,6 +125,7 @@ enum projector_type {
|
|||||||
PROJECTOR_TYPE_IDEFICS3,
|
PROJECTOR_TYPE_IDEFICS3,
|
||||||
PROJECTOR_TYPE_PIXTRAL,
|
PROJECTOR_TYPE_PIXTRAL,
|
||||||
PROJECTOR_TYPE_QWEN25VL,
|
PROJECTOR_TYPE_QWEN25VL,
|
||||||
|
PROJECTOR_TYPE_ULTRAVOX,
|
||||||
PROJECTOR_TYPE_INTERNVL,
|
PROJECTOR_TYPE_INTERNVL,
|
||||||
PROJECTOR_TYPE_LLAMA4,
|
PROJECTOR_TYPE_LLAMA4,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
@ -126,6 +142,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||||
|
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
||||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||||
};
|
};
|
||||||
@ -147,8 +164,10 @@ struct clip_image_u8 {
|
|||||||
std::vector<uint8_t> buf;
|
std::vector<uint8_t> buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
// RGB float32 image (NHWC)
|
// For images, buf.size() == nx*ny*3
|
||||||
// Memory layout: RGBRGBRGB...
|
// Memory layout: RGBRGBRGB...
|
||||||
|
// For audio, only one channel is used, buf.size() == nx*ny
|
||||||
|
// nx will be n_frames and ny will be n_mel
|
||||||
struct clip_image_f32 {
|
struct clip_image_f32 {
|
||||||
int nx;
|
int nx;
|
||||||
int ny;
|
int ny;
|
||||||
@ -242,6 +261,7 @@ struct clip_image_u8_batch {
|
|||||||
|
|
||||||
struct clip_image_f32_batch {
|
struct clip_image_f32_batch {
|
||||||
std::vector<clip_image_f32_ptr> entries;
|
std::vector<clip_image_f32_ptr> entries;
|
||||||
|
bool is_audio = false;
|
||||||
|
|
||||||
// for llava-uhd style models, we need to know the grid size
|
// for llava-uhd style models, we need to know the grid size
|
||||||
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
||||||
@ -249,7 +269,12 @@ struct clip_image_f32_batch {
|
|||||||
int grid_y = 0;
|
int grid_y = 0;
|
||||||
|
|
||||||
clip_image_f32_batch clone() const {
|
clip_image_f32_batch clone() const {
|
||||||
clip_image_f32_batch new_batch;
|
clip_image_f32_batch new_batch{
|
||||||
|
/* entries */ {},
|
||||||
|
/* is_audio */ is_audio,
|
||||||
|
/* grid_x */ grid_x,
|
||||||
|
/* grid_y */ grid_y,
|
||||||
|
};
|
||||||
new_batch.entries.reserve(entries.size());
|
new_batch.entries.reserve(entries.size());
|
||||||
for (const auto & entry : entries) {
|
for (const auto & entry : entries) {
|
||||||
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
||||||
|
@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
|
|||||||
|
|
||||||
enum ffn_op_type {
|
enum ffn_op_type {
|
||||||
FFN_GELU,
|
FFN_GELU,
|
||||||
|
FFN_GELU_ERF,
|
||||||
FFN_SILU,
|
FFN_SILU,
|
||||||
FFN_GELU_QUICK,
|
FFN_GELU_QUICK,
|
||||||
};
|
};
|
||||||
@ -165,6 +166,9 @@ enum patch_merge_type {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct clip_hparams {
|
struct clip_hparams {
|
||||||
|
bool has_vision = false;
|
||||||
|
bool has_audio = false;
|
||||||
|
|
||||||
int32_t image_size;
|
int32_t image_size;
|
||||||
int32_t patch_size;
|
int32_t patch_size;
|
||||||
int32_t n_embd;
|
int32_t n_embd;
|
||||||
@ -191,6 +195,10 @@ struct clip_hparams {
|
|||||||
int32_t attn_window_size = 0;
|
int32_t attn_window_size = 0;
|
||||||
int32_t n_wa_pattern = 0;
|
int32_t n_wa_pattern = 0;
|
||||||
int32_t spatial_merge_size = 0;
|
int32_t spatial_merge_size = 0;
|
||||||
|
|
||||||
|
// audio
|
||||||
|
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||||
|
int32_t proj_stack_factor = 0; // ultravox
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
@ -332,6 +340,14 @@ struct clip_vision_model {
|
|||||||
// pixtral
|
// pixtral
|
||||||
ggml_tensor * token_embd_img_break = nullptr;
|
ggml_tensor * token_embd_img_break = nullptr;
|
||||||
ggml_tensor * mm_patch_merger_w = nullptr;
|
ggml_tensor * mm_patch_merger_w = nullptr;
|
||||||
|
|
||||||
|
// ultravox / whisper encoder
|
||||||
|
ggml_tensor * conv1d_1_w = nullptr;
|
||||||
|
ggml_tensor * conv1d_1_b = nullptr;
|
||||||
|
ggml_tensor * conv1d_2_w = nullptr;
|
||||||
|
ggml_tensor * conv1d_2_b = nullptr;
|
||||||
|
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||||
|
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
@ -1408,6 +1424,104 @@ struct clip_graph {
|
|||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// whisper encoder with custom projector
|
||||||
|
ggml_cgraph * build_whisper_enc() {
|
||||||
|
const int n_frames = img.nx;
|
||||||
|
const int n_pos = n_frames / 2;
|
||||||
|
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
|
||||||
|
|
||||||
|
ggml_tensor * inp = build_inp_raw(1);
|
||||||
|
|
||||||
|
// conv1d block
|
||||||
|
{
|
||||||
|
// convolution + gelu
|
||||||
|
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
|
||||||
|
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
|
||||||
|
|
||||||
|
cur = ggml_gelu_erf(ctx0, cur);
|
||||||
|
|
||||||
|
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
|
||||||
|
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
|
||||||
|
|
||||||
|
cur = ggml_gelu_erf(ctx0, cur);
|
||||||
|
// transpose
|
||||||
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
||||||
|
cb(inp, "after_conv1d", -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanity check (only check one layer, but it should be the same for all)
|
||||||
|
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
|
||||||
|
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
|
||||||
|
GGML_ASSERT(model.layers[0].q_b);
|
||||||
|
GGML_ASSERT(model.layers[0].v_b);
|
||||||
|
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
|
||||||
|
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
|
||||||
|
|
||||||
|
ggml_tensor * pos_embd_selected = ggml_view_2d(
|
||||||
|
ctx0, model.position_embeddings,
|
||||||
|
model.position_embeddings->ne[0], n_pos,
|
||||||
|
model.position_embeddings->nb[1], 0
|
||||||
|
);
|
||||||
|
ggml_tensor * cur = build_vit(
|
||||||
|
inp, n_pos,
|
||||||
|
NORM_TYPE_NORMAL,
|
||||||
|
hparams.ffn_op,
|
||||||
|
pos_embd_selected,
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
cb(cur, "after_transformer", -1);
|
||||||
|
|
||||||
|
// StackAudioFrames
|
||||||
|
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||||
|
{
|
||||||
|
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||||
|
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||||
|
int64_t pad = padded_len - ggml_nelements(cur);
|
||||||
|
if (pad > 0) {
|
||||||
|
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||||
|
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||||
|
}
|
||||||
|
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||||
|
ggml_row_size(cur->type, stride), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(cur, "after_stacked", -1);
|
||||||
|
|
||||||
|
// UltravoxProjector
|
||||||
|
{
|
||||||
|
// pre-norm
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||||
|
|
||||||
|
// ffn in
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||||
|
|
||||||
|
// swiglu
|
||||||
|
{
|
||||||
|
int64_t split_point = cur->ne[0] / 2;
|
||||||
|
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||||
|
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||||
|
|
||||||
|
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||||
|
x1 = ggml_silu(ctx0, x1);
|
||||||
|
cur = ggml_mul(ctx0, x0, x1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mid-norm
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||||
|
|
||||||
|
// ffn out
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(cur, "projected", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//
|
//
|
||||||
// utility functions
|
// utility functions
|
||||||
@ -1562,8 +1676,8 @@ private:
|
|||||||
return inp;
|
return inp;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_inp_raw() {
|
ggml_tensor * build_inp_raw(int channels = 3) {
|
||||||
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
|
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
||||||
ggml_set_name(inp_raw, "inp_raw");
|
ggml_set_name(inp_raw, "inp_raw");
|
||||||
ggml_set_input(inp_raw);
|
ggml_set_input(inp_raw);
|
||||||
return inp_raw;
|
return inp_raw;
|
||||||
@ -1641,6 +1755,11 @@ private:
|
|||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cb(cur, "ffn_gelu", il);
|
cb(cur, "ffn_gelu", il);
|
||||||
} break;
|
} break;
|
||||||
|
case FFN_GELU_ERF:
|
||||||
|
{
|
||||||
|
cur = ggml_gelu_erf(ctx0, cur);
|
||||||
|
cb(cur, "ggml_gelu_erf", il);
|
||||||
|
} break;
|
||||||
case FFN_GELU_QUICK:
|
case FFN_GELU_QUICK:
|
||||||
{
|
{
|
||||||
cur = ggml_gelu_quick(ctx0, cur);
|
cur = ggml_gelu_quick(ctx0, cur);
|
||||||
@ -1832,6 +1951,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||||||
{
|
{
|
||||||
res = graph.build_llama4();
|
res = graph.build_llama4();
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
{
|
||||||
|
res = graph.build_whisper_enc();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
res = graph.build_llava();
|
res = graph.build_llava();
|
||||||
@ -1915,18 +2038,30 @@ struct clip_model_loader {
|
|||||||
|
|
||||||
// other hparams
|
// other hparams
|
||||||
{
|
{
|
||||||
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
|
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
|
||||||
|
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
|
||||||
|
|
||||||
get_u32(KEY_N_EMBD, hparams.n_embd);
|
const char * prefix = hparams.has_vision ? "vision" : "audio";
|
||||||
get_u32(KEY_N_HEAD, hparams.n_head);
|
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
|
||||||
get_u32(KEY_N_FF, hparams.n_ff);
|
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
|
||||||
get_u32(KEY_N_BLOCK, hparams.n_layer);
|
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
|
||||||
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
|
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
|
||||||
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
|
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
|
||||||
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
|
||||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
|
||||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
if (hparams.has_vision) {
|
||||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
|
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
||||||
|
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||||
|
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||||
|
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
|
||||||
|
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
|
||||||
|
|
||||||
|
} else if (hparams.has_audio) {
|
||||||
|
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
|
||||||
|
}
|
||||||
|
|
||||||
// default warmup value
|
// default warmup value
|
||||||
hparams.warmup_image_size = hparams.image_size;
|
hparams.warmup_image_size = hparams.image_size;
|
||||||
@ -1964,7 +2099,7 @@ struct clip_model_loader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
if (hparams.has_vision) {
|
||||||
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
|
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
|
||||||
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
|
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
|
||||||
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
|
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
|
||||||
@ -2050,30 +2185,43 @@ struct clip_model_loader {
|
|||||||
isize, isize*3, // 336, 1008
|
isize, isize*3, // 336, 1008
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
{
|
||||||
|
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
|
||||||
|
if (hparams.n_mel_bins != 128) {
|
||||||
|
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
||||||
|
}
|
||||||
|
hparams.ffn_op = FFN_GELU_ERF;
|
||||||
|
log_ffn_op = "gelu_erf"; // temporary solution for logging
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||||
|
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
|
||||||
|
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
|
||||||
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
|
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
|
||||||
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
|
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
|
||||||
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
|
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
|
||||||
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
|
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
|
||||||
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
|
|
||||||
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
|
|
||||||
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
|
|
||||||
LOG_INF("\n");
|
|
||||||
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
|
||||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
|
||||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
|
||||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
|
||||||
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
|
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
|
||||||
|
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
|
||||||
|
LOG_INF("\n");
|
||||||
|
if (hparams.has_vision) {
|
||||||
|
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
|
||||||
|
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
|
||||||
|
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
||||||
|
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
||||||
|
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||||
|
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||||
|
} else if (hparams.has_audio) {
|
||||||
|
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
|
||||||
|
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
|
||||||
|
}
|
||||||
|
LOG_INF("\n");
|
||||||
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||||
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
|
|
||||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2082,6 +2230,9 @@ struct clip_model_loader {
|
|||||||
std::map<std::string, size_t> tensor_offset;
|
std::map<std::string, size_t> tensor_offset;
|
||||||
std::vector<ggml_tensor *> tensors_to_load;
|
std::vector<ggml_tensor *> tensors_to_load;
|
||||||
|
|
||||||
|
// TODO @ngxson : support both audio and video in the future
|
||||||
|
const char * prefix = hparams.has_audio ? "a" : "v";
|
||||||
|
|
||||||
// get offsets
|
// get offsets
|
||||||
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
|
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
|
||||||
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
|
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
|
||||||
@ -2119,47 +2270,47 @@ struct clip_model_loader {
|
|||||||
|
|
||||||
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
||||||
|
|
||||||
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
|
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
|
||||||
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
|
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
|
||||||
|
|
||||||
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
|
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
|
||||||
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
|
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
|
||||||
|
|
||||||
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
|
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
|
||||||
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
|
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
|
||||||
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
|
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
|
||||||
|
|
||||||
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
|
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
|
||||||
|
|
||||||
// layers
|
// layers
|
||||||
vision_model.layers.resize(hparams.n_layer);
|
vision_model.layers.resize(hparams.n_layer);
|
||||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||||
auto & layer = vision_model.layers[il];
|
auto & layer = vision_model.layers[il];
|
||||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
|
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
|
||||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
|
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
|
||||||
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
|
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
|
||||||
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
|
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
|
||||||
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
|
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
|
||||||
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
|
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
|
||||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
|
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
|
||||||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
|
||||||
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
|
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
|
||||||
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
|
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
|
||||||
|
|
||||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
|
||||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
|
||||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
|
||||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
|
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
|
||||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
|
||||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
|
||||||
|
|
||||||
// ffn
|
// ffn
|
||||||
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
|
||||||
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
|
||||||
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
|
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
|
||||||
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
|
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
|
||||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
|
||||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
|
||||||
|
|
||||||
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
||||||
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
||||||
@ -2301,6 +2452,17 @@ struct clip_model_loader {
|
|||||||
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||||
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
{
|
||||||
|
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||||
|
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||||
|
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||||
|
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||||
|
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||||
|
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||||
|
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||||
|
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_INTERNVL:
|
case PROJECTOR_TYPE_INTERNVL:
|
||||||
{
|
{
|
||||||
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||||
@ -2358,13 +2520,19 @@ struct clip_model_loader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void alloc_compute_meta() {
|
void alloc_compute_meta() {
|
||||||
|
const auto & hparams = ctx_clip.vision_model.hparams;
|
||||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||||
|
|
||||||
// create a fake batch
|
// create a fake batch
|
||||||
clip_image_f32_batch batch;
|
clip_image_f32_batch batch;
|
||||||
clip_image_f32_ptr img(clip_image_f32_init());
|
clip_image_f32_ptr img(clip_image_f32_init());
|
||||||
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
|
if (hparams.has_vision) {
|
||||||
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
|
img->nx = hparams.warmup_image_size;
|
||||||
|
img->ny = hparams.warmup_image_size;
|
||||||
|
} else {
|
||||||
|
img->nx = 1024; // TODO @ngxson : use a better default
|
||||||
|
img->ny = hparams.n_mel_bins;
|
||||||
|
}
|
||||||
img->buf.resize(img->nx * img->ny * 3);
|
img->buf.resize(img->nx * img->ny * 3);
|
||||||
batch.entries.push_back(std::move(img));
|
batch.entries.push_back(std::move(img));
|
||||||
|
|
||||||
@ -3278,6 +3446,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||||||
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
|
||||||
n_patches /= (scale_factor * scale_factor);
|
n_patches /= (scale_factor * scale_factor);
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
|
||||||
|
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
|
||||||
|
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
||||||
|
n_patches = n_len / proj_stack_factor / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
return n_patches;
|
return n_patches;
|
||||||
@ -3435,7 +3607,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
};
|
};
|
||||||
|
|
||||||
// set input pixel values
|
// set input pixel values
|
||||||
{
|
if (!imgs.is_audio) {
|
||||||
size_t nelem = 0;
|
size_t nelem = 0;
|
||||||
for (const auto & img : imgs.entries) {
|
for (const auto & img : imgs.entries) {
|
||||||
nelem += img->nx * img->ny * 3;
|
nelem += img->nx * img->ny * 3;
|
||||||
@ -3472,6 +3644,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
set_input_f32("inp_raw", inp_raw);
|
set_input_f32("inp_raw", inp_raw);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// audio input
|
||||||
|
GGML_ASSERT(imgs.entries.size() == 1);
|
||||||
|
const auto & mel_inp = imgs.entries[0];
|
||||||
|
const int n_step = mel_inp->nx;
|
||||||
|
const int n_mel = mel_inp->ny;
|
||||||
|
std::vector<float> inp_raw(n_step * n_mel);
|
||||||
|
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
|
||||||
|
set_input_f32("inp_raw", inp_raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
// set input per projector
|
// set input per projector
|
||||||
@ -3668,6 +3850,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
case PROJECTOR_TYPE_IDEFICS3:
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
case PROJECTOR_TYPE_INTERNVL:
|
case PROJECTOR_TYPE_INTERNVL:
|
||||||
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
{
|
{
|
||||||
// do nothing
|
// do nothing
|
||||||
} break;
|
} break;
|
||||||
@ -3766,6 +3949,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||||
case PROJECTOR_TYPE_IDEFICS3:
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
return ctx->vision_model.projection->ne[1];
|
return ctx->vision_model.projection->ne[1];
|
||||||
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
return ctx->vision_model.mm_2_w->ne[1];
|
||||||
case PROJECTOR_TYPE_INTERNVL:
|
case PROJECTOR_TYPE_INTERNVL:
|
||||||
return ctx->vision_model.mm_3_w->ne[1];
|
return ctx->vision_model.mm_3_w->ne[1];
|
||||||
case PROJECTOR_TYPE_LLAMA4:
|
case PROJECTOR_TYPE_LLAMA4:
|
||||||
@ -3798,6 +3983,14 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
|
|||||||
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
|
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
|
||||||
|
return ctx->vision_model.hparams.has_vision;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
|
||||||
|
return ctx->vision_model.hparams.has_audio;
|
||||||
|
}
|
||||||
|
|
||||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||||
clip_image_f32 clip_img;
|
clip_image_f32 clip_img;
|
||||||
clip_img.buf.resize(h * w * 3);
|
clip_img.buf.resize(h * w * 3);
|
||||||
@ -3818,3 +4011,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
|
|||||||
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
|
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
|
||||||
return ctx->proj_type;
|
return ctx->proj_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
|
||||||
|
clip_image_f32 * audio = new clip_image_f32;
|
||||||
|
audio->nx = n_frames;
|
||||||
|
audio->ny = n_mel;
|
||||||
|
audio->buf.resize(n_frames * n_mel);
|
||||||
|
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
|
||||||
|
|
||||||
|
batch->entries.push_back(clip_image_f32_ptr(audio));
|
||||||
|
batch->is_audio = true;
|
||||||
|
}
|
||||||
|
@ -93,3 +93,9 @@ bool clip_is_llava(const struct clip_ctx * ctx);
|
|||||||
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||||
|
|
||||||
|
// use by audio input
|
||||||
|
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
|
||||||
|
|
||||||
|
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||||
|
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||||
|
93468
tools/mtmd/miniaudio.h
Normal file
93468
tools/mtmd/miniaudio.h
Normal file
File diff suppressed because it is too large
Load Diff
855
tools/mtmd/mtmd-audio.cpp
Normal file
855
tools/mtmd/mtmd-audio.cpp
Normal file
@ -0,0 +1,855 @@
|
|||||||
|
// fix problem with std::min and std::max
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
# define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "mtmd-audio.h"
|
||||||
|
|
||||||
|
//#define MTMD_AUDIO_DEBUG
|
||||||
|
|
||||||
|
#define MINIAUDIO_IMPLEMENTATION
|
||||||
|
#ifndef MTMD_AUDIO_DEBUG
|
||||||
|
# define MA_NO_ENCODING
|
||||||
|
#endif
|
||||||
|
#define MA_NO_DEVICE_IO
|
||||||
|
#define MA_NO_RESOURCE_MANAGER
|
||||||
|
#define MA_NO_NODE_GRAPH
|
||||||
|
#define MA_NO_ENGINE
|
||||||
|
#define MA_NO_GENERATION
|
||||||
|
#define MA_API static
|
||||||
|
#include "miniaudio.h"
|
||||||
|
|
||||||
|
#define _USE_MATH_DEFINES // for M_PI
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
// most of the code here is copied from whisper.cpp
|
||||||
|
|
||||||
|
// align x to upper multiple of n
|
||||||
|
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||||
|
|
||||||
|
namespace whisper_preprocessor {
|
||||||
|
|
||||||
|
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
||||||
|
namespace {
|
||||||
|
struct whisper_global_cache {
|
||||||
|
// In FFT, we frequently use sine and cosine operations with the same values.
|
||||||
|
// We can use precalculated values to speed up the process.
|
||||||
|
float sin_vals[SIN_COS_N_COUNT];
|
||||||
|
float cos_vals[SIN_COS_N_COUNT];
|
||||||
|
|
||||||
|
// Hann window (Use cosf to eliminate difference)
|
||||||
|
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||||
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
||||||
|
float hann_window[WHISPER_N_FFT];
|
||||||
|
|
||||||
|
whisper_global_cache() {
|
||||||
|
fill_sin_cos_table();
|
||||||
|
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_sin_cos_table() {
|
||||||
|
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
||||||
|
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
||||||
|
sin_vals[i] = sinf(theta);
|
||||||
|
cos_vals[i] = cosf(theta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_hann_window(int length, bool periodic, float * output) {
|
||||||
|
int offset = -1;
|
||||||
|
if (periodic) {
|
||||||
|
offset = 0;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} global_cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
// naive Discrete Fourier Transform
|
||||||
|
// input is real-valued
|
||||||
|
// output is complex-valued
|
||||||
|
static void dft(const float* in, int N, float* out) {
|
||||||
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||||
|
|
||||||
|
for (int k = 0; k < N; k++) {
|
||||||
|
float re = 0;
|
||||||
|
float im = 0;
|
||||||
|
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
||||||
|
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
||||||
|
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
out[k*2 + 0] = re;
|
||||||
|
out[k*2 + 1] = im;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cooley-Tukey FFT
|
||||||
|
// poor man's implementation - use something better
|
||||||
|
// input is real-valued
|
||||||
|
// output is complex-valued
|
||||||
|
static void fft(float* in, int N, float* out) {
|
||||||
|
if (N == 1) {
|
||||||
|
out[0] = in[0];
|
||||||
|
out[1] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int half_N = N / 2;
|
||||||
|
if (N - half_N*2 == 1) {
|
||||||
|
dft(in, N, out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* even = in + N;
|
||||||
|
for (int i = 0; i < half_N; ++i) {
|
||||||
|
even[i]= in[2*i];
|
||||||
|
}
|
||||||
|
float* even_fft = out + 2 * N;
|
||||||
|
fft(even, half_N, even_fft);
|
||||||
|
|
||||||
|
float* odd = even;
|
||||||
|
for (int i = 0; i < half_N; ++i) {
|
||||||
|
odd[i] = in[2*i + 1];
|
||||||
|
}
|
||||||
|
float* odd_fft = even_fft + N;
|
||||||
|
fft(odd, half_N, odd_fft);
|
||||||
|
|
||||||
|
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||||
|
for (int k = 0; k < half_N; k++) {
|
||||||
|
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||||
|
float re = global_cache.cos_vals[idx]; // cos(t)
|
||||||
|
float im = -global_cache.sin_vals[idx]; // sin(t)
|
||||||
|
|
||||||
|
float re_odd = odd_fft[2*k + 0];
|
||||||
|
float im_odd = odd_fft[2*k + 1];
|
||||||
|
|
||||||
|
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
||||||
|
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
||||||
|
|
||||||
|
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
||||||
|
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
||||||
|
int n_samples, int frame_size, int frame_step, int n_threads,
|
||||||
|
const whisper_filters & filters, whisper_mel & mel) {
|
||||||
|
std::vector<float> fft_in(frame_size * 2, 0.0);
|
||||||
|
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
||||||
|
|
||||||
|
int n_fft = filters.n_fft;
|
||||||
|
int i = ith;
|
||||||
|
|
||||||
|
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
||||||
|
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
|
||||||
|
|
||||||
|
// calculate FFT only when fft_in are not all zero
|
||||||
|
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
||||||
|
const int offset = i * frame_step;
|
||||||
|
|
||||||
|
// apply Hann window (~10% faster)
|
||||||
|
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
||||||
|
fft_in[j] = hann[j] * samples[offset + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// fill the rest with zeros
|
||||||
|
if (n_samples - offset < frame_size) {
|
||||||
|
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FFT
|
||||||
|
fft(fft_in.data(), frame_size, fft_out.data());
|
||||||
|
|
||||||
|
// Calculate modulus^2 of complex numbers
|
||||||
|
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||||
|
for (int j = 0; j < n_fft; j++) {
|
||||||
|
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mel spectrogram
|
||||||
|
for (int j = 0; j < mel.n_mel; j++) {
|
||||||
|
double sum = 0.0;
|
||||||
|
// unroll loop (suggested by GH user @lunixbochs)
|
||||||
|
int k = 0;
|
||||||
|
for (k = 0; k < n_fft - 3; k += 4) {
|
||||||
|
sum +=
|
||||||
|
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
|
||||||
|
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
|
||||||
|
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
||||||
|
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
||||||
|
}
|
||||||
|
// handle n_fft remainder
|
||||||
|
for (; k < n_fft; k++) {
|
||||||
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
||||||
|
}
|
||||||
|
sum = log10(std::max(sum, 1e-10));
|
||||||
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise fft_out are all zero
|
||||||
|
double sum = log10(1e-10);
|
||||||
|
for (; i < mel.n_len; i += n_threads) {
|
||||||
|
for (int j = 0; j < mel.n_mel; j++) {
|
||||||
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
||||||
|
static bool log_mel_spectrogram(
|
||||||
|
const float * samples,
|
||||||
|
const int n_samples,
|
||||||
|
const int /*sample_rate*/,
|
||||||
|
const int frame_size,
|
||||||
|
const int frame_step,
|
||||||
|
const int n_mel,
|
||||||
|
const int n_threads,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
const bool debug,
|
||||||
|
whisper_mel & mel) {
|
||||||
|
//const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Hann window
|
||||||
|
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
||||||
|
const float * hann = global_cache.hann_window;
|
||||||
|
|
||||||
|
// Calculate the length of padding
|
||||||
|
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
||||||
|
int64_t stage_2_pad = frame_size / 2;
|
||||||
|
|
||||||
|
// Initialize a vector and copy data from C array to it.
|
||||||
|
std::vector<float> samples_padded;
|
||||||
|
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
||||||
|
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
||||||
|
|
||||||
|
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
||||||
|
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
||||||
|
|
||||||
|
// reflective pad 200 samples at the beginning of audio
|
||||||
|
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
||||||
|
|
||||||
|
mel.n_mel = n_mel;
|
||||||
|
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
||||||
|
// Calculate number of frames + remove the last frame
|
||||||
|
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
|
||||||
|
// Calculate semi-padded sample length to ensure compatibility
|
||||||
|
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
||||||
|
mel.data.resize(mel.n_mel * mel.n_len);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<std::thread> workers(n_threads - 1);
|
||||||
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
|
workers[iw] = std::thread(
|
||||||
|
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
||||||
|
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
||||||
|
std::cref(filters), std::ref(mel));
|
||||||
|
}
|
||||||
|
|
||||||
|
// main thread
|
||||||
|
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
||||||
|
|
||||||
|
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||||
|
workers[iw].join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clamping and normalization
|
||||||
|
double mmax = -1e20;
|
||||||
|
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||||
|
if (mel.data[i] > mmax) {
|
||||||
|
mmax = mel.data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mmax -= 8.0;
|
||||||
|
|
||||||
|
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||||
|
if (mel.data[i] < mmax) {
|
||||||
|
mel.data[i] = mmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dump log_mel_spectrogram
|
||||||
|
if (debug) {
|
||||||
|
std::ofstream outFile("log_mel_spectrogram.json");
|
||||||
|
outFile << "[";
|
||||||
|
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
||||||
|
outFile << mel.data[i] << ", ";
|
||||||
|
}
|
||||||
|
outFile << mel.data[mel.data.size() - 1] << "]";
|
||||||
|
outFile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool preprocess_audio(
|
||||||
|
const float * samples,
|
||||||
|
size_t n_samples,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
std::vector<whisper_mel> & output) {
|
||||||
|
|
||||||
|
if (n_samples == 0) {
|
||||||
|
// empty audio
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_mel out_full;
|
||||||
|
bool ok = log_mel_spectrogram(
|
||||||
|
samples,
|
||||||
|
n_samples,
|
||||||
|
COMMON_SAMPLE_RATE,
|
||||||
|
WHISPER_N_FFT,
|
||||||
|
WHISPER_HOP_LENGTH,
|
||||||
|
filters.n_mel,
|
||||||
|
4, // n_threads
|
||||||
|
filters,
|
||||||
|
false, // debug
|
||||||
|
out_full);
|
||||||
|
if (!ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
|
||||||
|
// we always expect the mel to have 3000 silent frames at the end
|
||||||
|
// printf("n_len %d\n", out_full.n_len);
|
||||||
|
const size_t frames_per_chunk = 3000;
|
||||||
|
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
|
||||||
|
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
|
||||||
|
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
|
||||||
|
if ((size_t)n_len < frames_per_chunk) {
|
||||||
|
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_mel out_chunk;
|
||||||
|
out_chunk.n_len = n_len;
|
||||||
|
out_chunk.n_mel = out_full.n_mel;
|
||||||
|
out_chunk.n_len_org = out_full.n_mel; // unused
|
||||||
|
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
|
||||||
|
|
||||||
|
for (int i = 0; i < out_full.n_mel; i++) {
|
||||||
|
auto src = out_full.data.begin() + i*out_full.n_len + off;
|
||||||
|
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_back(std::move(out_chunk));
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace whisper_preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
namespace audio_helpers {
|
||||||
|
|
||||||
|
bool is_audio_file(const char * buf, size_t len) {
|
||||||
|
if (len < 12) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
|
||||||
|
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
|
||||||
|
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
|
||||||
|
bool is_mp3 = len >= 3 && (
|
||||||
|
memcmp(buf, "ID3", 3) == 0 ||
|
||||||
|
// Check for MPEG sync word (simplified check)
|
||||||
|
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
|
||||||
|
);
|
||||||
|
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
|
||||||
|
|
||||||
|
return is_wav || is_mp3 || is_flac;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true if the buffer is a valid audio file
|
||||||
|
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
|
||||||
|
ma_result result;
|
||||||
|
const int channels = 1;
|
||||||
|
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
|
||||||
|
ma_decoder decoder;
|
||||||
|
|
||||||
|
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ma_uint64 frame_count;
|
||||||
|
ma_uint64 frames_read;
|
||||||
|
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pcmf32_mono.resize(frame_count);
|
||||||
|
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
|
||||||
|
if (result != MA_SUCCESS) {
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef MTMD_AUDIO_DEBUG
|
||||||
|
// save audio to wav file
|
||||||
|
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
|
||||||
|
ma_encoder encoder;
|
||||||
|
ma_encoder_init_file("output.wav", &config, &encoder);
|
||||||
|
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
|
||||||
|
ma_encoder_uninit(&encoder);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ma_decoder_uninit(&decoder);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace wav_utils
|
||||||
|
|
||||||
|
|
||||||
|
// precalculated mel filter banks
|
||||||
|
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
|
||||||
|
//
|
||||||
|
// generated from python code:
|
||||||
|
//
|
||||||
|
// from numpy import load
|
||||||
|
// data = load('mel_filters.npz')
|
||||||
|
// lst = data.files
|
||||||
|
// for item in lst:
|
||||||
|
// print(item)
|
||||||
|
// print(data[item].shape)
|
||||||
|
// n_mel = data[item].shape[0]
|
||||||
|
// n_fft = data[item].shape[1]
|
||||||
|
// for i, row in enumerate(data[item]):
|
||||||
|
// for j, val in enumerate(row):
|
||||||
|
// val = val * 1000.0
|
||||||
|
// if val != 0:
|
||||||
|
// print(f"data[{i*n_fft + j}] = {val:.6f};")
|
||||||
|
|
||||||
|
namespace whisper_precalc_filters {
|
||||||
|
|
||||||
|
whisper_preprocessor::whisper_filters get_128_bins() {
|
||||||
|
whisper_preprocessor::whisper_filters filters;
|
||||||
|
filters.n_mel = 128;
|
||||||
|
filters.n_fft = 201;
|
||||||
|
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
|
||||||
|
|
||||||
|
data[1] = 12.37398665;
|
||||||
|
data[202] = 30.39256483;
|
||||||
|
data[404] = 24.74797331;
|
||||||
|
data[605] = 18.01857911;
|
||||||
|
data[807] = 37.12195903;
|
||||||
|
data[1008] = 5.64459199;
|
||||||
|
data[1009] = 6.72939420;
|
||||||
|
data[1210] = 36.03715822;
|
||||||
|
data[1412] = 19.10337992;
|
||||||
|
data[1613] = 23.66316877;
|
||||||
|
data[1815] = 31.47736564;
|
||||||
|
data[2016] = 11.28918398;
|
||||||
|
data[2017] = 1.08480197;
|
||||||
|
data[2218] = 41.68175161;
|
||||||
|
data[2420] = 13.45878839;
|
||||||
|
data[2621] = 29.30776216;
|
||||||
|
data[2823] = 25.83277412;
|
||||||
|
data[3024] = 16.93377644;
|
||||||
|
data[3226] = 38.20675984;
|
||||||
|
data[3427] = 4.55979025;
|
||||||
|
data[3428] = 7.81419594;
|
||||||
|
data[3629] = 34.95235741;
|
||||||
|
data[3831] = 20.18818259;
|
||||||
|
data[4032] = 22.57836796;
|
||||||
|
data[4234] = 32.56217018;
|
||||||
|
data[4435] = 10.20438317;
|
||||||
|
data[4436] = 2.16960395;
|
||||||
|
data[4637] = 40.59694707;
|
||||||
|
data[4839] = 14.54358920;
|
||||||
|
data[5040] = 28.22295949;
|
||||||
|
data[5242] = 26.91757679;
|
||||||
|
data[5443] = 15.84897563;
|
||||||
|
data[5645] = 39.29156065;
|
||||||
|
data[5846] = 3.47498828;
|
||||||
|
data[5847] = 8.89899861;
|
||||||
|
data[6048] = 33.86755288;
|
||||||
|
data[6250] = 21.27298526;
|
||||||
|
data[6451] = 21.49356715;
|
||||||
|
data[6653] = 33.64697099;
|
||||||
|
data[6854] = 9.11958050;
|
||||||
|
data[6855] = 3.25440569;
|
||||||
|
data[7056] = 39.51214626;
|
||||||
|
data[7258] = 15.62839188;
|
||||||
|
data[7459] = 27.13815868;
|
||||||
|
data[7661] = 28.00237760;
|
||||||
|
data[7862] = 14.76417296;
|
||||||
|
data[8064] = 40.37636518;
|
||||||
|
data[8265] = 2.38068704;
|
||||||
|
data[8266] = 10.20263787;
|
||||||
|
data[8467] = 31.61146119;
|
||||||
|
data[8669] = 24.54700135;
|
||||||
|
data[8870] = 15.32919332;
|
||||||
|
data[8871] = 1.66583748;
|
||||||
|
data[9072] = 36.72905266;
|
||||||
|
data[9274] = 20.09709924;
|
||||||
|
data[9475] = 16.93102531;
|
||||||
|
data[9476] = 2.90265540;
|
||||||
|
data[9677] = 32.84499049;
|
||||||
|
data[9879] = 23.52004871;
|
||||||
|
data[10080] = 11.03894413;
|
||||||
|
data[10081] = 10.72582975;
|
||||||
|
data[10282] = 22.71829173;
|
||||||
|
data[10484] = 32.27872774;
|
||||||
|
data[10685] = 0.11626833;
|
||||||
|
data[10686] = 22.85348251;
|
||||||
|
data[10887] = 8.56344029;
|
||||||
|
data[10888] = 14.97978810;
|
||||||
|
data[11089] = 15.51398356;
|
||||||
|
data[11090] = 8.51490628;
|
||||||
|
data[11291] = 21.10680379;
|
||||||
|
data[11292] = 3.32652032;
|
||||||
|
data[11493] = 25.47064796;
|
||||||
|
data[11695] = 27.35907957;
|
||||||
|
data[11896] = 0.65853616;
|
||||||
|
data[11897] = 23.83812517;
|
||||||
|
data[12098] = 3.44359246;
|
||||||
|
data[12099] = 21.22455277;
|
||||||
|
data[12300] = 5.35842171;
|
||||||
|
data[12301] = 19.42555793;
|
||||||
|
data[12502] = 6.49324711;
|
||||||
|
data[12503] = 18.35542172;
|
||||||
|
data[12704] = 6.93138083;
|
||||||
|
data[12705] = 17.93504693;
|
||||||
|
data[12906] = 6.74968259;
|
||||||
|
data[12907] = 18.09151843;
|
||||||
|
data[13108] = 6.01899112;
|
||||||
|
data[13109] = 18.75767298;
|
||||||
|
data[13310] = 4.80452832;
|
||||||
|
data[13311] = 19.87172849;
|
||||||
|
data[13512] = 3.16627859;
|
||||||
|
data[13513] = 21.37690969;
|
||||||
|
data[13514] = 1.25317345;
|
||||||
|
data[13714] = 1.15934468;
|
||||||
|
data[13715] = 20.80361731;
|
||||||
|
data[13716] = 4.04486805;
|
||||||
|
data[13917] = 17.55363122;
|
||||||
|
data[13918] = 7.08320038;
|
||||||
|
data[14119] = 14.07538634;
|
||||||
|
data[14120] = 10.32655034;
|
||||||
|
data[14321] = 10.40921453;
|
||||||
|
data[14322] = 13.73696327;
|
||||||
|
data[14523] = 6.59187697;
|
||||||
|
data[14524] = 17.27988198;
|
||||||
|
data[14525] = 1.46804214;
|
||||||
|
data[14725] = 2.65681883;
|
||||||
|
data[14726] = 18.09193194;
|
||||||
|
data[14727] = 5.85655728;
|
||||||
|
data[14928] = 13.34277913;
|
||||||
|
data[14929] = 10.28267574;
|
||||||
|
data[15130] = 8.56800377;
|
||||||
|
data[15131] = 14.72230814;
|
||||||
|
data[15132] = 1.04039861;
|
||||||
|
data[15332] = 3.79085587;
|
||||||
|
data[15333] = 17.14678481;
|
||||||
|
data[15334] = 6.11609267;
|
||||||
|
data[15535] = 11.75929047;
|
||||||
|
data[15536] = 11.13393717;
|
||||||
|
data[15737] = 6.43857848;
|
||||||
|
data[15738] = 16.07806236;
|
||||||
|
data[15739] = 4.23917221;
|
||||||
|
data[15939] = 1.19989377;
|
||||||
|
data[15940] = 12.75671553;
|
||||||
|
data[15941] = 9.65298992;
|
||||||
|
data[16142] = 7.06935255;
|
||||||
|
data[16143] = 14.94054683;
|
||||||
|
data[16144] = 4.19024844;
|
||||||
|
data[16344] = 1.51483389;
|
||||||
|
data[16345] = 12.00899947;
|
||||||
|
data[16346] = 9.84823331;
|
||||||
|
data[16547] = 6.10224018;
|
||||||
|
data[16548] = 15.33857174;
|
||||||
|
data[16549] = 5.57676842;
|
||||||
|
data[16749] = 0.36827257;
|
||||||
|
data[16750] = 9.89749376;
|
||||||
|
data[16751] = 11.35340426;
|
||||||
|
data[16752] = 2.05122307;
|
||||||
|
data[16952] = 3.89297144;
|
||||||
|
data[16953] = 12.97352277;
|
||||||
|
data[16954] = 8.06631614;
|
||||||
|
data[17155] = 6.74493238;
|
||||||
|
data[17156] = 13.85874674;
|
||||||
|
data[17157] = 5.41190524;
|
||||||
|
data[17357] = 0.74220158;
|
||||||
|
data[17358] = 8.98779090;
|
||||||
|
data[17359] = 11.37871388;
|
||||||
|
data[17360] = 3.32958088;
|
||||||
|
data[17560] = 2.82313535;
|
||||||
|
data[17561] = 10.68049297;
|
||||||
|
data[17562] = 9.43340641;
|
||||||
|
data[17563] = 1.76325557;
|
||||||
|
data[17763] = 4.39018616;
|
||||||
|
data[17764] = 11.87758986;
|
||||||
|
data[17765] = 7.97005836;
|
||||||
|
data[17766] = 0.66104700;
|
||||||
|
data[17966] = 5.49466675;
|
||||||
|
data[17967] = 12.62953598;
|
||||||
|
data[17968] = 6.93987962;
|
||||||
|
data[18169] = 6.18401915;
|
||||||
|
data[18170] = 12.93473132;
|
||||||
|
data[18171] = 6.29778765;
|
||||||
|
data[18371] = 0.02325210;
|
||||||
|
data[18372] = 6.50206627;
|
||||||
|
data[18373] = 12.32661773;
|
||||||
|
data[18374] = 6.00216538;
|
||||||
|
data[18574] = 0.31548753;
|
||||||
|
data[18575] = 6.48925547;
|
||||||
|
data[18576] = 12.04130240;
|
||||||
|
data[18577] = 6.01462880;
|
||||||
|
data[18777] = 0.29979556;
|
||||||
|
data[18778] = 6.18288014;
|
||||||
|
data[18779] = 12.04272825;
|
||||||
|
data[18780] = 6.29981188;
|
||||||
|
data[18781] = 0.55689598;
|
||||||
|
data[18980] = 0.01120471;
|
||||||
|
data[18981] = 5.61729167;
|
||||||
|
data[18982] = 11.22337859;
|
||||||
|
data[18983] = 6.82516303;
|
||||||
|
data[18984] = 1.35264499;
|
||||||
|
data[19184] = 4.82410006;
|
||||||
|
data[19185] = 10.16623247;
|
||||||
|
data[19186] = 7.56075513;
|
||||||
|
data[19187] = 2.34590308;
|
||||||
|
data[19387] = 3.83235747;
|
||||||
|
data[19388] = 8.92296247;
|
||||||
|
data[19389] = 8.47910438;
|
||||||
|
data[19390] = 3.50978645;
|
||||||
|
data[19590] = 2.66873185;
|
||||||
|
data[19591] = 7.51965167;
|
||||||
|
data[19592] = 9.55500547;
|
||||||
|
data[19593] = 4.81966138;
|
||||||
|
data[19594] = 0.08431751;
|
||||||
|
data[19793] = 1.35767367;
|
||||||
|
data[19794] = 5.98019501;
|
||||||
|
data[19795] = 10.60271543;
|
||||||
|
data[19796] = 6.25298498;
|
||||||
|
data[19797] = 1.74059917;
|
||||||
|
data[19997] = 4.32644226;
|
||||||
|
data[19998] = 8.73131864;
|
||||||
|
data[19999] = 7.78916525;
|
||||||
|
data[20000] = 3.48923868;
|
||||||
|
data[20200] = 2.57835095;
|
||||||
|
data[20201] = 6.77582854;
|
||||||
|
data[20202] = 9.40941647;
|
||||||
|
data[20203] = 5.31194592;
|
||||||
|
data[20204] = 1.21447595;
|
||||||
|
data[20403] = 0.75411191;
|
||||||
|
data[20404] = 4.75395704;
|
||||||
|
data[20405] = 8.75380263;
|
||||||
|
data[20406] = 7.19209015;
|
||||||
|
data[20407] = 3.28754401;
|
||||||
|
data[20607] = 2.68179690;
|
||||||
|
data[20608] = 6.49331464;
|
||||||
|
data[20609] = 9.11457930;
|
||||||
|
data[20610] = 5.39387390;
|
||||||
|
data[20611] = 1.67316827;
|
||||||
|
data[20810] = 0.57394296;
|
||||||
|
data[20811] = 4.20600036;
|
||||||
|
data[20812] = 7.83805829;
|
||||||
|
data[20813] = 7.52023002;
|
||||||
|
data[20814] = 3.97470826;
|
||||||
|
data[20815] = 0.42918732;
|
||||||
|
data[21014] = 1.90464477;
|
||||||
|
data[21015] = 5.36569161;
|
||||||
|
data[21016] = 8.82673822;
|
||||||
|
data[21017] = 6.27609482;
|
||||||
|
data[21018] = 2.89750961;
|
||||||
|
data[21218] = 2.89885257;
|
||||||
|
data[21219] = 6.19694078;
|
||||||
|
data[21220] = 8.56699049;
|
||||||
|
data[21221] = 5.34748193;
|
||||||
|
data[21222] = 2.12797290;
|
||||||
|
data[21421] = 0.44750227;
|
||||||
|
data[21422] = 3.59030394;
|
||||||
|
data[21423] = 6.73310598;
|
||||||
|
data[21424] = 7.77023612;
|
||||||
|
data[21425] = 4.70231380;
|
||||||
|
data[21426] = 1.63439126;
|
||||||
|
data[21625] = 1.01536023;
|
||||||
|
data[21626] = 4.01018746;
|
||||||
|
data[21627] = 7.00501446;
|
||||||
|
data[21628] = 7.23442994;
|
||||||
|
data[21629] = 4.31095669;
|
||||||
|
data[21630] = 1.38748321;
|
||||||
|
data[21829] = 1.33348850;
|
||||||
|
data[21830] = 4.18730825;
|
||||||
|
data[21831] = 7.04112789;
|
||||||
|
data[21832] = 6.93188375;
|
||||||
|
data[21833] = 4.14605811;
|
||||||
|
data[21834] = 1.36023236;
|
||||||
|
data[22033] = 1.42879714;
|
||||||
|
data[22034] = 4.14824858;
|
||||||
|
data[22035] = 6.86769979;
|
||||||
|
data[22036] = 6.83705276;
|
||||||
|
data[22037] = 4.18239459;
|
||||||
|
data[22038] = 1.52773573;
|
||||||
|
data[22237] = 1.32610439;
|
||||||
|
data[22238] = 3.91751388;
|
||||||
|
data[22239] = 6.50892360;
|
||||||
|
data[22240] = 6.92639686;
|
||||||
|
data[22241] = 4.39672917;
|
||||||
|
data[22242] = 1.86706171;
|
||||||
|
data[22441] = 1.04827771;
|
||||||
|
data[22442] = 3.51767405;
|
||||||
|
data[22443] = 5.98707050;
|
||||||
|
data[22444] = 7.17824046;
|
||||||
|
data[22445] = 4.76767914;
|
||||||
|
data[22446] = 2.35711760;
|
||||||
|
data[22645] = 0.61636406;
|
||||||
|
data[22646] = 2.96949223;
|
||||||
|
data[22647] = 5.32262027;
|
||||||
|
data[22648] = 7.57265091;
|
||||||
|
data[22649] = 5.27558755;
|
||||||
|
data[22650] = 2.97852419;
|
||||||
|
data[22651] = 0.68146095;
|
||||||
|
data[22849] = 0.04971400;
|
||||||
|
data[22850] = 2.29204819;
|
||||||
|
data[22851] = 4.53438237;
|
||||||
|
data[22852] = 6.77671656;
|
||||||
|
data[22853] = 5.90240723;
|
||||||
|
data[22854] = 3.71349836;
|
||||||
|
data[22855] = 1.52458926;
|
||||||
|
data[23054] = 1.50285335;
|
||||||
|
data[23055] = 3.63961048;
|
||||||
|
data[23056] = 5.77636715;
|
||||||
|
data[23057] = 6.63159089;
|
||||||
|
data[23058] = 4.54574358;
|
||||||
|
data[23059] = 2.45989650;
|
||||||
|
data[23060] = 0.37404924;
|
||||||
|
data[23258] = 0.61795861;
|
||||||
|
data[23259] = 2.65410915;
|
||||||
|
data[23260] = 4.69025923;
|
||||||
|
data[23261] = 6.72641024;
|
||||||
|
data[23262] = 5.46034705;
|
||||||
|
data[23263] = 3.47270933;
|
||||||
|
data[23264] = 1.48507138;
|
||||||
|
data[23463] = 1.59233576;
|
||||||
|
data[23464] = 3.53261665;
|
||||||
|
data[23465] = 5.47289755;
|
||||||
|
data[23466] = 6.44368259;
|
||||||
|
data[23467] = 4.54962999;
|
||||||
|
data[23468] = 2.65557761;
|
||||||
|
data[23469] = 0.76152512;
|
||||||
|
data[23667] = 0.46749352;
|
||||||
|
data[23668] = 2.31641904;
|
||||||
|
data[23669] = 4.16534441;
|
||||||
|
data[23670] = 6.01426978;
|
||||||
|
data[23671] = 5.67844696;
|
||||||
|
data[23672] = 3.87357362;
|
||||||
|
data[23673] = 2.06870004;
|
||||||
|
data[23674] = 0.26382666;
|
||||||
|
data[23872] = 1.05349103;
|
||||||
|
data[23873] = 2.81536230;
|
||||||
|
data[23874] = 4.57723346;
|
||||||
|
data[23875] = 6.33910485;
|
||||||
|
data[23876] = 5.12815686;
|
||||||
|
data[23877] = 3.40826320;
|
||||||
|
data[23878] = 1.68837002;
|
||||||
|
data[24077] = 1.43350090;
|
||||||
|
data[24078] = 3.11241671;
|
||||||
|
data[24079] = 4.79133241;
|
||||||
|
data[24080] = 6.40943693;
|
||||||
|
data[24081] = 4.77052201;
|
||||||
|
data[24082] = 3.13160778;
|
||||||
|
data[24083] = 1.49269309;
|
||||||
|
data[24281] = 0.02932359;
|
||||||
|
data[24282] = 1.62918994;
|
||||||
|
data[24283] = 3.22905602;
|
||||||
|
data[24284] = 4.82892245;
|
||||||
|
data[24285] = 6.14671456;
|
||||||
|
data[24286] = 4.58496623;
|
||||||
|
data[24287] = 3.02321767;
|
||||||
|
data[24288] = 1.46146910;
|
||||||
|
data[24486] = 0.13601698;
|
||||||
|
data[24487] = 1.66055572;
|
||||||
|
data[24488] = 3.18509457;
|
||||||
|
data[24489] = 4.70963307;
|
||||||
|
data[24490] = 6.04072399;
|
||||||
|
data[24491] = 4.55250870;
|
||||||
|
data[24492] = 3.06429295;
|
||||||
|
data[24493] = 1.57607743;
|
||||||
|
data[24494] = 0.08786193;
|
||||||
|
data[24691] = 0.09328097;
|
||||||
|
data[24692] = 1.54603878;
|
||||||
|
data[24693] = 2.99879676;
|
||||||
|
data[24694] = 4.45155473;
|
||||||
|
data[24695] = 5.90431225;
|
||||||
|
data[24696] = 4.65566106;
|
||||||
|
data[24697] = 3.23751615;
|
||||||
|
data[24698] = 1.81937125;
|
||||||
|
data[24699] = 0.40122634;
|
||||||
|
data[24897] = 1.30262633;
|
||||||
|
data[24898] = 2.68698297;
|
||||||
|
data[24899] = 4.07133950;
|
||||||
|
data[24900] = 5.45569602;
|
||||||
|
data[24901] = 4.87832492;
|
||||||
|
data[24902] = 3.52695142;
|
||||||
|
data[24903] = 2.17557792;
|
||||||
|
data[24904] = 0.82420459;
|
||||||
|
data[25102] = 0.94595028;
|
||||||
|
data[25103] = 2.26512621;
|
||||||
|
data[25104] = 3.58430226;
|
||||||
|
data[25105] = 4.90347855;
|
||||||
|
data[25106] = 5.20569785;
|
||||||
|
data[25107] = 3.91795207;
|
||||||
|
data[25108] = 2.63020652;
|
||||||
|
data[25109] = 1.34246063;
|
||||||
|
data[25110] = 0.05471494;
|
||||||
|
data[25307] = 0.49037894;
|
||||||
|
data[25308] = 1.74744334;
|
||||||
|
data[25309] = 3.00450763;
|
||||||
|
data[25310] = 4.26157191;
|
||||||
|
data[25311] = 5.51863620;
|
||||||
|
data[25312] = 4.39707236;
|
||||||
|
data[25313] = 3.16995848;
|
||||||
|
data[25314] = 1.94284460;
|
||||||
|
data[25315] = 0.71573065;
|
||||||
|
data[25513] = 1.14698056;
|
||||||
|
data[25514] = 2.34485767;
|
||||||
|
data[25515] = 3.54273478;
|
||||||
|
data[25516] = 4.74061165;
|
||||||
|
data[25517] = 4.95198462;
|
||||||
|
data[25518] = 3.78264743;
|
||||||
|
data[25519] = 2.61331047;
|
||||||
|
data[25520] = 1.44397374;
|
||||||
|
data[25521] = 0.27463681;
|
||||||
|
data[25718] = 0.47569509;
|
||||||
|
data[25719] = 1.61717169;
|
||||||
|
data[25720] = 2.75864848;
|
||||||
|
data[25721] = 3.90012516;
|
||||||
|
data[25722] = 5.04160160;
|
||||||
|
data[25723] = 4.45712078;
|
||||||
|
data[25724] = 3.34284059;
|
||||||
|
data[25725] = 2.22856039;
|
||||||
|
data[25726] = 1.11428020;
|
||||||
|
|
||||||
|
for (auto & val : data) {
|
||||||
|
val /= 1000.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
filters.data = std::move(data);
|
||||||
|
return filters;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace whisper_precalc_filters
|
62
tools/mtmd/mtmd-audio.h
Normal file
62
tools/mtmd/mtmd-audio.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#define WHISPER_ASSERT GGML_ASSERT
|
||||||
|
|
||||||
|
#define WHISPER_SAMPLE_RATE 16000
|
||||||
|
#define WHISPER_N_FFT 400
|
||||||
|
#define WHISPER_HOP_LENGTH 160
|
||||||
|
#define WHISPER_CHUNK_SIZE 30
|
||||||
|
|
||||||
|
#define COMMON_SAMPLE_RATE 16000
|
||||||
|
|
||||||
|
namespace whisper_preprocessor {
|
||||||
|
|
||||||
|
struct whisper_mel {
|
||||||
|
int n_len;
|
||||||
|
int n_len_org;
|
||||||
|
int n_mel;
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct whisper_filters {
|
||||||
|
int32_t n_mel;
|
||||||
|
int32_t n_fft;
|
||||||
|
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern bool preprocess_audio(
|
||||||
|
const float * samples,
|
||||||
|
size_t n_samples,
|
||||||
|
const whisper_filters & filters,
|
||||||
|
std::vector<whisper_mel> & output);
|
||||||
|
|
||||||
|
} // namespace whisper_preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
// TODO @ngxson : move this helper to mtmd-helpers.cpp
|
||||||
|
namespace audio_helpers {
|
||||||
|
|
||||||
|
extern bool is_audio_file(const char * buf, size_t len);
|
||||||
|
|
||||||
|
extern bool decode_audio_from_buf(
|
||||||
|
const unsigned char * buf_in,
|
||||||
|
size_t len,
|
||||||
|
int target_sampler_rate,
|
||||||
|
std::vector<float> & pcmf32_mono);
|
||||||
|
|
||||||
|
} // namespace audio_helpers
|
||||||
|
|
||||||
|
|
||||||
|
namespace whisper_precalc_filters {
|
||||||
|
|
||||||
|
extern whisper_preprocessor::whisper_filters get_128_bins();
|
||||||
|
|
||||||
|
} // namespace whisper_precalc_filters
|
@ -37,10 +37,10 @@ static volatile bool g_is_interrupted = false;
|
|||||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||||
LOG(
|
LOG(
|
||||||
"Experimental CLI for multimodal\n\n"
|
"Experimental CLI for multimodal\n\n"
|
||||||
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
|
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
|
||||||
" -m and --mmproj are required\n"
|
" -m and --mmproj are required\n"
|
||||||
" -hf user/repo can replace both -m and --mmproj in most cases\n"
|
" -hf user/repo can replace both -m and --mmproj in most cases\n"
|
||||||
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
|
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
|
||||||
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
|
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
|
||||||
argv[0]
|
argv[0]
|
||||||
);
|
);
|
||||||
@ -142,7 +142,7 @@ struct mtmd_cli_context {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_image(const std::string & fname) {
|
bool load_media(const std::string & fname) {
|
||||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
|
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
|
||||||
if (!bmp.ptr) {
|
if (!bmp.ptr) {
|
||||||
return false;
|
return false;
|
||||||
@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
|
|||||||
common_params params;
|
common_params params;
|
||||||
params.sampling.temp = 0.2; // lower temp by default for better quality
|
params.sampling.temp = 0.2; // lower temp by default for better quality
|
||||||
|
|
||||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,14 +283,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
if (is_single_turn) {
|
if (is_single_turn) {
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (params.prompt.find("<__image__>") == std::string::npos) {
|
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
|
||||||
params.prompt += " <__image__>";
|
params.prompt += mtmd_default_marker();
|
||||||
}
|
}
|
||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = params.prompt;
|
msg.content = params.prompt;
|
||||||
for (const auto & image : params.image) {
|
for (const auto & image : params.image) {
|
||||||
if (!ctx.load_image(image)) {
|
if (!ctx.load_media(image)) {
|
||||||
return 1; // error is already printed by libmtmd
|
return 1; // error is already printed by libmtmd
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -303,7 +303,12 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
LOG("\n Running in chat mode, available commands:");
|
LOG("\n Running in chat mode, available commands:");
|
||||||
LOG("\n /image <path> load an image");
|
if (mtmd_support_vision(ctx.ctx_vision.get())) {
|
||||||
|
LOG("\n /image <path> load an image");
|
||||||
|
}
|
||||||
|
if (mtmd_support_audio(ctx.ctx_vision.get())) {
|
||||||
|
LOG("\n /audio <path> load an audio");
|
||||||
|
}
|
||||||
LOG("\n /clear clear the chat history");
|
LOG("\n /clear clear the chat history");
|
||||||
LOG("\n /quit or /exit exit the program");
|
LOG("\n /quit or /exit exit the program");
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
@ -333,15 +338,17 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (line == "/image" || line.find("/image ") == 0) {
|
bool is_image = line == "/image" || line.find("/image ") == 0;
|
||||||
|
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
|
||||||
|
if (is_image || is_audio) {
|
||||||
if (line.size() < 8) {
|
if (line.size() < 8) {
|
||||||
LOG_ERR("ERR: Missing image filename\n");
|
LOG_ERR("ERR: Missing media filename\n");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string image = line.substr(7);
|
std::string media_path = line.substr(7);
|
||||||
if (ctx.load_image(image)) {
|
if (ctx.load_media(media_path)) {
|
||||||
LOG("Image %s loaded\n", image.c_str());
|
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
|
||||||
content += "<__image__>";
|
content += mtmd_default_marker();
|
||||||
}
|
}
|
||||||
// else, error is already printed by libmtmd
|
// else, error is already printed by libmtmd
|
||||||
continue;
|
continue;
|
||||||
|
@ -149,13 +149,10 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
int32_t n_batch,
|
int32_t n_batch,
|
||||||
llama_pos * new_n_past) {
|
llama_pos * new_n_past) {
|
||||||
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||||
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
|
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||||
return -1;
|
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
}
|
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
|
||||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
|
||||||
if (!image_tokens) {
|
|
||||||
LOG_ERR("failed to decode image chunk: image tokens are null\n");
|
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,15 +160,23 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||||||
int n_mmproj_embd = llama_model_n_embd(model);
|
int n_mmproj_embd = llama_model_n_embd(model);
|
||||||
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
||||||
|
|
||||||
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
|
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
|
||||||
int32_t i_batch = 0;
|
int32_t i_batch = 0;
|
||||||
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||||
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
||||||
|
|
||||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
|
||||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
|
||||||
|
|
||||||
if (mtmd_decode_use_mrope(ctx)) {
|
if (mtmd_decode_use_mrope(ctx)) {
|
||||||
|
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||||
|
if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (!image_tokens) {
|
||||||
|
LOG_ERR("failed to decode chunk: image tokens are null\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||||
|
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||||
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
|
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
|
||||||
} else {
|
} else {
|
||||||
batch_embd.set_position_normal(n_past, seq_id);
|
batch_embd.set_position_normal(n_past, seq_id);
|
||||||
@ -187,22 +192,22 @@ int32_t mtmd_helper_decode_image_chunk(
|
|||||||
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||||
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
||||||
|
|
||||||
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
|
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
int32_t ret = llama_decode(lctx, batch_embd_view);
|
int32_t ret = llama_decode(lctx, batch_embd_view);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG_ERR("failed to decode image\n");
|
LOG_ERR("failed to decode %s\n", name);
|
||||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||||
|
|
||||||
i_batch++;
|
i_batch++;
|
||||||
}
|
}
|
||||||
|
|
||||||
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
|
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||||
*new_n_past = n_past;
|
*new_n_past = n_past;
|
||||||
|
|
||||||
if (mtmd_decode_use_non_causal(ctx)) {
|
if (mtmd_decode_use_non_causal(ctx)) {
|
||||||
@ -253,25 +258,25 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
|||||||
*new_n_past += text_batch.n_tokens;
|
*new_n_past += text_batch.n_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
|
|
||||||
LOG_INF("encoding image or slice...\n");
|
LOG_INF("encoding %s slice...\n", name);
|
||||||
|
|
||||||
ret = mtmd_encode(ctx, image_tokens);
|
ret = mtmd_encode_chunk(ctx, chunk);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG_ERR("failed to encode image\n");
|
LOG_ERR("failed to encode %s slice\n", name);
|
||||||
llama_batch_free(text_batch);
|
llama_batch_free(text_batch);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||||
|
|
||||||
float * embd = mtmd_get_output_embd(ctx);
|
float * embd = mtmd_get_output_embd(ctx);
|
||||||
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG_ERR("failed to decode image\n");
|
LOG_ERR("failed to decode %s\n", name);
|
||||||
llama_batch_free(text_batch);
|
llama_batch_free(text_batch);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "clip-impl.h"
|
#include "clip-impl.h"
|
||||||
#include "mtmd.h"
|
#include "mtmd.h"
|
||||||
|
#include "mtmd-audio.h"
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
@ -19,17 +20,49 @@ struct mtmd_bitmap {
|
|||||||
uint32_t ny;
|
uint32_t ny;
|
||||||
std::vector<unsigned char> data;
|
std::vector<unsigned char> data;
|
||||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||||
|
bool is_audio = false; // true if the bitmap is audio
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mtmd_image_tokens_deleter {
|
struct mtmd_image_tokens {
|
||||||
void operator()(mtmd_image_tokens * val); // forward declaration
|
uint32_t nx; // number of tokens in x direction
|
||||||
|
uint32_t ny; // number of tokens in y direction
|
||||||
|
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||||
|
uint32_t n_tokens() const { return nx * ny; }
|
||||||
|
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||||
|
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||||
|
|
||||||
|
mtmd_image_tokens clone() {
|
||||||
|
return mtmd_image_tokens{
|
||||||
|
nx,
|
||||||
|
ny,
|
||||||
|
use_mrope_pos,
|
||||||
|
batch_f32.clone(),
|
||||||
|
id
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
|
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens>;
|
||||||
|
|
||||||
|
struct mtmd_audio_tokens {
|
||||||
|
uint32_t n_tokens; // number of tokens
|
||||||
|
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||||
|
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||||
|
|
||||||
|
mtmd_audio_tokens clone() {
|
||||||
|
return mtmd_audio_tokens{
|
||||||
|
n_tokens,
|
||||||
|
batch_f32.clone(),
|
||||||
|
id
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using mtmd_audio_tokens_ptr = std::unique_ptr<mtmd_audio_tokens>;
|
||||||
|
|
||||||
struct mtmd_input_chunk {
|
struct mtmd_input_chunk {
|
||||||
mtmd_input_chunk_type type;
|
mtmd_input_chunk_type type;
|
||||||
std::vector<llama_token> tokens_text;
|
std::vector<llama_token> tokens_text;
|
||||||
mtmd_image_tokens_ptr tokens_image;
|
mtmd_image_tokens_ptr tokens_image;
|
||||||
|
mtmd_audio_tokens_ptr tokens_audio;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mtmd_input_chunks {
|
struct mtmd_input_chunks {
|
||||||
@ -46,6 +79,10 @@ enum mtmd_slice_tmpl {
|
|||||||
// TODO @ngxson : add support for idefics (SmolVLM)
|
// TODO @ngxson : add support for idefics (SmolVLM)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const char * mtmd_default_marker() {
|
||||||
|
return "<__media__>";
|
||||||
|
}
|
||||||
|
|
||||||
mtmd_context_params mtmd_context_params_default() {
|
mtmd_context_params mtmd_context_params_default() {
|
||||||
mtmd_context_params params;
|
mtmd_context_params params;
|
||||||
params.use_gpu = true;
|
params.use_gpu = true;
|
||||||
@ -53,6 +90,7 @@ mtmd_context_params mtmd_context_params_default() {
|
|||||||
params.n_threads = 4;
|
params.n_threads = 4;
|
||||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||||
|
params.media_marker = mtmd_default_marker();
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +101,9 @@ struct mtmd_context {
|
|||||||
|
|
||||||
bool print_timings;
|
bool print_timings;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
std::string image_marker;
|
std::string media_marker;
|
||||||
|
bool has_vision;
|
||||||
|
bool has_audio;
|
||||||
|
|
||||||
// for llava-uhd style models, we need special tokens in-between slices
|
// for llava-uhd style models, we need special tokens in-between slices
|
||||||
// minicpmv calls them "slices", llama 4 calls them "tiles"
|
// minicpmv calls them "slices", llama 4 calls them "tiles"
|
||||||
@ -81,6 +121,9 @@ struct mtmd_context {
|
|||||||
|
|
||||||
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
|
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
|
||||||
|
|
||||||
|
// for whisper, we pre-calculate the mel filter bank
|
||||||
|
whisper_preprocessor::whisper_filters w_filters;
|
||||||
|
|
||||||
// TODO @ngxson : add timings
|
// TODO @ngxson : add timings
|
||||||
|
|
||||||
mtmd_context(const char * mmproj_fname,
|
mtmd_context(const char * mmproj_fname,
|
||||||
@ -89,8 +132,12 @@ struct mtmd_context {
|
|||||||
text_model (text_model),
|
text_model (text_model),
|
||||||
print_timings(ctx_params.print_timings),
|
print_timings(ctx_params.print_timings),
|
||||||
n_threads (ctx_params.n_threads),
|
n_threads (ctx_params.n_threads),
|
||||||
image_marker (ctx_params.image_marker)
|
media_marker (ctx_params.media_marker)
|
||||||
{
|
{
|
||||||
|
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
|
||||||
|
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
|
||||||
|
}
|
||||||
|
|
||||||
clip_context_params ctx_clip_params;
|
clip_context_params ctx_clip_params;
|
||||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||||
@ -99,7 +146,9 @@ struct mtmd_context {
|
|||||||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||||
}
|
}
|
||||||
|
|
||||||
use_mrope = clip_is_qwen2vl(ctx_clip);
|
has_vision = clip_has_vision_encoder(ctx_clip);
|
||||||
|
has_audio = clip_has_audio_encoder(ctx_clip);
|
||||||
|
use_mrope = clip_is_qwen2vl(ctx_clip);
|
||||||
|
|
||||||
projector_type proj = clip_get_projector_type(ctx_clip);
|
projector_type proj = clip_get_projector_type(ctx_clip);
|
||||||
int minicpmv_version = clip_is_minicpmv(ctx_clip);
|
int minicpmv_version = clip_is_minicpmv(ctx_clip);
|
||||||
@ -146,6 +195,21 @@ struct mtmd_context {
|
|||||||
tok_row_end_trail = true; // add trailing end-of-row token
|
tok_row_end_trail = true; // add trailing end-of-row token
|
||||||
ov_img_first = false; // overview image is last
|
ov_img_first = false; // overview image is last
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
|
||||||
|
// TODO @ngxson : check if model n_mel is 128 or 80
|
||||||
|
w_filters = whisper_precalc_filters::get_128_bins();
|
||||||
|
}
|
||||||
|
|
||||||
|
// warning messages
|
||||||
|
if (proj == PROJECTOR_TYPE_LLAMA4) {
|
||||||
|
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
|
||||||
|
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||||
|
}
|
||||||
|
if (has_audio) {
|
||||||
|
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
||||||
|
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~mtmd_context() {
|
~mtmd_context() {
|
||||||
@ -179,29 +243,6 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mtmd_image_tokens_data {
|
|
||||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
|
||||||
};
|
|
||||||
|
|
||||||
struct mtmd_image_tokens {
|
|
||||||
uint32_t nx; // number of tokens in x direction
|
|
||||||
uint32_t ny; // number of tokens in y direction
|
|
||||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
|
||||||
uint32_t n_tokens() const { return nx * ny; }
|
|
||||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
|
||||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
|
||||||
|
|
||||||
mtmd_image_tokens clone() {
|
|
||||||
return mtmd_image_tokens{
|
|
||||||
nx,
|
|
||||||
ny,
|
|
||||||
use_mrope_pos,
|
|
||||||
batch_f32.clone(),
|
|
||||||
id
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||||
const struct llama_model * text_model,
|
const struct llama_model * text_model,
|
||||||
const struct mtmd_context_params ctx_params) {
|
const struct mtmd_context_params ctx_params) {
|
||||||
@ -247,59 +288,63 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
auto vocab = llama_model_get_vocab(ctx->text_model);
|
auto vocab = llama_model_get_vocab(ctx->text_model);
|
||||||
|
|
||||||
std::string prompt_modified(text->text);
|
std::string prompt_modified(text->text);
|
||||||
std::string marker_modified(ctx->image_marker);
|
std::string marker_modified(ctx->media_marker);
|
||||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||||
|
|
||||||
|
// for compatibility, we convert image marker to media marker
|
||||||
|
string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
|
||||||
|
|
||||||
// a bit hacky here, but works for now
|
// a bit hacky here, but works for now
|
||||||
// for some models, we need to add prefix and suffix to the image embeddings
|
// for some models, we need to add prefix and suffix to the image embeddings
|
||||||
if (clip_is_gemma3(ctx->ctx_clip)) {
|
if (clip_is_gemma3(ctx->ctx_clip)) {
|
||||||
// gemma 3
|
// gemma 3
|
||||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||||
marker_modified = ctx->image_marker + "[IMG_END]";
|
marker_modified = ctx->media_marker + "[IMG_END]";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||||
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
|
marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
|
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
|
||||||
// (more details in mtmd_context constructor)
|
// (more details in mtmd_context constructor)
|
||||||
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
|
marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
|
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
|
||||||
// <img> ... (image embeddings) ... </img>
|
// <img> ... (image embeddings) ... </img>
|
||||||
marker_modified = "<img>" + ctx->image_marker + "</img>";
|
marker_modified = "<img>" + ctx->media_marker + "</img>";
|
||||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||||
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
|
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
|
||||||
|
|
||||||
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
|
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
|
||||||
output->entries.clear();
|
output->entries.clear();
|
||||||
output->entries.reserve(parts.size());
|
output->entries.reserve(parts.size());
|
||||||
|
|
||||||
size_t i_img = 0;
|
size_t i_bm = 0;
|
||||||
|
|
||||||
// utility for adding raw tokens
|
// utility for adding raw tokens
|
||||||
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
|
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
|
||||||
mtmd_input_chunk chunk{
|
mtmd_input_chunk chunk{
|
||||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
std::move(tokens),
|
std::move(tokens),
|
||||||
{},
|
nullptr, // image tokens
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
output->entries.emplace_back(std::move(chunk));
|
output->entries.emplace_back(std::move(chunk));
|
||||||
};
|
};
|
||||||
@ -317,8 +362,9 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
|
|
||||||
mtmd_input_chunk chunk{
|
mtmd_input_chunk chunk{
|
||||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
{},
|
{}, // text tokens
|
||||||
std::move(image_tokens),
|
std::move(image_tokens),
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
chunks.emplace_back(std::move(chunk));
|
chunks.emplace_back(std::move(chunk));
|
||||||
}
|
}
|
||||||
@ -336,24 +382,36 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
mtmd_input_chunk chunk{
|
mtmd_input_chunk chunk{
|
||||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
std::move(tokens),
|
std::move(tokens),
|
||||||
{},
|
nullptr, // image tokens
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
output->entries.emplace_back(std::move(chunk));
|
output->entries.emplace_back(std::move(chunk));
|
||||||
|
|
||||||
if (&parts.back() != &part) {
|
// only add image/audio tokens to middle of 2 parts
|
||||||
// add image token to middle of 2 parts
|
// therefore, we skip handling image/audio if this is the last part
|
||||||
|
if (&parts.back() == &part) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (i_img >= n_bitmaps) {
|
if (!bitmaps[i_bm]->is_audio) {
|
||||||
|
// handle image
|
||||||
|
|
||||||
|
if (i_bm >= n_bitmaps) {
|
||||||
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!ctx->has_vision) {
|
||||||
|
LOG_ERR("%s: error: model does not support vision input\n", __func__);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
// convert mtmd_bitmap to clip_image_u8
|
// convert mtmd_bitmap to clip_image_u8
|
||||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||||
img_u8->nx = bitmaps[i_img]->nx;
|
img_u8->nx = bitmaps[i_bm]->nx;
|
||||||
img_u8->ny = bitmaps[i_img]->ny;
|
img_u8->ny = bitmaps[i_bm]->ny;
|
||||||
img_u8->buf.resize(bitmaps[i_img]->data.size());
|
img_u8->buf.resize(bitmaps[i_bm]->data.size());
|
||||||
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
|
std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
|
||||||
|
|
||||||
// preprocess image
|
// preprocess image
|
||||||
clip_image_f32_batch batch_f32;
|
clip_image_f32_batch batch_f32;
|
||||||
@ -370,7 +428,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
|
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
|
||||||
) {
|
) {
|
||||||
// split batch into chunks of single images
|
// split batch into chunks of single images
|
||||||
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
|
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
|
||||||
GGML_ASSERT(chunks.size() > 0);
|
GGML_ASSERT(chunks.size() > 0);
|
||||||
|
|
||||||
auto ov_chunk = std::move(chunks.front());
|
auto ov_chunk = std::move(chunks.front());
|
||||||
@ -446,7 +504,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
image_tokens->ny = 1;
|
image_tokens->ny = 1;
|
||||||
}
|
}
|
||||||
image_tokens->batch_f32 = std::move(batch_f32);
|
image_tokens->batch_f32 = std::move(batch_f32);
|
||||||
image_tokens->id = bitmaps[i_img]->id; // optional
|
image_tokens->id = bitmaps[i_bm]->id; // optional
|
||||||
|
|
||||||
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
|
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
|
||||||
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
|
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
|
||||||
@ -454,23 +512,101 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
|
|
||||||
mtmd_input_chunk chunk{
|
mtmd_input_chunk chunk{
|
||||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
{},
|
{}, // text tokens
|
||||||
std::move(image_tokens),
|
std::move(image_tokens),
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
output->entries.emplace_back(std::move(chunk));
|
output->entries.emplace_back(std::move(chunk));
|
||||||
}
|
}
|
||||||
|
|
||||||
i_img++; // move to next image
|
i_bm++; // move to next image
|
||||||
|
continue;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// handle audio
|
||||||
|
|
||||||
|
if (i_bm >= n_bitmaps) {
|
||||||
|
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ctx->has_audio) {
|
||||||
|
LOG_ERR("%s: error: model does not support audio input\n", __func__);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bitmaps[i_bm]->data.size() == 0) {
|
||||||
|
LOG_ERR("%s: error: empty audio data\n", __func__);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// preprocess audio
|
||||||
|
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
|
||||||
|
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
|
||||||
|
const float * samples = (const float *)bitmaps[i_bm]->data.data();
|
||||||
|
size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
|
||||||
|
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
|
||||||
|
if (!ok) {
|
||||||
|
LOG_ERR("Unable to preprocess audio\n");
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// consider each mel_spec as a separate audio chunk
|
||||||
|
// TODO: maybe support batching, but this may come with memory cost
|
||||||
|
for (auto & mel_spec : mel_spec_chunks) {
|
||||||
|
clip_image_f32_ptr mel_f32(clip_image_f32_init());
|
||||||
|
mel_f32->nx = mel_spec.n_len;
|
||||||
|
mel_f32->ny = mel_spec.n_mel;
|
||||||
|
mel_f32->buf = std::move(mel_spec.data);
|
||||||
|
size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
|
||||||
|
|
||||||
|
clip_image_f32_batch batch_f32;
|
||||||
|
batch_f32.is_audio = true;
|
||||||
|
batch_f32.entries.push_back(std::move(mel_f32));
|
||||||
|
|
||||||
|
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
|
||||||
|
audio_tokens->n_tokens = n_tokens;
|
||||||
|
audio_tokens->batch_f32 = std::move(batch_f32);
|
||||||
|
audio_tokens->id = bitmaps[i_bm]->id; // optional
|
||||||
|
|
||||||
|
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
|
||||||
|
|
||||||
|
mtmd_input_chunk chunk{
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
||||||
|
{}, // text tokens
|
||||||
|
nullptr, // image tokens
|
||||||
|
std::move(audio_tokens),
|
||||||
|
};
|
||||||
|
output->entries.emplace_back(std::move(chunk));
|
||||||
|
}
|
||||||
|
|
||||||
|
i_bm++;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
|
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||||
if (image_tokens) {
|
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
delete image_tokens;
|
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
|
||||||
|
return 0;
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
return mtmd_encode(ctx, chunk->tokens_image.get());
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||||
|
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
|
||||||
|
bool ok = clip_image_batch_encode(
|
||||||
|
ctx->ctx_clip,
|
||||||
|
ctx->n_threads,
|
||||||
|
&chunk->tokens_audio->batch_f32,
|
||||||
|
ctx->image_embd_v.data());
|
||||||
|
return ok ? 0 : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
||||||
@ -516,8 +652,12 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
|
|||||||
return ctx->use_mrope;
|
return ctx->use_mrope;
|
||||||
}
|
}
|
||||||
|
|
||||||
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
|
bool mtmd_support_vision(mtmd_context * ctx) {
|
||||||
mtmd_image_tokens_free(val);
|
return ctx->has_vision;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool mtmd_support_audio(mtmd_context * ctx) {
|
||||||
|
return ctx->has_audio;
|
||||||
}
|
}
|
||||||
|
|
||||||
// these 2 helpers below use internal clip_image_u8_ptr,
|
// these 2 helpers below use internal clip_image_u8_ptr,
|
||||||
@ -526,6 +666,15 @@ void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
|
|||||||
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
|
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
|
||||||
|
|
||||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
|
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
|
||||||
|
if (audio_helpers::is_audio_file((const char *)buf, len)) {
|
||||||
|
std::vector<float> pcmf32;
|
||||||
|
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
|
||||||
|
LOG_ERR("Unable to read WAV audio file from buffer\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
|
||||||
|
}
|
||||||
|
|
||||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||||
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
|
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
@ -538,15 +687,26 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
|
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
|
||||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
std::vector<unsigned char> buf;
|
||||||
bool ok = clip_image_load_from_file(fname, img_u8.get());
|
FILE * f = fopen(fname, "rb");
|
||||||
if (!ok) {
|
if (!f) {
|
||||||
LOG_ERR("Unable to load image %s\n", fname);
|
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
uint32_t nx, ny;
|
|
||||||
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
|
fseek(f, 0, SEEK_END);
|
||||||
return mtmd_bitmap_init(nx, ny, data);
|
long file_size = ftell(f);
|
||||||
|
fseek(f, 0, SEEK_SET);
|
||||||
|
buf.resize(file_size);
|
||||||
|
|
||||||
|
size_t n_read = fread(buf.data(), 1, file_size, f);
|
||||||
|
fclose(f);
|
||||||
|
if (n_read != (size_t)file_size) {
|
||||||
|
LOG_ERR("Failed to read entire file %s", fname);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mtmd_helper_bitmap_init_from_buf(buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -567,6 +727,18 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
|
|||||||
return bitmap;
|
return bitmap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
|
||||||
|
const float * data) {
|
||||||
|
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||||
|
bitmap->nx = n_samples;
|
||||||
|
bitmap->ny = 1;
|
||||||
|
bitmap->is_audio = true;
|
||||||
|
size_t data_size = n_samples * sizeof(float);
|
||||||
|
bitmap->data.resize(data_size);
|
||||||
|
std::memcpy(bitmap->data.data(), data, data_size);
|
||||||
|
return bitmap;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
|
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
|
||||||
return bitmap->nx;
|
return bitmap->nx;
|
||||||
}
|
}
|
||||||
@ -579,6 +751,10 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
|
|||||||
return bitmap->data.data();
|
return bitmap->data.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
|
||||||
|
return bitmap->is_audio;
|
||||||
|
}
|
||||||
|
|
||||||
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
|
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
|
||||||
return bitmap->id.c_str();
|
return bitmap->id.c_str();
|
||||||
}
|
}
|
||||||
@ -642,17 +818,56 @@ const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chu
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk) {
|
||||||
|
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
return chunk->tokens_text.size();
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
return mtmd_image_tokens_get_n_tokens(chunk->tokens_image.get());
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
return chunk->tokens_audio->n_tokens;
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid chunk type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk) {
|
||||||
|
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
return chunk->tokens_text.size();
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
return mtmd_image_tokens_get_n_pos(chunk->tokens_image.get());
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
return chunk->tokens_audio->n_tokens;
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid chunk type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk) {
|
||||||
|
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
return chunk->tokens_image->id.c_str();
|
||||||
|
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
|
return chunk->tokens_audio->id.c_str();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
|
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
|
||||||
mtmd_input_chunk * copy = new mtmd_input_chunk{
|
mtmd_input_chunk * copy = new mtmd_input_chunk{
|
||||||
chunk->type,
|
chunk->type,
|
||||||
chunk->tokens_text,
|
chunk->tokens_text,
|
||||||
mtmd_image_tokens_ptr(),
|
nullptr,
|
||||||
|
nullptr,
|
||||||
};
|
};
|
||||||
if (chunk->tokens_image) {
|
if (chunk->tokens_image) {
|
||||||
// copy the image tokens
|
// copy the image tokens
|
||||||
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
|
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
|
||||||
*copy->tokens_image = chunk->tokens_image->clone();
|
*copy->tokens_image = chunk->tokens_image->clone();
|
||||||
}
|
}
|
||||||
|
if (chunk->tokens_audio) {
|
||||||
|
// copy the audio tokens
|
||||||
|
copy->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
|
||||||
|
*copy->tokens_audio = chunk->tokens_audio->clone();
|
||||||
|
}
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -700,7 +915,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
|||||||
mtmd_input_chunk chunk_text{
|
mtmd_input_chunk chunk_text{
|
||||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
std::move(tokens_text),
|
std::move(tokens_text),
|
||||||
{},
|
nullptr, // image tokens
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
chunks->entries.emplace_back(std::move(chunk_text));
|
chunks->entries.emplace_back(std::move(chunk_text));
|
||||||
|
|
||||||
@ -712,8 +928,9 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
|||||||
image_tokens->id = "image_1";
|
image_tokens->id = "image_1";
|
||||||
mtmd_input_chunk chunk_image{
|
mtmd_input_chunk chunk_image{
|
||||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
{},
|
{}, // text tokens
|
||||||
std::move(image_tokens),
|
std::move(image_tokens),
|
||||||
|
nullptr, // audio tokens
|
||||||
};
|
};
|
||||||
chunks->entries.emplace_back(std::move(chunk_image));
|
chunks->entries.emplace_back(std::move(chunk_image));
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@
|
|||||||
# define MTMD_API
|
# define MTMD_API
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// deprecated marker, use mtmd_default_marker() instead
|
||||||
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
|
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
@ -48,6 +49,7 @@ extern "C" {
|
|||||||
enum mtmd_input_chunk_type {
|
enum mtmd_input_chunk_type {
|
||||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
||||||
};
|
};
|
||||||
|
|
||||||
// opaque types
|
// opaque types
|
||||||
@ -79,9 +81,12 @@ struct mtmd_context_params {
|
|||||||
bool print_timings;
|
bool print_timings;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
enum ggml_log_level verbosity;
|
enum ggml_log_level verbosity;
|
||||||
const char * image_marker;
|
const char * image_marker; // deprecated, use media_marker instead
|
||||||
|
const char * media_marker;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
MTMD_API const char * mtmd_default_marker(void);
|
||||||
|
|
||||||
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
||||||
|
|
||||||
// initialize the mtmd context
|
// initialize the mtmd context
|
||||||
@ -98,17 +103,26 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
|||||||
// whether the current model use M-RoPE for llama_decode
|
// whether the current model use M-RoPE for llama_decode
|
||||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether the current model supports vision input
|
||||||
|
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// whether the current model supports audio input
|
||||||
|
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
|
||||||
|
|
||||||
// mtmd_bitmap
|
// mtmd_bitmap
|
||||||
//
|
//
|
||||||
// length of data must be nx * ny * 3
|
// if bitmap is image:
|
||||||
// the data is in RGBRGBRGB... format
|
// length of data must be nx * ny * 3
|
||||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx,
|
// the data is in RGBRGBRGB... format
|
||||||
uint32_t ny,
|
// if bitmap is audio:
|
||||||
const unsigned char * data);
|
// length of data must be n_samples * sizeof(float)
|
||||||
|
// the data is in float format (PCM F32)
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||||
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
||||||
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||||
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||||
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
|
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
|
||||||
|
MTMD_API bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap);
|
||||||
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||||
// bitmap ID is optional, but useful for KV cache tracking
|
// bitmap ID is optional, but useful for KV cache tracking
|
||||||
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
||||||
@ -132,6 +146,11 @@ MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chu
|
|||||||
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
|
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
|
||||||
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
||||||
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
|
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
|
||||||
|
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
||||||
|
// returns nullptr for ID on text chunk
|
||||||
|
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
||||||
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||||
|
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
||||||
// you can move the chunk ownership to your own code by copying it
|
// you can move the chunk ownership to your own code by copying it
|
||||||
@ -144,27 +163,28 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
|
|||||||
//
|
//
|
||||||
// the instance will be constructed via mtmd_tokenize()
|
// the instance will be constructed via mtmd_tokenize()
|
||||||
// it will be freed along with mtmd_input_chunk
|
// it will be freed along with mtmd_input_chunk
|
||||||
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
|
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||||
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
|
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
|
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
|
|
||||||
// tokenize an input text prompt and an image
|
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||||
// the prompt must have the input image marker (default: "<__image__>") in it
|
// the prompt must have the input image marker (default: "<__media__>") in it
|
||||||
// the marker will be replaced with the image tokens
|
// the default marker is defined by mtmd_default_marker()
|
||||||
|
// the marker will be replaced with the image/audio chunk
|
||||||
// for example:
|
// for example:
|
||||||
// "here is an image: <__image__>\ndescribe it in detail."
|
// "here is an image: <__media__>\ndescribe it in detail."
|
||||||
// this will gives 3 chunks:
|
// this will gives 3 chunks:
|
||||||
// 1. "here is an image: <start_of_image>"
|
// 1. "here is an image: <start_of_image>"
|
||||||
// 2. (image tokens)
|
// 2. (image/audio tokens)
|
||||||
// 3. "<end_of_image>\ndescribe it in detail."
|
// 3. "<end_of_image>\ndescribe it in detail."
|
||||||
// number of bitmaps must be equal to the number of image markers in the prompt
|
// number of bitmaps must be equal to the number of markers in the prompt
|
||||||
// this function is thread-safe (shared ctx)
|
// this function is thread-safe (shared ctx)
|
||||||
// return values:
|
// return values:
|
||||||
// 0 on success
|
// 0 on success
|
||||||
// 1 on number of images not matching the number of markers
|
// 1 on number of bitmaps not matching the number of markers
|
||||||
// 2 on image preprocessing error
|
// 2 on image preprocessing error
|
||||||
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||||
mtmd_input_chunks * output,
|
mtmd_input_chunks * output,
|
||||||
@ -173,9 +193,14 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
size_t n_bitmaps);
|
size_t n_bitmaps);
|
||||||
|
|
||||||
// returns 0 on success
|
// returns 0 on success
|
||||||
|
// TODO: deprecate
|
||||||
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||||
const mtmd_image_tokens * image_tokens);
|
const mtmd_image_tokens * image_tokens);
|
||||||
|
|
||||||
|
// returns 0 on success
|
||||||
|
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||||
|
const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
// get output embeddings from the last encode pass
|
// get output embeddings from the last encode pass
|
||||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||||
|
|
||||||
@ -189,12 +214,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
|||||||
//
|
//
|
||||||
|
|
||||||
// helper function to construct a mtmd_bitmap from a file
|
// helper function to construct a mtmd_bitmap from a file
|
||||||
|
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||||
// returns nullptr on failure
|
// returns nullptr on failure
|
||||||
// this function is thread-safe
|
// this function is thread-safe
|
||||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
|
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
|
||||||
|
|
||||||
// helper function to construct a mtmd_bitmap from a buffer containing a file
|
// helper function to construct a mtmd_bitmap from a buffer containing a file
|
||||||
// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
|
// supported formats:
|
||||||
|
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
|
||||||
|
// audio: formats supported by miniaudio: wav, mp3, flac
|
||||||
|
// note: audio files will be auto-detected based on magic bytes
|
||||||
// returns nullptr on failure
|
// returns nullptr on failure
|
||||||
// this function is thread-safe
|
// this function is thread-safe
|
||||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
|
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
|
||||||
|
@ -710,7 +710,7 @@ static json oaicompat_completion_params_parse(
|
|||||||
|
|
||||||
// replace this chunk with a marker
|
// replace this chunk with a marker
|
||||||
p["type"] = "text";
|
p["type"] = "text";
|
||||||
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
|
p["text"] = mtmd_default_marker();
|
||||||
p.erase("image_url");
|
p.erase("image_url");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user