diff --git a/examples/llava/mtmd-cli.cpp b/examples/llava/mtmd-cli.cpp index aa52d92ca..474e7c4f8 100644 --- a/examples/llava/mtmd-cli.cpp +++ b/examples/llava/mtmd-cli.cpp @@ -72,6 +72,8 @@ struct mtmd_cli_context { llama_batch batch; int n_batch; + std::vector bitmaps; + // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another // so here we don't need to keep track of chat history common_chat_templates_ptr tmpls; @@ -135,13 +137,22 @@ struct mtmd_cli_context { antiprompt_tokens.begin() ); } + + bool load_image(const std::string & fname) { + mtmd_bitmap bitmap; + if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { + return false; + } + bitmaps.push_back(std::move(bitmap)); + return true; + } }; static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) { llama_tokens generated_tokens; for (int i = 0; i < n_predict; i++) { if (i > n_predict || !g_is_generating || g_is_interrupted) { - printf("\n"); + LOG("\n"); break; } @@ -150,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int common_sampler_accept(smpl, token_id, true); if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) { - printf("\n"); + LOG("\n"); break; // end of generation } - printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); + LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); fflush(stdout); if (g_is_interrupted) { - printf("\n"); + LOG("\n"); break; } @@ -173,9 +184,7 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int return 0; } -static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector & images_fname, bool add_bos = false) { - std::vector bitmaps; - +static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) { common_chat_templates_inputs tmpl_inputs; tmpl_inputs.messages = {msg}; tmpl_inputs.add_generation_prompt = true; @@ -183,15 +192,6 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs); LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); - for (auto & fname : images_fname) { - mtmd_bitmap bitmap; - if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { - LOG_ERR("Unable to load image %s\n", fname.c_str()); - return 2; // image not found - } - bitmaps.push_back(std::move(bitmap)); - } - mtmd_input_text text; text.text = formatted_chat.prompt; text.add_special = add_bos; @@ -200,12 +200,14 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect if (g_is_interrupted) return 0; - int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps); + int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps); if (res != 0) { LOG_ERR("Unable to tokenize prompt, res = %d\n", res); return 1; } + ctx.bitmaps.clear(); + if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) { LOG_ERR("Unable to eval prompt\n"); return 1; @@ -213,6 +215,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect ctx.n_past += mtmd_helper_get_n_pos(chunks); + LOG("\n"); + return 0; } @@ -235,7 +239,7 @@ int main(int argc, char ** argv) { } mtmd_cli_context ctx(params); - printf("%s: %s\n", __func__, params.model.path.c_str()); + LOG("%s: loading model: %s\n", __func__, params.model.path.c_str()); bool is_single_turn = !params.prompt.empty() && !params.image.empty(); @@ -268,7 +272,12 @@ int main(int argc, char ** argv) { common_chat_msg msg; msg.role = "user"; msg.content = params.prompt; - if (eval_message(ctx, msg, params.image, true)) { + for (const auto & image : params.image) { + if (!ctx.load_image(image)) { + return 1; // error is already printed by libmtmd + } + } + if (eval_message(ctx, msg, true)) { return 1; } if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) { @@ -283,7 +292,6 @@ int main(int argc, char ** argv) { LOG("\n"); bool is_first_msg = true; - std::vector images_fname; std::string content; while (!g_is_interrupted) { @@ -308,10 +316,17 @@ int main(int argc, char ** argv) { continue; } g_is_generating = true; - if (line.find("/image") == 0) { + if (line == "/image" || line.find("/image ") == 0) { + if (line.size() < 8) { + LOG_ERR("ERR: Missing image filename\n"); + continue; + } std::string image = line.substr(7); - images_fname.push_back(string_strip(image)); - content += "<__image__>"; + if (ctx.load_image(image)) { + LOG("Image %s loaded\n", image.c_str()); + content += "<__image__>"; + } + // else, error is already printed by libmtmd continue; } else { content += line; @@ -319,21 +334,14 @@ int main(int argc, char ** argv) { common_chat_msg msg; msg.role = "user"; msg.content = content; - int ret = eval_message(ctx, msg, images_fname, is_first_msg); - if (g_is_interrupted) break; - if (ret == 2) { - // non-fatal error - images_fname.clear(); - content.clear(); - continue; - } + int ret = eval_message(ctx, msg, is_first_msg); if (ret) { return 1; } + if (g_is_interrupted) break; if (generate_response(ctx, smpl, n_predict)) { return 1; } - images_fname.clear(); content.clear(); is_first_msg = false; } diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 7081fd735..d1d7530fe 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -590,7 +590,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, } } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - GGML_ASSERT(!is_last && "logits for last image chunk is not yet support"); + GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported"); GGML_ASSERT(chunk.tokens_image != nullptr); int64_t t0 = ggml_time_ms(); if (ctx->print_timings) {