diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 72b32df9f..709d5ad96 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -165,6 +165,10 @@ if (NOT GGML_BACKEND_DL) llama_build_and_test(test-rope.cpp) endif() +# libmtmd +set(LLAMA_TEST_NAME test-mtmd-c-api) +llama_build_and_test(test-mtmd-c-api.c) +target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd) # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c new file mode 100644 index 000000000..02e762e6a --- /dev/null +++ b/tests/test-mtmd-c-api.c @@ -0,0 +1,63 @@ +#include +#include + +#include "mtmd.h" + +int main(void) { + printf("\n\nTesting libmtmd C API...\n"); + printf("--------\n\n"); + + struct mtmd_context_params params = mtmd_context_params_default(); + printf("Default image marker: %s\n", params.image_marker); + + mtmd_input_chunks * chunks = mtmd_test_create_input_chunks(); + + if (!chunks) { + fprintf(stderr, "Failed to create input chunks\n"); + return 1; + } + + size_t n_chunks = mtmd_input_chunks_size(chunks); + printf("Number of chunks: %zu\n", n_chunks); + assert(n_chunks > 0); + + for (size_t i = 0; i < n_chunks; i++) { + const mtmd_input_chunk * chunk = mtmd_input_chunks_get(chunks, i); + assert(chunk != NULL); + enum mtmd_input_chunk_type type = mtmd_input_chunk_get_type(chunk); + printf("Chunk %zu type: %d\n", i, type); + + if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const llama_token * tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + printf(" Text chunk with %zu tokens\n", n_tokens); + assert(tokens != NULL); + assert(n_tokens > 0); + for (size_t j = 0; j < n_tokens; j++) { + assert(tokens[j] >= 0); + printf(" > Token %zu: %d\n", j, tokens[j]); + } + + } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + size_t nx = mtmd_image_tokens_get_nx(image_tokens); + size_t ny = mtmd_image_tokens_get_ny(image_tokens); + const char * id = mtmd_image_tokens_get_id(image_tokens); + assert(n_tokens > 0); + assert(nx > 0); + assert(ny > 0); + assert(id != NULL); + printf(" Image chunk with %zu tokens\n", n_tokens); + printf(" Image size: %zu x %zu\n", nx, ny); + printf(" Image ID: %s\n", id); + } + } + + // Free the chunks + mtmd_input_chunks_free(chunks); + + printf("\n\nDONE: test libmtmd C API...\n"); + + return 0; +} diff --git a/tools/llava/clip-impl.h b/tools/llava/clip-impl.h index b78d930bc..fb780e9de 100644 --- a/tools/llava/clip-impl.h +++ b/tools/llava/clip-impl.h @@ -233,6 +233,15 @@ struct clip_image_u8_batch { struct clip_image_f32_batch { std::vector entries; + + clip_image_f32_batch clone() const { + clip_image_f32_batch new_batch; + new_batch.entries.reserve(entries.size()); + for (const auto & entry : entries) { + new_batch.entries.emplace_back(new clip_image_f32(*entry)); + } + return new_batch; + } }; // diff --git a/tools/llava/clip.h b/tools/llava/clip.h index 0a53bd8eb..0b0eb0295 100644 --- a/tools/llava/clip.h +++ b/tools/llava/clip.h @@ -78,10 +78,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); -CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava +CLIP_API struct clip_image_size * clip_image_size_init(void); +CLIP_API struct clip_image_u8 * clip_image_u8_init (void); +CLIP_API struct clip_image_f32 * clip_image_f32_init(void); +CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava // nx, ny are the output image dimensions CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); diff --git a/tools/llava/mtmd-cli.cpp b/tools/llava/mtmd-cli.cpp index 474e7c4f8..dd18e0fe6 100644 --- a/tools/llava/mtmd-cli.cpp +++ b/tools/llava/mtmd-cli.cpp @@ -63,7 +63,7 @@ static void sigint_handler(int signo) { #endif struct mtmd_cli_context { - mtmd_context_ptr ctx_vision; + mtmd::context_ptr ctx_vision; common_init_result llama_init; llama_model * model; @@ -72,7 +72,7 @@ struct mtmd_cli_context { llama_batch batch; int n_batch; - std::vector bitmaps; + mtmd::bitmaps bitmaps; // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another // so here we don't need to keep track of chat history @@ -115,12 +115,12 @@ struct mtmd_cli_context { void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); - ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ - /* use_gpu */ params.mmproj_use_gpu, - /* timings */ true, - /* n_threads */ params.cpuparams.n_threads, - /* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO, - })); + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params.mmproj_use_gpu; + mparams.print_timings = true; + mparams.n_threads = params.cpuparams.n_threads; + mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); exit(1); @@ -139,11 +139,11 @@ struct mtmd_cli_context { } bool load_image(const std::string & fname) { - mtmd_bitmap bitmap; - if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str())); + if (!bmp.ptr) { return false; } - bitmaps.push_back(std::move(bitmap)); + bitmaps.entries.push_back(std::move(bmp)); return true; } }; @@ -193,27 +193,40 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_ LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); mtmd_input_text text; - text.text = formatted_chat.prompt; + text.text = formatted_chat.prompt.c_str(); text.add_special = add_bos; text.parse_special = true; - mtmd_input_chunks chunks; if (g_is_interrupted) return 0; - int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps); + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = ctx.bitmaps.c_ptr(); + int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), + chunks.ptr.get(), // output + &text, // text + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); if (res != 0) { LOG_ERR("Unable to tokenize prompt, res = %d\n", res); return 1; } - ctx.bitmaps.clear(); + ctx.bitmaps.entries.clear(); - if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) { + llama_pos new_n_past; + if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(), + ctx.lctx, // lctx + chunks.ptr.get(), // chunks + ctx.n_past, // n_past + 0, // seq_id + ctx.n_batch, // n_batch + true, // logits_last + &new_n_past)) { LOG_ERR("Unable to eval prompt\n"); return 1; } - ctx.n_past += mtmd_helper_get_n_pos(chunks); + ctx.n_past = new_n_past; LOG("\n"); @@ -246,7 +259,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; - // ctrl+C handling + // Ctrl+C handling { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; diff --git a/tools/llava/mtmd.cpp b/tools/llava/mtmd.cpp index 73abf2ad1..b600e4341 100644 --- a/tools/llava/mtmd.cpp +++ b/tools/llava/mtmd.cpp @@ -12,6 +12,30 @@ #include #include +// represents raw image data, layout is RGBRGBRGB... +// length of data must be nx * ny * 3 +struct mtmd_bitmap { + uint32_t nx; + uint32_t ny; + std::vector data; + std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking +}; + +struct mtmd_image_tokens_deleter { + void operator()(mtmd_image_tokens * val); // forward declaration +}; +using mtmd_image_tokens_ptr = std::unique_ptr; + +struct mtmd_input_chunk { + mtmd_input_chunk_type type; + std::vector tokens_text; + mtmd_image_tokens_ptr tokens_image; +}; + +struct mtmd_input_chunks { + std::vector entries; +}; + // slice template, used by some llava-uhd models to correctly place the special tokens around image embeddings // models not having it (llava-1.6) will process embeddings without any special tokens in-between enum mtmd_slice_tmpl { @@ -21,6 +45,16 @@ enum mtmd_slice_tmpl { // TODO @ngxson : add support for idefics (SmolVLM) }; +mtmd_context_params mtmd_context_params_default() { + mtmd_context_params params; + params.use_gpu = true; + params.print_timings = true; + params.n_threads = 4; + params.verbosity = GGML_LOG_LEVEL_INFO; + params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; + return params; +} + struct mtmd_context { struct clip_ctx * ctx_clip; const struct llama_model * text_model; @@ -132,6 +166,16 @@ struct mtmd_image_tokens { uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches std::string id; // optional user-defined ID, useful for KV cache tracking + + mtmd_image_tokens clone() { + return mtmd_image_tokens{ + nx, + ny, + use_mrope_pos, + batch_f32.clone(), + id + }; + } }; mtmd_context * mtmd_init_from_file(const char * mmproj_fname, @@ -172,12 +216,13 @@ static std::vector mtmd_tokenize_text_internal( } int32_t mtmd_tokenize(mtmd_context * ctx, - std::vector & output, - const mtmd_input_text & text, - const std::vector & bitmaps) { + mtmd_input_chunks * output, + const mtmd_input_text * text, + const mtmd_bitmap ** bitmaps, + size_t n_bitmaps) { auto vocab = llama_model_get_vocab(ctx->text_model); - std::string prompt_modified(text.text); + std::string prompt_modified(text->text); std::string marker_modified(ctx->image_marker); projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); @@ -211,8 +256,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx, // for glm-edge, BOI and EOI token's embeddings are not present in the text model std::vector parts = string_split_str(prompt_modified, ctx->image_marker); - output.clear(); - output.reserve(parts.size()); + output->entries.clear(); + output->entries.reserve(parts.size()); size_t i_img = 0; @@ -223,7 +268,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); }; // utility for splitting batch of multiple images into chunks of batch having single images @@ -251,7 +296,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, for (const auto & part : parts) { // printf("tokenizing part: %s\n", part.c_str()); bool add_bos = &parts.front() == ∂ - auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special); + auto tokens = mtmd_tokenize_text_internal(vocab, part, text->add_special && add_bos, text->parse_special); if (tokens.empty()) { continue; } @@ -260,22 +305,22 @@ int32_t mtmd_tokenize(mtmd_context * ctx, std::move(tokens), {}, }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); if (&parts.back() != &part) { // add image token to middle of 2 parts - if (i_img >= bitmaps.size()) { + if (i_img >= n_bitmaps) { LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); return 1; } // convert mtmd_bitmap to clip_image_u8 clip_image_u8_ptr img_u8(clip_image_u8_init()); - img_u8->nx = bitmaps[i_img].nx; - img_u8->ny = bitmaps[i_img].ny; - img_u8->buf.resize(bitmaps[i_img].data.size()); - std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3); + img_u8->nx = bitmaps[i_img]->nx; + img_u8->ny = bitmaps[i_img]->ny; + img_u8->buf.resize(bitmaps[i_img]->data.size()); + std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3); clip_image_size img_u8_size{img_u8->nx, img_u8->ny}; // preprocess image @@ -288,12 +333,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) { // split batch into chunks of single images - auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img].id); + auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id); GGML_ASSERT(chunks.size() > 0); // add overview image add_text_chunk({ctx->tok_ov_img_start}); - output.emplace_back(std::move(chunks.front())); + output->entries.emplace_back(std::move(chunks.front())); chunks.erase(chunks.begin()); add_text_chunk({ctx->tok_ov_img_end}); @@ -311,7 +356,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_start}); } - output.emplace_back(std::move(chunks[y * n_col + x])); + output->entries.emplace_back(std::move(chunks[y * n_col + x])); if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_end}); } @@ -343,7 +388,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, image_tokens->ny = 1; } image_tokens->batch_f32 = std::move(batch_f32); - image_tokens->id = bitmaps[i_img].id; // optional + image_tokens->id = bitmaps[i_img]->id; // optional LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx); LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny); @@ -354,7 +399,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx, {}, std::move(image_tokens), }; - output.emplace_back(std::move(chunk)); + output->entries.emplace_back(std::move(chunk)); } i_img++; // move to next image @@ -364,35 +409,12 @@ int32_t mtmd_tokenize(mtmd_context * ctx, return 0; } -void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { +static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) { if (image_tokens) { delete image_tokens; } } -size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { - return image_tokens->n_tokens(); -} - -size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { - return image_tokens->nx; -} - -size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { - return image_tokens->ny; -} - -std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { - return image_tokens->id; -} - -llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { - if (image_tokens->use_mrope_pos) { - return 1; // for M-RoPE, the whole image is 1 in temporal dimension - } - return image_tokens->n_tokens(); -} - int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); @@ -432,13 +454,18 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { +size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { size_t n_tokens = 0; - for (auto & chunk : chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - n_tokens += chunk.tokens_text.size(); - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - n_tokens += mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_tokens += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); } else { GGML_ASSERT(false && "chunk type not supported"); } @@ -446,13 +473,18 @@ size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) { return n_tokens; } -llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks) { +llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { llama_pos n_pos = 0; - for (auto & chunk : chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - n_pos += chunk.tokens_text.size(); - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - n_pos += mtmd_image_tokens_get_n_pos(chunk.tokens_image.get()); + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_pos += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_pos += mtmd_image_tokens_get_n_pos(tokens_image); } else { GGML_ASSERT(false && "chunk type not supported"); } @@ -548,143 +580,172 @@ struct decode_embd_batch { } }; -int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks & chunks, - llama_pos pos0, +int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, llama_seq_id seq_id, - int32_t n_batch) { + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { int32_t ret; - llama_pos n_past = pos0; llama_batch text_batch = llama_batch_init(n_batch, 0, 1); + auto chunk_type = mtmd_input_chunk_get_type(chunk); int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1; - for (auto & chunk : chunks) { - bool is_last = &chunk == &chunks.back(); - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - text_batch.n_tokens = chunk.tokens_text.size(); - size_t i = 0; - while (i < chunk.tokens_text.size()) { // split into batches - for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) { - text_batch.token [i] = chunk.tokens_text[i]; - text_batch.pos [i] = n_past++; - text_batch.n_seq_id[i] = 1; - text_batch.seq_id [i][0] = seq_id; - text_batch.logits [i] = false; - } - if (is_last) { - // always get logits for last input chunk - text_batch.logits[text_batch.n_tokens - 1] = true; - } - ret = llama_decode(lctx, text_batch); - if (ret != 0) { - LOG_ERR("failed to decode text\n"); - llama_batch_free(text_batch); - return ret; - } + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens); + size_t i = 0; + while (i < n_tokens) { // split into batches + text_batch.n_tokens = 0; // clear the batch + for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { + text_batch.n_tokens++; + text_batch.token [i] = tokens[i]; + text_batch.pos [i] = n_past++; + text_batch.n_seq_id[i] = 1; + text_batch.seq_id [i][0] = seq_id; + text_batch.logits [i] = false; } - - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported"); - GGML_ASSERT(chunk.tokens_image != nullptr); - int64_t t0 = ggml_time_ms(); - if (ctx->print_timings) { - LOG_INF("encoding image or slice...\n"); + bool is_last_token = (i == n_tokens); + if (logits_last && is_last_token) { + text_batch.logits[text_batch.n_tokens - 1] = true; } - ret = mtmd_encode(ctx, chunk.tokens_image.get()); + ret = llama_decode(lctx, text_batch); if (ret != 0) { - LOG_ERR("failed to encode image\n"); + LOG_ERR("failed to decode text\n"); llama_batch_free(text_batch); return ret; } - if (ctx->print_timings) { - LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - } - - int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get()); - int32_t i_batch = 0; - int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; - float * embd = mtmd_get_output_embd(ctx); - decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd); - - const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get()); - const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get()); - - if (mtmd_decode_use_mrope(ctx)) { - batch_embd.set_position_mrope(n_past, nx, ny, seq_id); - } else { - batch_embd.set_position_normal(n_past, seq_id); - } - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, false); - // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image - } - - while (i_batch < n_img_batches) { // split into batches - int pos_offset = i_batch*n_batch; - int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); - llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); - - LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); - - int64_t t1 = ggml_time_ms(); - ret = llama_decode(lctx, batch_embd_view); - if (ret != 0) { - LOG_ERR("failed to decode image\n"); - llama_set_causal_attn(lctx, true); // restore causal attn - llama_batch_free(text_batch); - return ret; - } - - if (ctx->print_timings) { - LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); - } - - i_batch++; - } - - // for mrope, one image is one single **temporal** position - n_past += mtmd_decode_use_mrope(ctx) ? 1 : n_tokens; - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, true); - } - - } else { - GGML_ASSERT(false && "chunk type not supported"); + *new_n_past += text_batch.n_tokens; } + + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + int64_t t0 = ggml_time_ms(); + if (ctx->print_timings) { + LOG_INF("encoding image or slice...\n"); + } + ret = mtmd_encode(ctx, image_tokens); + if (ret != 0) { + LOG_ERR("failed to encode image\n"); + llama_batch_free(text_batch); + return ret; + } + if (ctx->print_timings) { + LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + } + + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + int32_t i_batch = 0; + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + float * embd = mtmd_get_output_embd(ctx); + decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd); + + const int nx = mtmd_image_tokens_get_nx(image_tokens); + const int ny = mtmd_image_tokens_get_ny(image_tokens); + + if (mtmd_decode_use_mrope(ctx)) { + batch_embd.set_position_mrope(n_past, nx, ny, seq_id); + } else { + batch_embd.set_position_normal(n_past, seq_id); + } + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, false); + // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image + } + + while (i_batch < n_img_batches) { // split into batches + int pos_offset = i_batch*n_batch; + int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); + llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); + + LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); + + int64_t t1 = ggml_time_ms(); + ret = llama_decode(lctx, batch_embd_view); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_set_causal_attn(lctx, true); // restore causal attn + llama_batch_free(text_batch); + return ret; + } + + if (ctx->print_timings) { + LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); + } + + i_batch++; + } + + n_past += mtmd_image_tokens_get_n_pos(image_tokens); + *new_n_past = n_past; + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, true); + } + + } else { + GGML_ABORT("chunk type not supported"); } - llama_batch_free(text_batch); return 0; } -int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) { +int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunks * chunks, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { + size_t n_chunks = mtmd_input_chunks_size(chunks); + if (n_chunks == 0) { + LOG_WRN("no chunks to eval\n"); + return 0; + } + + for (size_t i = 0; i < n_chunks; i++) { + bool chunk_logits_last = (i == n_chunks - 1) && logits_last; + auto chunk = mtmd_input_chunks_get(chunks, i); + + int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); + if (res != 0) { + LOG_ERR("failed to eval chunk %zu\n", i); + return res; + } + *new_n_past = n_past; + } + + return 0; +} + +mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) { clip_image_u8_ptr img_u8(clip_image_u8_init()); bool ok = clip_image_load_from_bytes(buf, len, img_u8.get()); if (!ok) { LOG_ERR("Unable to load image from buffer\n"); - return 1; + return nullptr; } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; + uint32_t nx, ny; + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny); + return mtmd_bitmap_init(nx, ny, data); } -int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) { +mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) { clip_image_u8_ptr img_u8(clip_image_u8_init()); bool ok = clip_image_load_from_file(fname, img_u8.get()); if (!ok) { LOG_ERR("Unable to load image %s\n", fname); - return 1; + return nullptr; } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; + uint32_t nx, ny; + unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny); + return mtmd_bitmap_init(nx, ny, data); } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { @@ -702,3 +763,175 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) { void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { mtmd_image_tokens_free(val); } + + +// +// public API functions +// + +// mtmd_bitmap + +mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, + uint32_t ny, + const unsigned char * data) { + mtmd_bitmap * bitmap = new mtmd_bitmap; + bitmap->nx = nx; + bitmap->ny = ny; + size_t data_size = (size_t)nx * ny * 3; + bitmap->data.resize(data_size); + std::memcpy(bitmap->data.data(), data, data_size); + return bitmap; +} + +uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) { + return bitmap->nx; +} + +uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) { + return bitmap->ny; +} + +const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) { + return bitmap->data.data(); +} + +const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) { + return bitmap->id.c_str(); +} + +void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id) { + if (id) { + bitmap->id = std::string(id); + } else { + bitmap->id.clear(); + } +} + +void mtmd_bitmap_free(mtmd_bitmap * bitmap) { + if (bitmap) { + delete bitmap; + } +} + +// mtmd_input_chunks + +mtmd_input_chunks * mtmd_input_chunks_init() { + return new mtmd_input_chunks; +} + +size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks) { + return chunks->entries.size(); +} + +const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx) { + if (idx >= chunks->entries.size()) { + return nullptr; + } + return &chunks->entries[idx]; +} + +void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { + if (chunks) { + delete chunks; + } +} + +// mtmd_input_chunk + +enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk) { + return chunk->type; +} + +const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output) { + if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + *n_tokens_output = chunk->tokens_text.size(); + return chunk->tokens_text.data(); + } + *n_tokens_output = 0; + return nullptr; +} + +const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk) { + if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + return chunk->tokens_image.get(); + } + return nullptr; +} + +mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) { + mtmd_input_chunk * copy = new mtmd_input_chunk{ + chunk->type, + chunk->tokens_text, + mtmd_image_tokens_ptr(), + }; + if (chunk->tokens_image) { + // copy the image tokens + copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens()); + *copy->tokens_image = chunk->tokens_image->clone(); + } + return copy; +} + +void mtmd_input_chunk_free(mtmd_input_chunk * chunk) { + if (chunk) { + delete chunk; + } +} + +// mtmd_image_tokens + +size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) { + return image_tokens->n_tokens(); +} + +size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) { + return image_tokens->nx; +} + +size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { + return image_tokens->ny; +} + +const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { + return image_tokens->id.c_str(); +} + +llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { + if (image_tokens->use_mrope_pos) { + return 1; // for M-RoPE, the whole image is 1 in temporal dimension + } + return image_tokens->n_tokens(); +} + +// test function + +mtmd_input_chunks * mtmd_test_create_input_chunks() { + mtmd_input_chunks * chunks = mtmd_input_chunks_init(); + if (!chunks) { + return nullptr; + } + + // create a text chunk + std::vector tokens_text = { 1, 2, 3, 4, 5 }; + mtmd_input_chunk chunk_text{ + MTMD_INPUT_CHUNK_TYPE_TEXT, + std::move(tokens_text), + {}, + }; + chunks->entries.emplace_back(std::move(chunk_text)); + + // create an image chunk + mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); + image_tokens->nx = 4; + image_tokens->ny = 4; + image_tokens->batch_f32.entries.resize(16); + image_tokens->id = "image_1"; + mtmd_input_chunk chunk_image{ + MTMD_INPUT_CHUNK_TYPE_IMAGE, + {}, + std::move(image_tokens), + }; + chunks->entries.emplace_back(std::move(chunk_image)); + + return chunks; +} diff --git a/tools/llava/mtmd.h b/tools/llava/mtmd.h index 6805e5e48..e2f76e2e8 100644 --- a/tools/llava/mtmd.h +++ b/tools/llava/mtmd.h @@ -5,9 +5,24 @@ #include "llama.h" #include "clip.h" +#include +#include +#include + +#ifdef __cplusplus #include #include #include +#endif + +/** + * libmtmd: A library for multimodal support in llama.cpp. + * + * WARNING: This API is experimental and subject to many BREAKING CHANGES. + * Issues related to API usage may receive lower priority support. + * + * For the usage, see an example in mtmd-cli.cpp + */ #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -23,60 +38,118 @@ # define MTMD_API #endif +#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>" + #ifdef __cplusplus +extern "C" { +#endif enum mtmd_input_chunk_type { MTMD_INPUT_CHUNK_TYPE_TEXT, MTMD_INPUT_CHUNK_TYPE_IMAGE, }; +// opaque types struct mtmd_context; +struct mtmd_bitmap; struct mtmd_image_tokens; - -// represents raw image data, layout is RGBRGBRGB... -// length of data must be nx * ny * 3 -struct mtmd_bitmap { - uint32_t nx; - uint32_t ny; - std::vector data; - std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking -}; - -struct mtmd_image_tokens_deleter { - void operator()(mtmd_image_tokens * val); // forward declaration -}; -using mtmd_image_tokens_ptr = std::unique_ptr; - -struct mtmd_input_chunk { - mtmd_input_chunk_type type; - std::vector tokens_text; - mtmd_image_tokens_ptr tokens_image; -}; - -using mtmd_input_chunks = std::vector; - -struct mtmd_context_params { - bool use_gpu = true; - bool print_timings = true; - int n_threads = 4; - enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; - const char * image_marker = "<__image__>"; -}; +struct mtmd_input_chunk; +struct mtmd_input_chunks; struct mtmd_input_text { - std::string text; + const char * text; bool add_special; bool parse_special; }; +// +// C API +// + +typedef struct mtmd_context mtmd_context; +typedef struct mtmd_bitmap mtmd_bitmap; +typedef struct mtmd_image_tokens mtmd_image_tokens; +typedef struct mtmd_input_chunk mtmd_input_chunk; +typedef struct mtmd_input_chunks mtmd_input_chunks; +typedef struct mtmd_input_text mtmd_input_text; + +struct mtmd_context_params { + bool use_gpu; + bool print_timings; + int n_threads; + enum ggml_log_level verbosity; + const char * image_marker; +}; + +MTMD_API struct mtmd_context_params mtmd_context_params_default(void); + // initialize the mtmd context // return nullptr on failure MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, - const llama_model * text_model, - const mtmd_context_params ctx_params); + const struct llama_model * text_model, + const struct mtmd_context_params ctx_params); MTMD_API void mtmd_free(mtmd_context * ctx); +// whether we need to set non-causal mask before llama_decode +MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); + +// whether the current model use M-RoPE for llama_decode +MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); + + +// mtmd_bitmap +// +// length of data must be nx * ny * 3 +// the data is in RGBRGBRGB... format +MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, + uint32_t ny, + const unsigned char * data); +MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap); +MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap); +MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap); +MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap); +// bitmap ID is optional, but useful for KV cache tracking +// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data() +MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap); +MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id); + + +// mtmd_input_chunks +// +// this is simply a list of mtmd_input_chunk +// the elements can only be populated via mtmd_tokenize() +MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void); +MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks); +MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx); +MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); + +// mtmd_input_chunk +// +// the instance will be constructed via mtmd_tokenize() +// it will be freed along with mtmd_input_chunks +MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk); +MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output); +MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk); + +// in case you want to use custom logic to handle the chunk (i.e. KV cache management) +// you can move the chunk ownership to your own code by copying it +// remember to free the chunk when you are done with it +MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk); +MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk); + + +// mtmd_image_tokens +// +// the instance will be constructed via mtmd_tokenize() +// it will be freed along with mtmd_input_chunk +MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens); +MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens); +MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); +// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) +MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); + // tokenize an input text prompt and an image // the prompt must have the input image marker (default: "<__image__>") in it // the marker will be replaced with the image tokens @@ -93,75 +166,152 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // 1 on number of images not matching the number of markers // 2 on image preprocessing error MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx, - std::vector & output, - const mtmd_input_text & text, - const std::vector & bitmaps); - -// access mtmd_image_tokens -MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens); -MTMD_API size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens); -MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens); -MTMD_API llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens); // number of temporal positions (always 1 for M-RoPE, n_tokens otherwise) -MTMD_API void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens); + mtmd_input_chunks * output, + const mtmd_input_text * text, + const mtmd_bitmap ** bitmaps, + size_t n_bitmaps); // returns 0 on success MTMD_API int32_t mtmd_encode(mtmd_context * ctx, - const mtmd_image_tokens * image_tokens); + const mtmd_image_tokens * image_tokens); // get output embeddings from the last encode pass MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); -// whether we need to set non-causal mask before llama_decode -MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx); - -// whether the current model use M-RoPE for llama_decode -MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); - - +///////////////////////////////////////// // -// helper functions (can be implemented based on other functions) +// Helper functions (can be implemented based on other functions) // +// Please note that these helpers are not guaranteed to be stable. +// BREAKING CHANGES are expected. +// + +// helper function to construct a mtmd_bitmap from a file +// returns nullptr on failure +// this function is thread-safe +MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname); + +// helper function to construct a mtmd_bitmap from a buffer containing a file +// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) +// returns nullptr on failure +// this function is thread-safe +MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len); // helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache -MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks); +MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks); // helper to count the total position of tokens from a list of chunks, useful to keep track of n_past -MTMD_API llama_pos mtmd_helper_get_n_pos(mtmd_input_chunks & chunks); +// normally, n_pos is equal to n_tokens, but for M-RoPE it is different +MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks); // helper function that automatically: // 1. run llama_decode() on text chunks // 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() // if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error // otherwise, returns 0 on success -MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks & chunks, - llama_pos pos0, - llama_seq_id seq_id, - int32_t n_batch); +// this function is NOT thread-safe +MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunks * chunks, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past); -// helper function to construct a mtmd_bitmap from a file -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output); +// works like mtmd_helper_eval_chunks(), but only for a single chunk +// this function is NOT thread-safe +MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past); -// helper function to construct a mtmd_bitmap from a buffer -// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output); +///////////////////////////////////////// + +// test function, to be used in test-mtmd-c-api.c +MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void); + +#ifdef __cplusplus +} // extern "C" +#endif + +// +// C++ wrappers +// + +#ifdef __cplusplus + +namespace mtmd { -// convenient unique_ptr wrappers struct mtmd_context_deleter { void operator()(mtmd_context * val) { mtmd_free(val); } }; -using mtmd_context_ptr = std::unique_ptr; +using context_ptr = std::unique_ptr; -#else +struct mtmd_bitmap_deleter { + void operator()(mtmd_bitmap * val) { mtmd_bitmap_free(val); } +}; +using bitmap_ptr = std::unique_ptr; -static_assert(false && "C header is not yet supported by this library"); +struct mtmd_input_chunks_deleter { + void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } +}; +using input_chunks_ptr = std::unique_ptr; + +struct mtmd_input_chunk_deleter { + void operator()(mtmd_input_chunk * val) { mtmd_input_chunk_free(val); } +}; +using input_chunk_ptr = std::unique_ptr; + +struct bitmap { + bitmap_ptr ptr; + bitmap() : ptr(nullptr) {} + bitmap(mtmd_bitmap * bitmap) : ptr(bitmap) {} + bitmap(bitmap && other) noexcept : ptr(std::move(other.ptr)) {} + bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) { + ptr.reset(mtmd_bitmap_init(nx, ny, data)); + } + ~bitmap() = default; + uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); } + uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); } + const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); } + std::string id() { return mtmd_bitmap_get_id(ptr.get()); } + void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); } +}; + +struct bitmaps { + std::vector entries; + ~bitmaps() = default; + // return list of pointers to mtmd_bitmap + // example: + // auto bitmaps_c_ptr = bitmaps.c_ptr(); + // int32_t res = mtmd_tokenize(... bitmaps_c_ptr.data(), bitmaps_c_ptr.size()); + std::vector c_ptr() { + std::vector res(entries.size()); + for (size_t i = 0; i < entries.size(); i++) { + res[i] = entries[i].ptr.get(); + } + return res; + } +}; + +struct input_chunks { + input_chunks_ptr ptr; + input_chunks() = default; + input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {} + ~input_chunks() = default; + size_t size() { return mtmd_input_chunks_size(ptr.get()); } + const mtmd_input_chunk * operator[](size_t idx) { + return mtmd_input_chunks_get(ptr.get(), idx); + } +}; + +} // namespace mtmd #endif