mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd-cli : fix out_of_range when input image path is empty (#13244)
* fix out_of_range error to keep the chat loop running * Update examples/llava/mtmd-cli.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * mtmd-cli : load image right away * add a new line for readability * rm printf * Update examples/llava/mtmd-cli.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update examples/llava/mtmd-cli.cpp --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
@ -72,6 +72,8 @@ struct mtmd_cli_context {
|
|||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
int n_batch;
|
int n_batch;
|
||||||
|
|
||||||
|
std::vector<mtmd_bitmap> bitmaps;
|
||||||
|
|
||||||
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||||
// so here we don't need to keep track of chat history
|
// so here we don't need to keep track of chat history
|
||||||
common_chat_templates_ptr tmpls;
|
common_chat_templates_ptr tmpls;
|
||||||
@ -135,13 +137,22 @@ struct mtmd_cli_context {
|
|||||||
antiprompt_tokens.begin()
|
antiprompt_tokens.begin()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool load_image(const std::string & fname) {
|
||||||
|
mtmd_bitmap bitmap;
|
||||||
|
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bitmaps.push_back(std::move(bitmap));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < n_predict; i++) {
|
||||||
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
if (i > n_predict || !g_is_generating || g_is_interrupted) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
|||||||
common_sampler_accept(smpl, token_id, true);
|
common_sampler_accept(smpl, token_id, true);
|
||||||
|
|
||||||
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break; // end of generation
|
break; // end of generation
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
if (g_is_interrupted) {
|
if (g_is_interrupted) {
|
||||||
printf("\n");
|
LOG("\n");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,9 +184,7 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
|
||||||
std::vector<mtmd_bitmap> bitmaps;
|
|
||||||
|
|
||||||
common_chat_templates_inputs tmpl_inputs;
|
common_chat_templates_inputs tmpl_inputs;
|
||||||
tmpl_inputs.messages = {msg};
|
tmpl_inputs.messages = {msg};
|
||||||
tmpl_inputs.add_generation_prompt = true;
|
tmpl_inputs.add_generation_prompt = true;
|
||||||
@ -183,15 +192,6 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
|||||||
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||||
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||||
|
|
||||||
for (auto & fname : images_fname) {
|
|
||||||
mtmd_bitmap bitmap;
|
|
||||||
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
|
||||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
|
||||||
return 2; // image not found
|
|
||||||
}
|
|
||||||
bitmaps.push_back(std::move(bitmap));
|
|
||||||
}
|
|
||||||
|
|
||||||
mtmd_input_text text;
|
mtmd_input_text text;
|
||||||
text.text = formatted_chat.prompt;
|
text.text = formatted_chat.prompt;
|
||||||
text.add_special = add_bos;
|
text.add_special = add_bos;
|
||||||
@ -200,12 +200,14 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
|||||||
|
|
||||||
if (g_is_interrupted) return 0;
|
if (g_is_interrupted) return 0;
|
||||||
|
|
||||||
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
|
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.bitmaps.clear();
|
||||||
|
|
||||||
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
|
||||||
LOG_ERR("Unable to eval prompt\n");
|
LOG_ERR("Unable to eval prompt\n");
|
||||||
return 1;
|
return 1;
|
||||||
@ -213,6 +215,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
|
|||||||
|
|
||||||
ctx.n_past += mtmd_helper_get_n_pos(chunks);
|
ctx.n_past += mtmd_helper_get_n_pos(chunks);
|
||||||
|
|
||||||
|
LOG("\n");
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +239,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mtmd_cli_context ctx(params);
|
mtmd_cli_context ctx(params);
|
||||||
printf("%s: %s\n", __func__, params.model.path.c_str());
|
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||||
|
|
||||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||||
|
|
||||||
@ -268,7 +272,12 @@ int main(int argc, char ** argv) {
|
|||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = params.prompt;
|
msg.content = params.prompt;
|
||||||
if (eval_message(ctx, msg, params.image, true)) {
|
for (const auto & image : params.image) {
|
||||||
|
if (!ctx.load_image(image)) {
|
||||||
|
return 1; // error is already printed by libmtmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (eval_message(ctx, msg, true)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
|
||||||
@ -283,7 +292,6 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
bool is_first_msg = true;
|
bool is_first_msg = true;
|
||||||
std::vector<std::string> images_fname;
|
|
||||||
std::string content;
|
std::string content;
|
||||||
|
|
||||||
while (!g_is_interrupted) {
|
while (!g_is_interrupted) {
|
||||||
@ -308,10 +316,17 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (line.find("/image") == 0) {
|
if (line == "/image" || line.find("/image ") == 0) {
|
||||||
|
if (line.size() < 8) {
|
||||||
|
LOG_ERR("ERR: Missing image filename\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string image = line.substr(7);
|
std::string image = line.substr(7);
|
||||||
images_fname.push_back(string_strip(image));
|
if (ctx.load_image(image)) {
|
||||||
content += "<__image__>";
|
LOG("Image %s loaded\n", image.c_str());
|
||||||
|
content += "<__image__>";
|
||||||
|
}
|
||||||
|
// else, error is already printed by libmtmd
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
content += line;
|
content += line;
|
||||||
@ -319,21 +334,14 @@ int main(int argc, char ** argv) {
|
|||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
msg.role = "user";
|
msg.role = "user";
|
||||||
msg.content = content;
|
msg.content = content;
|
||||||
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
int ret = eval_message(ctx, msg, is_first_msg);
|
||||||
if (g_is_interrupted) break;
|
|
||||||
if (ret == 2) {
|
|
||||||
// non-fatal error
|
|
||||||
images_fname.clear();
|
|
||||||
content.clear();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ret) {
|
if (ret) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
if (g_is_interrupted) break;
|
||||||
if (generate_response(ctx, smpl, n_predict)) {
|
if (generate_response(ctx, smpl, n_predict)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
images_fname.clear();
|
|
||||||
content.clear();
|
content.clear();
|
||||||
is_first_msg = false;
|
is_first_msg = false;
|
||||||
}
|
}
|
||||||
|
@ -590,7 +590,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
|
GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported");
|
||||||
GGML_ASSERT(chunk.tokens_image != nullptr);
|
GGML_ASSERT(chunk.tokens_image != nullptr);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
if (ctx->print_timings) {
|
if (ctx->print_timings) {
|
||||||
|
Reference in New Issue
Block a user