mtmd : add support for Voxtral (#14862)

* mtmd : add support for Voxtral

* clean up

* fix python requirements

* add [BEGIN_AUDIO] token

* also support Devstral conversion

* add docs and tests

* fix regression for ultravox

* minor coding style improvement

* correct project activation fn

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Xuan-Son Nguyen
2025-07-28 15:01:48 +02:00
committed by GitHub
parent 946b1f6859
commit 00fa15fedc
13 changed files with 546 additions and 46 deletions

View File

@@ -131,6 +131,7 @@ enum projector_type {
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -150,6 +151,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@@ -354,6 +354,16 @@ struct clip_model {
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
bool audio_has_stack_frames() const {
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
};
struct clip_ctx {
@@ -1483,49 +1493,52 @@ struct clip_graph {
cb(cur, "after_transformer", -1);
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
if (model.audio_has_stack_frames()) {
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
cb(cur, "after_stacked", -1);
}
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
cur = ggml_swiglu_swapped(ctx0, cur);
// swiglu
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
cur = ggml_swiglu_swapped(ctx0, cur);
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
cur = ggml_add(ctx0, cur, model.mm_fc_b);
} else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
} else {
GGML_ABORT("%s: unknown projector type", __func__);
}
@@ -1670,8 +1683,7 @@ private:
inpL = cur;
}
// TODO @ngxson : find a way to move this outside
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
if (ctx->model.audio_has_avgpool()) {
ggml_tensor * cur = inpL;
cur = ggml_transpose(ctx0, cur);
cur = ggml_cont(ctx0, cur);
@@ -1985,6 +1997,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A:
{
res = graph.build_whisper_enc();
@@ -2259,8 +2272,10 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_VOXTRAL:
{
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
model.proj_type == PROJECTOR_TYPE_VOXTRAL;
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
@@ -2544,6 +2559,15 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
} break;
case PROJECTOR_TYPE_VOXTRAL:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@@ -3570,17 +3594,26 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches_sq /= (scale_factor * scale_factor);
} break;
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_ULTRAVOX:
{
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor / 2;
} break;
case PROJECTOR_TYPE_QWEN2A:
{
// divide by 2 because of whisper
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq = img->nx / 4;
n_patches_sq = img->nx;
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
if (ctx->model.audio_has_stack_frames()) {
GGML_ASSERT(proj_stack_factor > 0);
const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor;
}
// whisper downscales input token by half after conv1d
n_patches_sq /= 2;
if (ctx->model.audio_has_avgpool()) {
// divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq /= 2;
}
} break;
default:
GGML_ABORT("unsupported projector type");
@@ -3986,6 +4019,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
{
// do nothing
} break;
@@ -4086,6 +4120,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_IDEFICS3:
return ctx->model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->model.mm_3_w->ne[1];
@@ -4132,7 +4167,8 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {

View File

@@ -289,6 +289,10 @@ struct mtmd_context {
aud_beg = "<|audio_bos|>";
aud_end = "<|audio_eos|>";
} else if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// [BEGIN_AUDIO] ... (embeddings) ...
aud_beg = "[BEGIN_AUDIO]";
}
}

View File

@@ -1,5 +1,5 @@
-r ../../requirements/requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
pillow~=10.2.0
pillow~=11.3.0
torch~=2.2.1
torchvision~=0.17.1

View File

@@ -71,6 +71,7 @@ add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then