mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd : add support for Qwen2-Audio and SeaLLM-Audio (#13760)
* mtmd : add Qwen2-Audio support * small clean up * update discussion link * clarify mtmd_get_output_embd * clarification in multimodal.md * fix ultravox bug * ggml_cont
This commit is contained in:
@ -2643,7 +2643,7 @@ class QwenModel(TextModel):
|
|||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
|
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
|
||||||
class Qwen2Model(TextModel):
|
class Qwen2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||||
|
|
||||||
@ -2667,8 +2667,9 @@ class Qwen2Model(TextModel):
|
|||||||
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
|
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
|
||||||
if "language_model." in name:
|
if "language_model." in name:
|
||||||
name = name.replace("language_model.", "") # for InternVL
|
name = name.replace("language_model.", "") # for InternVL
|
||||||
if name.startswith("mlp") or name.startswith("vision_model"):
|
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||||
# skip visual tensors
|
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||||
|
# skip vision and audio tensors
|
||||||
return []
|
return []
|
||||||
yield from super().modify_tensors(data_torch, name, bid)
|
yield from super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
@ -5993,11 +5994,11 @@ class UltravoxModel(TextModel):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
|
raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("UltravoxModel")
|
@ModelBase.register("Qwen2AudioForConditionalGeneration")
|
||||||
class UltravoxAudioModel(MmprojModel):
|
class WhisperEncoderModel(MmprojModel):
|
||||||
has_vision_encoder = False # no vision encoder
|
has_vision_encoder = False # no vision encoder
|
||||||
has_audio_encoder = True
|
has_audio_encoder = True
|
||||||
|
|
||||||
@ -6009,10 +6010,9 @@ class UltravoxAudioModel(MmprojModel):
|
|||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
|
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A)
|
||||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
|
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
|
||||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
|
||||||
|
|
||||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
del bid, new_name, n_dims # unused
|
del bid, new_name, n_dims # unused
|
||||||
@ -6023,6 +6023,10 @@ class UltravoxAudioModel(MmprojModel):
|
|||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
||||||
|
if name.startswith("language_model."):
|
||||||
|
# skip language model tensors
|
||||||
|
return []
|
||||||
|
|
||||||
# prevent clash naming with vision tensors
|
# prevent clash naming with vision tensors
|
||||||
if name.startswith("multi_modal_projector"):
|
if name.startswith("multi_modal_projector"):
|
||||||
name = "audio." + name
|
name = "audio." + name
|
||||||
@ -6033,6 +6037,16 @@ class UltravoxAudioModel(MmprojModel):
|
|||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("UltravoxModel")
|
||||||
|
class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
||||||
|
has_vision_encoder = False # no vision encoder
|
||||||
|
has_audio_encoder = True
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,4 +93,8 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
|||||||
# Ultravox 0.5
|
# Ultravox 0.5
|
||||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
|
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
|
||||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
|
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
|
||||||
|
|
||||||
|
# Qwen2-Audio and SeaLLM-Audio
|
||||||
|
# note: no pre-quantized GGUF this model, as they have very poor result
|
||||||
|
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
|
||||||
```
|
```
|
||||||
|
@ -546,6 +546,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
A_ENC_FFN_GATE = auto()
|
A_ENC_FFN_GATE = auto()
|
||||||
A_ENC_FFN_DOWN = auto()
|
A_ENC_FFN_DOWN = auto()
|
||||||
A_MMPROJ = auto()
|
A_MMPROJ = auto()
|
||||||
|
A_MMPROJ_FC = auto()
|
||||||
A_MM_NORM_PRE = auto()
|
A_MM_NORM_PRE = auto()
|
||||||
A_MM_NORM_MID = auto()
|
A_MM_NORM_MID = auto()
|
||||||
|
|
||||||
@ -825,6 +826,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
||||||
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
||||||
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
||||||
|
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
|
||||||
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
||||||
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
||||||
}
|
}
|
||||||
@ -885,6 +887,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.A_ENC_FFN_GATE,
|
MODEL_TENSOR.A_ENC_FFN_GATE,
|
||||||
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
||||||
MODEL_TENSOR.A_MMPROJ,
|
MODEL_TENSOR.A_MMPROJ,
|
||||||
|
MODEL_TENSOR.A_MMPROJ_FC,
|
||||||
MODEL_TENSOR.A_MM_NORM_PRE,
|
MODEL_TENSOR.A_MM_NORM_PRE,
|
||||||
MODEL_TENSOR.A_MM_NORM_MID,
|
MODEL_TENSOR.A_MM_NORM_MID,
|
||||||
],
|
],
|
||||||
@ -2256,6 +2259,7 @@ class VisionProjectorType:
|
|||||||
QWEN25VL = "qwen2.5vl_merger"
|
QWEN25VL = "qwen2.5vl_merger"
|
||||||
ULTRAVOX = "ultravox"
|
ULTRAVOX = "ultravox"
|
||||||
INTERNVL = "internvl"
|
INTERNVL = "internvl"
|
||||||
|
QWEN2A = "qwen2a" # audio
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
@ -1165,6 +1165,10 @@ class TensorNameMap:
|
|||||||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||||
|
"audio.multi_modal_projector.linear", # qwen2audio
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.A_MM_NORM_PRE: (
|
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||||
"audio.multi_modal_projector.ln_pre", # ultravox
|
"audio.multi_modal_projector.ln_pre", # ultravox
|
||||||
),
|
),
|
||||||
|
@ -107,6 +107,7 @@
|
|||||||
// ultravox
|
// ultravox
|
||||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||||
|
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||||
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
||||||
|
|
||||||
@ -128,6 +129,7 @@ enum projector_type {
|
|||||||
PROJECTOR_TYPE_ULTRAVOX,
|
PROJECTOR_TYPE_ULTRAVOX,
|
||||||
PROJECTOR_TYPE_INTERNVL,
|
PROJECTOR_TYPE_INTERNVL,
|
||||||
PROJECTOR_TYPE_LLAMA4,
|
PROJECTOR_TYPE_LLAMA4,
|
||||||
|
PROJECTOR_TYPE_QWEN2A,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -145,6 +147,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||||||
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
||||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||||
|
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||||
};
|
};
|
||||||
|
|
||||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||||
|
@ -254,7 +254,9 @@ struct clip_vision_model {
|
|||||||
ggml_tensor * post_ln_w;
|
ggml_tensor * post_ln_w;
|
||||||
ggml_tensor * post_ln_b;
|
ggml_tensor * post_ln_b;
|
||||||
|
|
||||||
ggml_tensor * projection;
|
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
||||||
|
ggml_tensor * mm_fc_w;
|
||||||
|
ggml_tensor * mm_fc_b;
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
ggml_tensor * mm_input_norm_w = nullptr;
|
ggml_tensor * mm_input_norm_w = nullptr;
|
||||||
@ -1471,48 +1473,58 @@ struct clip_graph {
|
|||||||
|
|
||||||
cb(cur, "after_transformer", -1);
|
cb(cur, "after_transformer", -1);
|
||||||
|
|
||||||
// StackAudioFrames
|
if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
|
||||||
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
// StackAudioFrames
|
||||||
{
|
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
|
||||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
|
||||||
int64_t pad = padded_len - ggml_nelements(cur);
|
|
||||||
if (pad > 0) {
|
|
||||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
|
||||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
|
||||||
}
|
|
||||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
|
||||||
ggml_row_size(cur->type, stride), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
cb(cur, "after_stacked", -1);
|
|
||||||
|
|
||||||
// UltravoxProjector
|
|
||||||
{
|
|
||||||
// pre-norm
|
|
||||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
|
||||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
|
||||||
|
|
||||||
// ffn in
|
|
||||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
|
||||||
|
|
||||||
// swiglu
|
|
||||||
{
|
{
|
||||||
int64_t split_point = cur->ne[0] / 2;
|
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
int64_t pad = padded_len - ggml_nelements(cur);
|
||||||
|
if (pad > 0) {
|
||||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||||
x1 = ggml_silu(ctx0, x1);
|
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||||
cur = ggml_mul(ctx0, x0, x1);
|
}
|
||||||
|
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||||
|
ggml_row_size(cur->type, stride), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mid-norm
|
cb(cur, "after_stacked", -1);
|
||||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
|
||||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
|
||||||
|
|
||||||
// ffn out
|
// UltravoxProjector
|
||||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
{
|
||||||
|
// pre-norm
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||||
|
|
||||||
|
// ffn in
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||||
|
|
||||||
|
// swiglu
|
||||||
|
{
|
||||||
|
int64_t split_point = cur->ne[0] / 2;
|
||||||
|
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||||
|
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||||
|
|
||||||
|
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||||
|
x1 = ggml_silu(ctx0, x1);
|
||||||
|
cur = ggml_mul(ctx0, x0, x1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mid-norm
|
||||||
|
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||||
|
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||||
|
|
||||||
|
// ffn out
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||||
|
// projector
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.mm_fc_b);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("%s: unknown projector type", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
cb(cur, "projected", -1);
|
cb(cur, "projected", -1);
|
||||||
@ -1655,6 +1667,17 @@ private:
|
|||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO @ngxson : find a way to move this outside
|
||||||
|
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||||
|
ggml_tensor * cur = inpL;
|
||||||
|
cur = ggml_transpose(ctx0, cur);
|
||||||
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
|
||||||
|
cur = ggml_transpose(ctx0, cur);
|
||||||
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
// post-layernorm
|
// post-layernorm
|
||||||
if (model.post_ln_w) {
|
if (model.post_ln_w) {
|
||||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
||||||
@ -1952,6 +1975,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||||||
res = graph.build_llama4();
|
res = graph.build_llama4();
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
case PROJECTOR_TYPE_QWEN2A:
|
||||||
{
|
{
|
||||||
res = graph.build_whisper_enc();
|
res = graph.build_whisper_enc();
|
||||||
} break;
|
} break;
|
||||||
@ -2186,8 +2210,10 @@ struct clip_model_loader {
|
|||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
|
case PROJECTOR_TYPE_QWEN2A:
|
||||||
{
|
{
|
||||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
|
bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
|
||||||
|
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||||
if (hparams.n_mel_bins != 128) {
|
if (hparams.n_mel_bins != 128) {
|
||||||
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
||||||
}
|
}
|
||||||
@ -2266,7 +2292,7 @@ struct clip_model_loader {
|
|||||||
return cur;
|
return cur;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto & vision_model = ctx_clip.vision_model;
|
auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
|
||||||
|
|
||||||
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
||||||
|
|
||||||
@ -2463,6 +2489,15 @@ struct clip_model_loader {
|
|||||||
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||||
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||||
} break;
|
} break;
|
||||||
|
case PROJECTOR_TYPE_QWEN2A:
|
||||||
|
{
|
||||||
|
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||||
|
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||||
|
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||||
|
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||||
|
vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||||
|
vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||||
|
} break;
|
||||||
case PROJECTOR_TYPE_INTERNVL:
|
case PROJECTOR_TYPE_INTERNVL:
|
||||||
{
|
{
|
||||||
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||||
@ -3450,6 +3485,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||||||
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
|
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
|
||||||
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
||||||
n_patches = n_len / proj_stack_factor / 2;
|
n_patches = n_len / proj_stack_factor / 2;
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||||
|
// divide by 2 because of whisper
|
||||||
|
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
|
||||||
|
n_patches = img->nx / 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
return n_patches;
|
return n_patches;
|
||||||
@ -3850,6 +3889,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
case PROJECTOR_TYPE_GEMMA3:
|
case PROJECTOR_TYPE_GEMMA3:
|
||||||
case PROJECTOR_TYPE_IDEFICS3:
|
case PROJECTOR_TYPE_IDEFICS3:
|
||||||
case PROJECTOR_TYPE_INTERNVL:
|
case PROJECTOR_TYPE_INTERNVL:
|
||||||
|
case PROJECTOR_TYPE_QWEN2A:
|
||||||
case PROJECTOR_TYPE_ULTRAVOX:
|
case PROJECTOR_TYPE_ULTRAVOX:
|
||||||
{
|
{
|
||||||
// do nothing
|
// do nothing
|
||||||
@ -3910,7 +3950,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
const int n_tokens_out = embeddings->ne[1];
|
const int n_tokens_out = embeddings->ne[1];
|
||||||
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
|
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
|
||||||
if (n_tokens_out != expected_n_tokens_out) {
|
if (n_tokens_out != expected_n_tokens_out) {
|
||||||
LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
||||||
GGML_ABORT("Invalid number of output tokens");
|
GGML_ABORT("Invalid number of output tokens");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3955,6 +3995,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||||||
return ctx->vision_model.mm_3_w->ne[1];
|
return ctx->vision_model.mm_3_w->ne[1];
|
||||||
case PROJECTOR_TYPE_LLAMA4:
|
case PROJECTOR_TYPE_LLAMA4:
|
||||||
return ctx->vision_model.mm_model_proj->ne[1];
|
return ctx->vision_model.mm_model_proj->ne[1];
|
||||||
|
case PROJECTOR_TYPE_QWEN2A:
|
||||||
|
return ctx->vision_model.mm_fc_w->ne[1];
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("Unknown projector type");
|
GGML_ABORT("Unknown projector type");
|
||||||
}
|
}
|
||||||
@ -3991,6 +4033,10 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
|
|||||||
return ctx->vision_model.hparams.has_audio;
|
return ctx->vision_model.hparams.has_audio;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||||
|
return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
|
||||||
|
}
|
||||||
|
|
||||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||||
clip_image_f32 clip_img;
|
clip_image_f32 clip_img;
|
||||||
clip_img.buf.resize(h * w * 3);
|
clip_img.buf.resize(h * w * 3);
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
// !!! Internal header, to be used by mtmd only !!!
|
||||||
|
|
||||||
struct clip_ctx;
|
struct clip_ctx;
|
||||||
|
|
||||||
struct clip_image_size {
|
struct clip_image_size {
|
||||||
@ -99,3 +101,4 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
|
|||||||
|
|
||||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||||
|
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||||
|
@ -146,6 +146,13 @@ struct mtmd_context {
|
|||||||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
|
||||||
|
"hint: you may be using wrong mmproj\n",
|
||||||
|
llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
|
||||||
|
}
|
||||||
|
|
||||||
has_vision = clip_has_vision_encoder(ctx_clip);
|
has_vision = clip_has_vision_encoder(ctx_clip);
|
||||||
has_audio = clip_has_audio_encoder(ctx_clip);
|
has_audio = clip_has_audio_encoder(ctx_clip);
|
||||||
use_mrope = clip_is_qwen2vl(ctx_clip);
|
use_mrope = clip_is_qwen2vl(ctx_clip);
|
||||||
@ -196,7 +203,7 @@ struct mtmd_context {
|
|||||||
ov_img_first = false; // overview image is last
|
ov_img_first = false; // overview image is last
|
||||||
}
|
}
|
||||||
|
|
||||||
if (proj == PROJECTOR_TYPE_ULTRAVOX) {
|
if (clip_has_whisper_encoder(ctx_clip)) {
|
||||||
// TODO @ngxson : check if model n_mel is 128 or 80
|
// TODO @ngxson : check if model n_mel is 128 or 80
|
||||||
w_filters = whisper_precalc_filters::get_128_bins();
|
w_filters = whisper_precalc_filters::get_128_bins();
|
||||||
}
|
}
|
||||||
@ -208,7 +215,7 @@ struct mtmd_context {
|
|||||||
}
|
}
|
||||||
if (has_audio) {
|
if (has_audio) {
|
||||||
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
||||||
" https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
|
" https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,6 +334,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
|||||||
marker_modified = "<img>" + ctx->media_marker + "</img>";
|
marker_modified = "<img>" + ctx->media_marker + "</img>";
|
||||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
|
} else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||||
|
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||||
|
marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
|
||||||
|
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||||
|
@ -203,6 +203,8 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
|||||||
const mtmd_input_chunk * chunk);
|
const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
// get output embeddings from the last encode pass
|
// get output embeddings from the last encode pass
|
||||||
|
// the reading size (in bytes) is equal to:
|
||||||
|
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||||
|
|
||||||
/////////////////////////////////////////
|
/////////////////////////////////////////
|
||||||
|
Reference in New Issue
Block a user