From b9154ecff93ff54dc554411eb844a2a654be49f2 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 18 Apr 2025 10:04:51 +0200 Subject: [PATCH] mtmd : add methods to access `mtmd_image_tokens` (#12906) * mtmd : add more api around mtmd_image_tokens * mtmd : ability to calc image hash * shared_ptr for mtmd_image_tokens * move hash to user-define ID (fixed) * fix prompt_modified * rm redundant data member --- examples/llava/gemma3-cli.cpp | 11 +++-- examples/llava/mtmd.cpp | 88 ++++++++++++++++++++++++----------- examples/llava/mtmd.h | 37 ++++++++++----- 3 files changed, 92 insertions(+), 44 deletions(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 693b738a8..3d5664750 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -184,18 +184,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector text.text = formatted_chat.prompt; text.add_special = add_bos; text.parse_special = true; - mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps)); - if (chunks == nullptr) { - LOG_ERR("Unable to tokenize prompt\n"); + mtmd_input_chunks chunks; + int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps); + if (res != 0) { + LOG_ERR("Unable to tokenize prompt, res = %d\n", res); return 1; } - if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) { + if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) { LOG_ERR("Unable to eval prompt\n"); return 1; } - ctx.n_past += mtmd_helper_get_n_tokens(chunks.get()); + ctx.n_past += mtmd_helper_get_n_tokens(chunks); return 0; } diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp index 114c274bc..3fd5bebc6 100644 --- a/examples/llava/mtmd.cpp +++ b/examples/llava/mtmd.cpp @@ -16,6 +16,7 @@ struct mtmd_context { struct clip_ctx * ctx_clip; const struct llama_model * text_model; std::vector image_embd_v; // image embedding vector + bool print_timings; int n_threads; std::string image_marker; @@ -24,7 +25,11 @@ struct mtmd_context { mtmd_context(const char * mmproj_fname, const llama_model * text_model, - const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) { + const mtmd_context_params & ctx_params) : + print_timings(ctx_params.print_timings), + n_threads (ctx_params.n_threads), + image_marker (ctx_params.image_marker) + { clip_context_params ctx_clip_params; ctx_clip_params.use_gpu = ctx_params.use_gpu; ctx_clip_params.verbosity = ctx_params.verbosity; @@ -49,6 +54,7 @@ struct mtmd_image_tokens { uint32_t ny; // number of tokens in y direction uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches + std::string id; // optional user-defined ID, useful for KV cache tracking }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -88,10 +94,10 @@ static std::vector mtmd_tokenize_text_internal( return result; } -mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, - const mtmd_input_text & text, - const std::vector & bitmaps) { - mtmd_input_chunks * output = new mtmd_input_chunks; +int32_t mtmd_tokenize(mtmd_context * ctx, + std::vector & output, + const mtmd_input_text & text, + const std::vector & bitmaps) { auto vocab = llama_model_get_vocab(ctx->text_model); std::string prompt_modified(text.text); @@ -105,9 +111,9 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, string_replace_all(prompt_modified, ctx->image_marker, marker_modified); } - std::vector parts = string_split_str(text.text, ctx->image_marker); - output->clear(); - output->reserve(parts.size()); + std::vector parts = string_split_str(prompt_modified, ctx->image_marker); + output.clear(); + output.reserve(parts.size()); size_t i_img = 0; @@ -123,14 +129,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output->emplace_back(std::move(chunk)); + output.emplace_back(std::move(chunk)); if (&parts.back() != &part) { // add image token to middle of 2 parts if (i_img >= bitmaps.size()) { LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); - return nullptr; + return 1; } // shim layer @@ -145,34 +151,48 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); if (!ok) { LOG_ERR("Unable to preprocess image\n"); - return nullptr; + return 2; } - mtmd_image_tokens * image_tokens = new mtmd_image_tokens; + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image image_tokens->ny = 1; // TODO image_tokens->batch_f32 = std::move(batch_f32); + image_tokens->id = bitmaps[i_img].id; // optional mtmd_input_chunk chunk{ MTMD_INPUT_CHUNK_TYPE_IMAGE, {}, - image_tokens, + std::move(image_tokens), }; - output->emplace_back(std::move(chunk)); + output.emplace_back(std::move(chunk)); i_img++; } } - return output; + return 0; } -void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { - delete chunk.tokens_image; - } +void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { + if (image_tokens) { + delete image_tokens; } - delete chunks; +} + +size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { + return image_tokens->n_tokens(); +} + +size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { + return image_tokens->nx; +} + +size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { + return image_tokens->ny; +} + +std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { + return image_tokens->id; } int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { @@ -190,9 +210,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) { +size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { size_t n_tokens = 0; - for (auto & chunk : *chunks) { + for (auto & chunk : chunks) { if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { n_tokens += chunk.tokens_text.size(); } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { @@ -241,7 +261,7 @@ struct decode_embd_batch { int32_t mtmd_helper_eval(mtmd_context * ctx, llama_context * lctx, - mtmd_input_chunks * chunks, + mtmd_input_chunks & chunks, llama_pos pos0, llama_seq_id seq_id, int32_t n_batch) { @@ -249,8 +269,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, llama_pos n_past = pos0; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); - for (auto & chunk : *chunks) { - bool is_last = &chunk == &chunks->back(); + for (auto & chunk : chunks) { + bool is_last = &chunk == &chunks.back(); if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { // TODO @ngxson : may need to split into smaller batches text_batch.n_tokens = chunk.tokens_text.size(); @@ -279,7 +299,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, if (ctx->print_timings) { LOG_INF("encoding image...\n"); } - ret = mtmd_encode(ctx, chunk.tokens_image); + ret = mtmd_encode(ctx, chunk.tokens_image.get()); if (ret != 0) { LOG_ERR("failed to encode image\n"); llama_batch_free(text_batch); @@ -289,7 +309,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx, LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); } - int32_t n_tokens = chunk.tokens_image->n_tokens(); + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); float * embd = mtmd_get_output_embd(ctx); decode_embd_batch batch_img(embd, n_tokens, n_past, 0); int64_t t1 = ggml_time_ms(); @@ -339,3 +359,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp std::memcpy(output.data.data(), data, output.nx * output.ny * 3); return 0; } + +bool mtmd_decode_use_non_causal(mtmd_context * ctx) { + projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + return true; + } + return false; +} + +void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { + mtmd_image_tokens_free(val); +} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h index 598f6947b..78be192dd 100644 --- a/examples/llava/mtmd.h +++ b/examples/llava/mtmd.h @@ -39,12 +39,18 @@ struct mtmd_bitmap { uint32_t nx; uint32_t ny; std::vector data; + std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking }; +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val); // forward declaration +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + struct mtmd_input_chunk { mtmd_input_chunk_type type; std::vector tokens_text; - mtmd_image_tokens * tokens_image = nullptr; + mtmd_image_tokens_ptr tokens_image; }; using mtmd_input_chunks = std::vector; @@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 3. "\ndescribe it in detail." // number of bitmaps must be equal to the number of image markers in the prompt // this function is thread-safe (shared ctx) -MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, +// return values: +// 0 on success +// 1 on number of images not matching the number of markers +// 2 on image preprocessing error +MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, + std::vector & output, const mtmd_input_text & text, const std::vector & bitmaps); -// free image chunk data -MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); +// access mtmd_image_tokens +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); +MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens); +MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, @@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx, // get output embeddings from the last encode pass MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); +// whether we need to set non-causal mask before llama_decode +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); + + + // // helper functions (can be implemented based on other functions) // // helper to count the total number of tokens from a list of chunks, useful to keep track of n_past -MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); +MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks); // helper function that automatically: // 1. run llama_decode() on text chunks @@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); // otherwise, returns 0 on success MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, llama_context * lctx, - mtmd_input_chunks * chunks, + mtmd_input_chunks & chunks, llama_pos pos0, llama_seq_id seq_id, int32_t n_batch); @@ -132,11 +152,6 @@ struct mtmd_context_deleter { }; using mtmd_context_ptr = std::unique_ptr; -struct mtmd_input_chunks_deleter { - void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } -}; -using mtmd_input_chunks_ptr = std::unique_ptr; - #else static_assert(false && "C header is not yet supported by this library");