mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
clip : fix confused naming ffn_up and ffn_down (#13290)
* clip : fix confused naming ffn_up and ffn_down * rm ffn_i/o/g naming * rename n_embd, n_ff * small fix * no check n_ff
This commit is contained in:
@ -1778,6 +1778,12 @@ class LlamaModel(TextModel):
|
|||||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||||
undo_permute = True
|
undo_permute = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# fix for SmolVLM2, missing `num_attention_heads` in config.json
|
||||||
|
if self.hf_arch == "VLlama3ForCausalLM":
|
||||||
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
try:
|
try:
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
@ -977,15 +977,12 @@ class TensorNameMap:
|
|||||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
# some namings are messed up because the original llava code swapped fc1 and fc2
|
|
||||||
# we have no better way to fix it, just be careful
|
|
||||||
# new models like pixtral use the correct naming
|
|
||||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
|
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
||||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -997,9 +994,9 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
|
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
||||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ enum patch_merge_type {
|
|||||||
struct clip_hparams {
|
struct clip_hparams {
|
||||||
int32_t image_size;
|
int32_t image_size;
|
||||||
int32_t patch_size;
|
int32_t patch_size;
|
||||||
int32_t hidden_size;
|
int32_t n_embd;
|
||||||
int32_t n_intermediate;
|
int32_t n_ff;
|
||||||
int32_t projection_dim;
|
int32_t projection_dim;
|
||||||
int32_t n_head;
|
int32_t n_head;
|
||||||
int32_t n_layer;
|
int32_t n_layer;
|
||||||
@ -191,12 +191,6 @@ struct clip_layer {
|
|||||||
struct ggml_tensor * ln_1_w = nullptr;
|
struct ggml_tensor * ln_1_w = nullptr;
|
||||||
struct ggml_tensor * ln_1_b = nullptr;
|
struct ggml_tensor * ln_1_b = nullptr;
|
||||||
|
|
||||||
// ff
|
|
||||||
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
|
|
||||||
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
|
|
||||||
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
|
|
||||||
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
|
|
||||||
|
|
||||||
struct ggml_tensor * ff_up_w = nullptr;
|
struct ggml_tensor * ff_up_w = nullptr;
|
||||||
struct ggml_tensor * ff_up_b = nullptr;
|
struct ggml_tensor * ff_up_b = nullptr;
|
||||||
struct ggml_tensor * ff_gate_w = nullptr;
|
struct ggml_tensor * ff_gate_w = nullptr;
|
||||||
@ -204,9 +198,6 @@ struct clip_layer {
|
|||||||
struct ggml_tensor * ff_down_w = nullptr;
|
struct ggml_tensor * ff_down_w = nullptr;
|
||||||
struct ggml_tensor * ff_down_b = nullptr;
|
struct ggml_tensor * ff_down_b = nullptr;
|
||||||
|
|
||||||
struct ggml_tensor * ff_g_w = NULL;
|
|
||||||
struct ggml_tensor * ff_g_b = NULL;
|
|
||||||
|
|
||||||
// layernorm 2
|
// layernorm 2
|
||||||
struct ggml_tensor * ln_2_w = nullptr;
|
struct ggml_tensor * ln_2_w = nullptr;
|
||||||
struct ggml_tensor * ln_2_b = nullptr;
|
struct ggml_tensor * ln_2_b = nullptr;
|
||||||
@ -388,9 +379,9 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
|
|
||||||
const int patch_size = hparams.patch_size;
|
const int patch_size = hparams.patch_size;
|
||||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = n_embd / n_head;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
|
|
||||||
@ -411,7 +402,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
ggml_set_input(inp_raw);
|
ggml_set_input(inp_raw);
|
||||||
|
|
||||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
|
inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
|
||||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||||
|
|
||||||
@ -456,7 +447,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
|
cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
|
||||||
}
|
}
|
||||||
|
|
||||||
// attention output
|
// attention output
|
||||||
@ -473,14 +464,14 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
|
||||||
|
|
||||||
// siglip uses gelu
|
// siglip uses gelu
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
|
||||||
|
|
||||||
// residual 2
|
// residual 2
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
@ -504,11 +495,11 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
|||||||
const int kernel_size = patches_per_image / tokens_per_side;
|
const int kernel_size = patches_per_image / tokens_per_side;
|
||||||
|
|
||||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||||
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
|
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size);
|
||||||
|
|
||||||
// doing a pool2d to reduce the number of output tokens to 256
|
// doing a pool2d to reduce the number of output tokens to 256
|
||||||
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
||||||
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
|
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size);
|
||||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||||
|
|
||||||
// apply norm before projection
|
// apply norm before projection
|
||||||
@ -637,9 +628,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
const int n_patches_x = image_size_width / patch_size;
|
const int n_patches_x = image_size_width / patch_size;
|
||||||
const int n_patches_y = image_size_height / patch_size;
|
const int n_patches_y = image_size_height / patch_size;
|
||||||
const int num_patches = n_patches_x * n_patches_y;
|
const int num_patches = n_patches_x * n_patches_y;
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = n_embd / n_head;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
const int n_merge = hparams.spatial_merge_size;
|
const int n_merge = hparams.spatial_merge_size;
|
||||||
@ -669,7 +660,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
ggml_set_input(pos_w);
|
ggml_set_input(pos_w);
|
||||||
|
|
||||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||||
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
|
inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
|
||||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||||
|
|
||||||
struct ggml_tensor * embeddings = inp;
|
struct ggml_tensor * embeddings = inp;
|
||||||
@ -710,7 +701,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
|
cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
|
||||||
}
|
}
|
||||||
@ -753,8 +744,8 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
||||||
|
|
||||||
// reshape image tokens to 2D grid
|
// reshape image tokens to 2D grid
|
||||||
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
|
||||||
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
|
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
|
||||||
cur = ggml_cont(ctx0, cur);
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
|
||||||
// torch.nn.functional.unfold is just an im2col under the hood
|
// torch.nn.functional.unfold is just an im2col under the hood
|
||||||
@ -762,7 +753,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
|
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
|
||||||
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
|
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
|
||||||
|
|
||||||
// project to hidden_size
|
// project to n_embd
|
||||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
||||||
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
|
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
@ -785,9 +776,9 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
// arrangement of the [IMG_BREAK] token
|
// arrangement of the [IMG_BREAK] token
|
||||||
{
|
{
|
||||||
// not efficient, but works
|
// not efficient, but works
|
||||||
// the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
|
// the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
|
||||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||||
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
// after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
|
||||||
|
|
||||||
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
||||||
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
||||||
@ -827,9 +818,9 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
const int patches_h = image_size_height / patch_size;
|
const int patches_h = image_size_height / patch_size;
|
||||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position
|
const int num_position_ids = num_positions * 4; // m-rope requires 4 dim per position
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = n_embd / n_head;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
|
|
||||||
@ -864,14 +855,14 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||||
inp = ggml_reshape_4d(
|
inp = ggml_reshape_4d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
n_embd * 2, patches_w / 2, patches_h, batch_size);
|
||||||
inp = ggml_reshape_4d(
|
inp = ggml_reshape_4d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||||
inp = ggml_reshape_3d(
|
inp = ggml_reshape_3d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size, patches_w * patches_h, batch_size);
|
n_embd, patches_w * patches_h, batch_size);
|
||||||
|
|
||||||
if (model.patch_bias) {
|
if (model.patch_bias) {
|
||||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||||
@ -904,11 +895,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
ggml_set_name(window_mask, "window_mask");
|
ggml_set_name(window_mask, "window_mask");
|
||||||
ggml_set_input(window_mask);
|
ggml_set_input(window_mask);
|
||||||
|
|
||||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
// embeddings shape: [n_embd, patches_w * patches_h, batch_size]
|
||||||
GGML_ASSERT(batch_size == 1);
|
GGML_ASSERT(batch_size == 1);
|
||||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
|
embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4);
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
|
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
|
||||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
|
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
@ -961,7 +952,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// attention output
|
// attention output
|
||||||
@ -978,11 +969,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
|
|
||||||
// mlp
|
// mlp
|
||||||
// ffn_up
|
// ffn_up
|
||||||
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||||
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
|
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b);
|
||||||
|
|
||||||
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
|
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
||||||
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
|
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b);
|
||||||
// TODO : only 2 of these 3 are actually used, should we remove one of them?
|
// TODO : only 2 of these 3 are actually used, should we remove one of them?
|
||||||
if (ctx->use_gelu) {
|
if (ctx->use_gelu) {
|
||||||
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
|
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
|
||||||
@ -994,8 +985,8 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
cur = ggml_mul(ctx0, cur_gate, cur_up);
|
cur = ggml_mul(ctx0, cur_gate, cur_up);
|
||||||
|
|
||||||
// ffn_down
|
// ffn_down
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
|
||||||
|
|
||||||
// residual 2
|
// residual 2
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
@ -1011,7 +1002,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
|
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
|
||||||
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
@ -1028,7 +1019,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
|
|||||||
ggml_set_name(window_idx, "window_idx");
|
ggml_set_name(window_idx, "window_idx");
|
||||||
ggml_set_input(window_idx);
|
ggml_set_input(window_idx);
|
||||||
|
|
||||||
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
|
// embeddings shape: [n_embd, patches_w * patches_h, batch_size]
|
||||||
GGML_ASSERT(batch_size == 1);
|
GGML_ASSERT(batch_size == 1);
|
||||||
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
|
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
|
||||||
@ -1074,9 +1065,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
const int patches_h = image_size_height / patch_size;
|
const int patches_h = image_size_height / patch_size;
|
||||||
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
|
||||||
const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
|
const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = n_embd / n_head;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
|
||||||
@ -1114,17 +1105,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||||
inp = ggml_reshape_4d(
|
inp = ggml_reshape_4d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
n_embd * 2, patches_w / 2, patches_h, batch_size);
|
||||||
inp = ggml_reshape_4d(
|
inp = ggml_reshape_4d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||||
inp = ggml_reshape_3d(
|
inp = ggml_reshape_3d(
|
||||||
ctx0, inp,
|
ctx0, inp,
|
||||||
hidden_size, patches_w * patches_h, batch_size);
|
n_embd, patches_w * patches_h, batch_size);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size);
|
||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1137,7 +1128,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
|
|
||||||
// concat class_embeddings and patch_embeddings
|
// concat class_embeddings and patch_embeddings
|
||||||
if (model.class_embedding) {
|
if (model.class_embedding) {
|
||||||
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size);
|
||||||
embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
|
embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
|
||||||
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||||
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||||
@ -1234,7 +1225,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
|
cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// attention output
|
// attention output
|
||||||
@ -1252,8 +1243,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
|
||||||
|
|
||||||
if (ctx->use_gelu) {
|
if (ctx->use_gelu) {
|
||||||
cur = ggml_gelu_inplace(ctx0, cur);
|
cur = ggml_gelu_inplace(ctx0, cur);
|
||||||
@ -1263,8 +1254,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
|
||||||
|
|
||||||
// residual 2
|
// residual 2
|
||||||
cur = ggml_add(ctx0, embeddings, cur);
|
cur = ggml_add(ctx0, embeddings, cur);
|
||||||
@ -1496,9 +1487,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
}
|
}
|
||||||
|
|
||||||
{ // attention
|
{ // attention
|
||||||
int hidden_size = clip_n_mmproj_embd(ctx);
|
int n_embd = clip_n_mmproj_embd(ctx);
|
||||||
const int d_head = 128;
|
const int d_head = 128;
|
||||||
int n_head = hidden_size/d_head;
|
int n_head = n_embd/d_head;
|
||||||
int num_query = 96;
|
int num_query = 96;
|
||||||
if (ctx->minicpmv_version == 2) {
|
if (ctx->minicpmv_version == 2) {
|
||||||
num_query = 96;
|
num_query = 96;
|
||||||
@ -1528,7 +1519,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||||
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
|
||||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
|
KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size);
|
||||||
|
|
||||||
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
|
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
|
||||||
}
|
}
|
||||||
@ -1571,7 +1562,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
|
|||||||
}
|
}
|
||||||
|
|
||||||
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
|
||||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
|
||||||
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
@ -1696,9 +1687,9 @@ struct clip_model_loader {
|
|||||||
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
|
||||||
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
|
||||||
|
|
||||||
get_u32(KEY_N_EMBD, hparams.hidden_size);
|
get_u32(KEY_N_EMBD, hparams.n_embd);
|
||||||
get_u32(KEY_N_HEAD, hparams.n_head);
|
get_u32(KEY_N_HEAD, hparams.n_head);
|
||||||
get_u32(KEY_N_FF, hparams.n_intermediate);
|
get_u32(KEY_N_FF, hparams.n_ff);
|
||||||
get_u32(KEY_N_BLOCK, hparams.n_layer);
|
get_u32(KEY_N_BLOCK, hparams.n_layer);
|
||||||
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
|
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
|
||||||
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
|
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
|
||||||
@ -1807,6 +1798,7 @@ struct clip_model_loader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void load_tensors() {
|
void load_tensors() {
|
||||||
|
auto & hparams = ctx_clip.vision_model.hparams;
|
||||||
std::map<std::string, size_t> tensor_offset;
|
std::map<std::string, size_t> tensor_offset;
|
||||||
std::vector<ggml_tensor *> tensors_to_load;
|
std::vector<ggml_tensor *> tensors_to_load;
|
||||||
|
|
||||||
@ -1860,8 +1852,8 @@ struct clip_model_loader {
|
|||||||
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
|
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
|
||||||
|
|
||||||
// layers
|
// layers
|
||||||
vision_model.layers.resize(vision_model.hparams.n_layer);
|
vision_model.layers.resize(hparams.n_layer);
|
||||||
for (int il = 0; il < vision_model.hparams.n_layer; ++il) {
|
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||||
auto & layer = vision_model.layers[il];
|
auto & layer = vision_model.layers[il];
|
||||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
|
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
|
||||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
|
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
|
||||||
@ -1884,13 +1876,18 @@ struct clip_model_loader {
|
|||||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
||||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
||||||
|
|
||||||
// legacy naming (the in and out is reversed! don't ask me why)
|
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
||||||
layer.ff_i_w = layer.ff_down_w;
|
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
||||||
layer.ff_o_w = layer.ff_up_w;
|
if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
|
||||||
layer.ff_g_w = layer.ff_gate_w;
|
// swap up and down weights
|
||||||
layer.ff_i_b = layer.ff_down_b;
|
ggml_tensor * tmp = layer.ff_up_w;
|
||||||
layer.ff_o_b = layer.ff_up_b;
|
layer.ff_up_w = layer.ff_down_w;
|
||||||
layer.ff_g_b = layer.ff_gate_b;
|
layer.ff_down_w = tmp;
|
||||||
|
// swap up and down biases
|
||||||
|
tmp = layer.ff_up_b;
|
||||||
|
layer.ff_up_b = layer.ff_down_b;
|
||||||
|
layer.ff_down_b = tmp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (ctx_clip.proj_type) {
|
switch (ctx_clip.proj_type) {
|
||||||
@ -2904,7 +2901,7 @@ int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
|
int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
|
||||||
return ctx->vision_model.hparams.hidden_size;
|
return ctx->vision_model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||||
|
@ -92,6 +92,10 @@ struct mtmd_cli_context {
|
|||||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||||
n_batch = params.n_batch;
|
n_batch = params.n_batch;
|
||||||
|
|
||||||
|
if (!model || !lctx) {
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) {
|
if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) {
|
||||||
LOG_ERR("Model does not have chat template.\n");
|
LOG_ERR("Model does not have chat template.\n");
|
||||||
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
||||||
|
Reference in New Issue
Block a user