llama-mtmd-cli: Sigint rework in mtmd vision example (#13080)

* Sigint rework in mtmd vision example

* Applied suggestions on mtmd-cli PR

* Forgot to invert one of the conditions

* Update examples/llava/mtmd-cli.cpp

* Removed redundant exit check

---------

Co-authored-by: pl752 <maximpl752@gmail.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
pl752
2025-04-24 02:32:35 +05:00
committed by GitHub
parent ecda2ec4b3
commit 5630406959

View File

@ -24,7 +24,9 @@
#include <signal.h> #include <signal.h>
#endif #endif
static bool g_is_generating = false; // volatile, because of signal being an interrupt
static volatile bool g_is_generating = false;
static volatile bool g_is_interrupted = false;
/** /**
* Please note that this is NOT a production-ready stuff. * Please note that this is NOT a production-ready stuff.
@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
g_is_generating = false; g_is_generating = false;
} else { } else {
console::cleanup(); console::cleanup();
LOG("\nInterrupted by user\n"); if (g_is_interrupted) {
_exit(130); _exit(1);
}
g_is_interrupted = true;
} }
} }
} }
@ -167,7 +171,7 @@ struct decode_embd_batch {
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) { static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
llama_tokens generated_tokens; llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) { for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating) { if (i > n_predict || !g_is_generating || g_is_interrupted) {
printf("\n"); printf("\n");
break; break;
} }
@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout); fflush(stdout);
if (g_is_interrupted) {
printf("\n");
break;
}
// eval the token // eval the token
common_batch_clear(ctx.batch); common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
text.add_special = add_bos; text.add_special = add_bos;
text.parse_special = true; text.parse_special = true;
mtmd_input_chunks chunks; mtmd_input_chunks chunks;
if (g_is_interrupted) return 0;
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps); int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
if (res != 0) { if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res); LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
#endif #endif
} }
if (g_is_interrupted) return 130;
if (is_single_turn) { if (is_single_turn) {
g_is_generating = true; g_is_generating = true;
if (params.prompt.find("<__image__>") == std::string::npos) { if (params.prompt.find("<__image__>") == std::string::npos) {
@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
if (eval_message(ctx, msg, params.image, true)) { if (eval_message(ctx, msg, params.image, true)) {
return 1; return 1;
} }
if (generate_response(ctx, smpl, n_predict)) { if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
return 1; return 1;
} }
@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
std::vector<std::string> images_fname; std::vector<std::string> images_fname;
std::string content; std::string content;
while (true) { while (!g_is_interrupted) {
g_is_generating = false; g_is_generating = false;
LOG("\n> "); LOG("\n> ");
console::set_display(console::user_input); console::set_display(console::user_input);
std::string line; std::string line;
console::readline(line, false); console::readline(line, false);
if (g_is_interrupted) break;
console::set_display(console::reset); console::set_display(console::reset);
line = string_strip(line); line = string_strip(line);
if (line.empty()) { if (line.empty()) {
@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
msg.role = "user"; msg.role = "user";
msg.content = content; msg.content = content;
int ret = eval_message(ctx, msg, images_fname, is_first_msg); int ret = eval_message(ctx, msg, images_fname, is_first_msg);
if (g_is_interrupted) break;
if (ret == 2) { if (ret == 2) {
// non-fatal error // non-fatal error
images_fname.clear(); images_fname.clear();
@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
is_first_msg = false; is_first_msg = false;
} }
} }
if (g_is_interrupted) LOG("\nInterrupted by user\n");
llama_perf_context_print(ctx.lctx); llama_perf_context_print(ctx.lctx);
return 0; return g_is_interrupted ? 130 : 0;
} }