mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
llava : introduce libmtmd (#12849)
* wip llava2 * migrated gemma3 to llava2 * add timings * correct pre/postfix * fix missing include * fix compilation unused var warn * update llava2_tokenize * change name llava2 --> mtmd * improve api * refine helpers * Update examples/llava/mtmd.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@ -1,3 +1,5 @@
|
|||||||
|
# llava (legacy)
|
||||||
|
|
||||||
add_library(llava OBJECT
|
add_library(llava OBJECT
|
||||||
llava.cpp
|
llava.cpp
|
||||||
llava.h
|
llava.h
|
||||||
@ -22,12 +24,41 @@ if (BUILD_SHARED_LIBS)
|
|||||||
install(TARGETS llava_shared LIBRARY)
|
install(TARGETS llava_shared LIBRARY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# mtmd
|
||||||
|
|
||||||
|
add_library(mtmd OBJECT
|
||||||
|
mtmd.cpp
|
||||||
|
mtmd.h
|
||||||
|
clip.cpp
|
||||||
|
clip.h
|
||||||
|
clip-impl.h
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
|
target_include_directories(mtmd PUBLIC .)
|
||||||
|
target_include_directories(mtmd PRIVATE ../..)
|
||||||
|
target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h
|
||||||
|
|
||||||
|
target_compile_features(mtmd PRIVATE cxx_std_17)
|
||||||
|
|
||||||
|
add_library(mtmd_static STATIC $<TARGET_OBJECTS:mtmd>)
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||||
|
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
|
||||||
|
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
install(TARGETS mtmd_shared LIBRARY)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
|
target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
|
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(TARGET BUILD_INFO)
|
if(TARGET BUILD_INFO)
|
||||||
add_dependencies(llava BUILD_INFO)
|
add_dependencies(llava BUILD_INFO)
|
||||||
|
add_dependencies(mtmd BUILD_INFO)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(TARGET llama-llava-cli)
|
set(TARGET llama-llava-cli)
|
||||||
@ -55,7 +86,7 @@ set(TARGET llama-gemma3-cli)
|
|||||||
add_executable(${TARGET} gemma3-cli.cpp)
|
add_executable(${TARGET} gemma3-cli.cpp)
|
||||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
|
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||||
|
|
||||||
set(TARGET llama-llava-clip-quantize-cli)
|
set(TARGET llama-llava-clip-quantize-cli)
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "gguf.h"
|
#include "gguf.h"
|
||||||
|
|
||||||
|
#include "clip.h"
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
// Internal header for clip.cpp
|
// Internal header for clip.cpp
|
||||||
|
|
||||||
@ -120,6 +123,23 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
|
|||||||
return PROJECTOR_TYPE_UNKNOWN;
|
return PROJECTOR_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RGB uint8 image
|
||||||
|
struct clip_image_u8 {
|
||||||
|
int nx;
|
||||||
|
int ny;
|
||||||
|
|
||||||
|
std::vector<uint8_t> buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
// RGB float32 image (NHWC)
|
||||||
|
// Memory layout: RGBRGBRGB...
|
||||||
|
struct clip_image_f32 {
|
||||||
|
int nx;
|
||||||
|
int ny;
|
||||||
|
|
||||||
|
std::vector<float> buf;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// logging
|
// logging
|
||||||
//
|
//
|
||||||
@ -178,6 +198,28 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
|
|||||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||||
|
|
||||||
|
//
|
||||||
|
// cpp wrappers
|
||||||
|
//
|
||||||
|
|
||||||
|
struct clip_image_u8_deleter {
|
||||||
|
void operator()(clip_image_u8 * val) { clip_image_u8_free(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_image_f32_deleter {
|
||||||
|
void operator()(clip_image_f32 * val) { clip_image_f32_free(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct clip_image_f32_batch_deleter {
|
||||||
|
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::unique_ptr<clip_image_u8, clip_image_u8_deleter> clip_image_u8_ptr;
|
||||||
|
typedef std::unique_ptr<clip_image_f32, clip_image_f32_deleter> clip_image_f32_ptr;
|
||||||
|
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
|
||||||
|
|
||||||
|
// TODO @ngxson : we're currently having a naming clash between struct clip_image_size and function clip_image_size()
|
||||||
|
|
||||||
//
|
//
|
||||||
// common utils
|
// common utils
|
||||||
//
|
//
|
||||||
@ -214,6 +256,20 @@ static void string_replace_all(std::string & s, const std::string & search, cons
|
|||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// split string by a `std::string delim` instead of `char delim`
|
||||||
|
static std::vector<std::string> string_split_str(std::string s, const std::string & delimiter) {
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::string token;
|
||||||
|
while ((pos = s.find(delimiter)) != std::string::npos) {
|
||||||
|
token = s.substr(0, pos);
|
||||||
|
tokens.push_back(token);
|
||||||
|
s.erase(0, pos + delimiter.length());
|
||||||
|
}
|
||||||
|
tokens.push_back(s);
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// gguf utils
|
// gguf utils
|
||||||
//
|
//
|
||||||
@ -271,3 +327,9 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
|||||||
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
|
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// API used internally with mtmd
|
||||||
|
//
|
||||||
|
|
||||||
|
projector_type clip_get_projector_type(const struct clip_ctx * ctx);
|
||||||
|
@ -32,23 +32,6 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
|
|||||||
|
|
||||||
//#define CLIP_DEBUG_FUNCTIONS
|
//#define CLIP_DEBUG_FUNCTIONS
|
||||||
|
|
||||||
// RGB uint8 image
|
|
||||||
struct clip_image_u8 {
|
|
||||||
int nx;
|
|
||||||
int ny;
|
|
||||||
|
|
||||||
std::vector<uint8_t> buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
// RGB float32 image (NHWC)
|
|
||||||
// Memory layout: RGBRGBRGB...
|
|
||||||
struct clip_image_f32 {
|
|
||||||
int nx;
|
|
||||||
int ny;
|
|
||||||
|
|
||||||
std::vector<float> buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef CLIP_DEBUG_FUNCTIONS
|
#ifdef CLIP_DEBUG_FUNCTIONS
|
||||||
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
|
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
|
||||||
std::ofstream file(filename, std::ios::binary);
|
std::ofstream file(filename, std::ios::binary);
|
||||||
@ -1614,6 +1597,12 @@ struct clip_image_f32 * clip_image_f32_init() {
|
|||||||
return new clip_image_f32();
|
return new clip_image_f32();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
|
||||||
|
if (nx) *nx = img->nx;
|
||||||
|
if (ny) *ny = img->ny;
|
||||||
|
return img->buf.data();
|
||||||
|
}
|
||||||
|
|
||||||
void clip_image_size_free(struct clip_image_size * load_image_size) {
|
void clip_image_size_free(struct clip_image_size * load_image_size) {
|
||||||
if (load_image_size == nullptr) {
|
if (load_image_size == nullptr) {
|
||||||
return;
|
return;
|
||||||
@ -2346,6 +2335,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
|||||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||||
n_patches = x_patch * y_patch;
|
n_patches = x_patch * y_patch;
|
||||||
|
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||||
|
n_patches = 256;
|
||||||
}
|
}
|
||||||
|
|
||||||
return n_patches;
|
return n_patches;
|
||||||
@ -2893,3 +2884,11 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
|
|||||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// API used internally with mtmd
|
||||||
|
//
|
||||||
|
|
||||||
|
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
|
||||||
|
return ctx->proj_type;
|
||||||
|
}
|
||||||
|
@ -77,6 +77,9 @@ CLIP_API struct clip_image_size * clip_image_size_init();
|
|||||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||||
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
||||||
|
|
||||||
|
// nx, ny are the output image dimensions
|
||||||
|
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||||
|
|
||||||
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
|
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
|
||||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
||||||
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "clip.h"
|
|
||||||
#include "stb_image.h"
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "console.h"
|
#include "console.h"
|
||||||
|
#include "chat.h"
|
||||||
|
#include "mtmd.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <limits.h>
|
#include <limits.h>
|
||||||
@ -57,13 +57,18 @@ static void sigint_handler(int signo) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct gemma3_context {
|
struct gemma3_context {
|
||||||
struct clip_ctx * ctx_clip = NULL;
|
mtmd_context_ptr ctx_vision;
|
||||||
common_init_result llama_init;
|
common_init_result llama_init;
|
||||||
|
|
||||||
llama_model * model;
|
llama_model * model;
|
||||||
llama_context * lctx;
|
llama_context * lctx;
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
|
int n_batch;
|
||||||
|
|
||||||
|
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
|
||||||
|
// so here we don't need to keep track of chat history
|
||||||
|
common_chat_templates_ptr tmpls;
|
||||||
|
|
||||||
int n_threads = 1;
|
int n_threads = 1;
|
||||||
llama_pos n_past = 0;
|
llama_pos n_past = 0;
|
||||||
@ -74,21 +79,24 @@ struct gemma3_context {
|
|||||||
vocab = llama_model_get_vocab(model);
|
vocab = llama_model_get_vocab(model);
|
||||||
n_threads = params.cpuparams.n_threads;
|
n_threads = params.cpuparams.n_threads;
|
||||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||||
init_clip_model(params);
|
n_batch = params.n_batch;
|
||||||
|
tmpls = common_chat_templates_init(model, params.chat_template);
|
||||||
|
init_vision_context(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_clip_model(common_params & params) {
|
void init_vision_context(common_params & params) {
|
||||||
const char * clip_path = params.mmproj.path.c_str();
|
const char * clip_path = params.mmproj.path.c_str();
|
||||||
ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
|
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
|
||||||
if (!ctx_clip) {
|
/* use_gpu */ true,
|
||||||
LOG_ERR("Failed to load CLIP model from %s\n", clip_path);
|
/* timings */ true,
|
||||||
|
/* n_threads */ params.cpuparams.n_threads,
|
||||||
|
/* verbosity */ GGML_LOG_LEVEL_INFO,
|
||||||
|
}));
|
||||||
|
if (!ctx_vision.get()) {
|
||||||
|
LOG_ERR("Failed to load vision model from %s\n", clip_path);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~gemma3_context() {
|
|
||||||
clip_free(ctx_clip);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct decode_embd_batch {
|
struct decode_embd_batch {
|
||||||
@ -124,77 +132,6 @@ struct decode_embd_batch {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
|
||||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
|
||||||
common_batch_clear(ctx.batch);
|
|
||||||
for (llama_token & t : tokens) {
|
|
||||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
|
||||||
}
|
|
||||||
if (logits_last) {
|
|
||||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
|
||||||
}
|
|
||||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
|
||||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
|
||||||
LOG_ERR("Failed to decode text\n");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int eval_image(gemma3_context & ctx, std::string & fname) {
|
|
||||||
std::vector<float> image_embd_v;
|
|
||||||
int n_embd = llama_model_n_embd(ctx.model);
|
|
||||||
int n_tokens = 256;
|
|
||||||
image_embd_v.resize(n_tokens * n_embd);
|
|
||||||
|
|
||||||
bool ok;
|
|
||||||
struct clip_image_u8 * img_u8 = clip_image_u8_init();
|
|
||||||
ok = clip_image_load_from_file(fname.c_str(), img_u8);
|
|
||||||
if (!ok) {
|
|
||||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
|
||||||
clip_image_u8_free(img_u8);
|
|
||||||
return 2; // non-fatal error
|
|
||||||
}
|
|
||||||
|
|
||||||
clip_image_f32_batch batch_f32;
|
|
||||||
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
|
|
||||||
if (!ok) {
|
|
||||||
LOG_ERR("Unable to preprocess image\n");
|
|
||||||
clip_image_f32_batch_free(&batch_f32);
|
|
||||||
clip_image_u8_free(img_u8);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t t0 = ggml_time_ms();
|
|
||||||
LOG("Encoding image %s\n", fname.c_str());
|
|
||||||
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
|
|
||||||
if (!ok) {
|
|
||||||
LOG_ERR("Unable to encode image\n");
|
|
||||||
clip_image_f32_batch_free(&batch_f32);
|
|
||||||
clip_image_u8_free(img_u8);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
|
||||||
|
|
||||||
clip_image_f32_batch_free(&batch_f32);
|
|
||||||
clip_image_u8_free(img_u8);
|
|
||||||
|
|
||||||
// decode image embeddings
|
|
||||||
int64_t t1 = ggml_time_ms();
|
|
||||||
eval_text(ctx, "<start_of_image>");
|
|
||||||
llama_set_causal_attn(ctx.lctx, false);
|
|
||||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
|
||||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
|
||||||
LOG_ERR("failed to decode image\n");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
ctx.n_past += n_tokens;
|
|
||||||
llama_set_causal_attn(ctx.lctx, true);
|
|
||||||
eval_text(ctx, "<end_of_image>");
|
|
||||||
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
|
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
|
||||||
for (int i = 0; i < n_predict; i++) {
|
for (int i = 0; i < n_predict; i++) {
|
||||||
if (i > n_predict || !g_is_generating) {
|
if (i > n_predict || !g_is_generating) {
|
||||||
@ -224,6 +161,45 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
|
||||||
|
std::vector<mtmd_bitmap> bitmaps;
|
||||||
|
|
||||||
|
common_chat_templates_inputs tmpl_inputs;
|
||||||
|
tmpl_inputs.messages = {msg};
|
||||||
|
tmpl_inputs.add_generation_prompt = true;
|
||||||
|
tmpl_inputs.use_jinja = false; // jinja is buggy here
|
||||||
|
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
|
||||||
|
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
|
||||||
|
|
||||||
|
for (auto & fname : images_fname) {
|
||||||
|
mtmd_bitmap bitmap;
|
||||||
|
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
|
||||||
|
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
||||||
|
return 2; // image not found
|
||||||
|
}
|
||||||
|
bitmaps.push_back(std::move(bitmap));
|
||||||
|
}
|
||||||
|
|
||||||
|
mtmd_input_text text;
|
||||||
|
text.text = formatted_chat.prompt;
|
||||||
|
text.add_special = add_bos;
|
||||||
|
text.parse_special = true;
|
||||||
|
mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
|
||||||
|
if (chunks == nullptr) {
|
||||||
|
LOG_ERR("Unable to tokenize prompt\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
|
||||||
|
LOG_ERR("Unable to eval prompt\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
@ -265,21 +241,15 @@ int main(int argc, char ** argv) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (eval_text(ctx, "<bos>")) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_single_turn) {
|
if (is_single_turn) {
|
||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
if (params.prompt.find("<__image__>") == std::string::npos) {
|
||||||
return 1;
|
params.prompt += " <__image__>";
|
||||||
}
|
}
|
||||||
for (auto & fname : params.image) {
|
common_chat_msg msg;
|
||||||
if (eval_image(ctx, fname)) {
|
msg.role = "user";
|
||||||
return 1;
|
msg.content = params.prompt;
|
||||||
}
|
if (eval_message(ctx, msg, params.image, true)) {
|
||||||
}
|
|
||||||
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (generate_response(ctx, smpl, n_predict)) {
|
if (generate_response(ctx, smpl, n_predict)) {
|
||||||
@ -293,9 +263,9 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("\n /quit or /exit exit the program");
|
LOG("\n /quit or /exit exit the program");
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
|
|
||||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
bool is_first_msg = true;
|
||||||
return 1;
|
std::vector<std::string> images_fname;
|
||||||
}
|
std::string content;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
g_is_generating = false;
|
g_is_generating = false;
|
||||||
@ -320,24 +290,31 @@ int main(int argc, char ** argv) {
|
|||||||
g_is_generating = true;
|
g_is_generating = true;
|
||||||
if (line.find("/image") == 0) {
|
if (line.find("/image") == 0) {
|
||||||
std::string image = line.substr(7);
|
std::string image = line.substr(7);
|
||||||
int res = eval_image(ctx, image);
|
images_fname.push_back(string_strip(image));
|
||||||
if (res == 2) {
|
content += "<__image__>";
|
||||||
continue; // image not found
|
continue;
|
||||||
}
|
} else {
|
||||||
if (res) {
|
content += line;
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
common_chat_msg msg;
|
||||||
|
msg.role = "user";
|
||||||
|
msg.content = content;
|
||||||
|
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
|
||||||
|
if (ret == 2) {
|
||||||
|
// non-fatal error
|
||||||
|
images_fname.clear();
|
||||||
|
content.clear();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
|
if (ret) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (generate_response(ctx, smpl, n_predict)) {
|
if (generate_response(ctx, smpl, n_predict)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
|
images_fname.clear();
|
||||||
return 1;
|
content.clear();
|
||||||
}
|
is_first_msg = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
341
examples/llava/mtmd.cpp
Normal file
341
examples/llava/mtmd.cpp
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
#include "clip.h"
|
||||||
|
#include "clip-impl.h"
|
||||||
|
#include "mtmd.h"
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cerrno>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct mtmd_context {
|
||||||
|
struct clip_ctx * ctx_clip;
|
||||||
|
const struct llama_model * text_model;
|
||||||
|
std::vector<float> image_embd_v; // image embedding vector
|
||||||
|
bool print_timings;
|
||||||
|
int n_threads;
|
||||||
|
std::string image_marker;
|
||||||
|
|
||||||
|
// TODO @ngxson : add timings
|
||||||
|
|
||||||
|
mtmd_context(const char * mmproj_fname,
|
||||||
|
const llama_model * text_model,
|
||||||
|
const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
|
||||||
|
clip_context_params ctx_clip_params;
|
||||||
|
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||||
|
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||||
|
ctx_clip = clip_init(mmproj_fname, ctx_clip_params);
|
||||||
|
if (!ctx_clip) {
|
||||||
|
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||||
|
}
|
||||||
|
this->text_model = text_model;
|
||||||
|
}
|
||||||
|
|
||||||
|
~mtmd_context() {
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mtmd_image_tokens_data {
|
||||||
|
clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mtmd_image_tokens {
|
||||||
|
uint32_t nx; // number of tokens in x direction
|
||||||
|
uint32_t ny; // number of tokens in y direction
|
||||||
|
uint32_t n_tokens() const { return nx * ny; }
|
||||||
|
clip_image_f32_batch_ptr batch_f32; // preprocessed image patches
|
||||||
|
};
|
||||||
|
|
||||||
|
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||||
|
const struct llama_model * text_model,
|
||||||
|
const struct mtmd_context_params ctx_params) {
|
||||||
|
try {
|
||||||
|
return new mtmd_context(mmproj_fname, text_model, ctx_params);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mtmd_free(mtmd_context * ctx) {
|
||||||
|
if (ctx) {
|
||||||
|
delete ctx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copied from common_tokenize
|
||||||
|
static std::vector<llama_token> mtmd_tokenize_text_internal(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const std::string & text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special) {
|
||||||
|
// upper limit for the number of tokens
|
||||||
|
int n_tokens = text.length() + 2 * add_special;
|
||||||
|
std::vector<llama_token> result(n_tokens);
|
||||||
|
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||||
|
if (n_tokens < 0) {
|
||||||
|
result.resize(-n_tokens);
|
||||||
|
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||||
|
GGML_ASSERT(check == -n_tokens);
|
||||||
|
} else {
|
||||||
|
result.resize(n_tokens);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||||
|
const mtmd_input_text & text,
|
||||||
|
const std::vector<mtmd_bitmap> & bitmaps) {
|
||||||
|
mtmd_input_chunks * output = new mtmd_input_chunks;
|
||||||
|
auto vocab = llama_model_get_vocab(ctx->text_model);
|
||||||
|
|
||||||
|
std::string prompt_modified(text.text);
|
||||||
|
std::string marker_modified(ctx->image_marker);
|
||||||
|
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||||
|
// a bit hacky here, but works for now
|
||||||
|
// for some models, we need to add prefix and suffix to the image embeddings
|
||||||
|
if (proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||||
|
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||||
|
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||||
|
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> parts = string_split_str(text.text, ctx->image_marker);
|
||||||
|
output->clear();
|
||||||
|
output->reserve(parts.size());
|
||||||
|
|
||||||
|
size_t i_img = 0;
|
||||||
|
|
||||||
|
for (const auto & part : parts) {
|
||||||
|
//printf("tokenizing part: %s\n", part.c_str());
|
||||||
|
bool add_bos = &parts.front() == ∂
|
||||||
|
auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special);
|
||||||
|
if (tokens.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
mtmd_input_chunk chunk{
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
|
std::move(tokens),
|
||||||
|
{},
|
||||||
|
};
|
||||||
|
output->emplace_back(std::move(chunk));
|
||||||
|
|
||||||
|
if (&parts.back() != &part) {
|
||||||
|
// add image token to middle of 2 parts
|
||||||
|
|
||||||
|
if (i_img >= bitmaps.size()) {
|
||||||
|
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// shim layer
|
||||||
|
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||||
|
img_u8->nx = bitmaps[i_img].nx;
|
||||||
|
img_u8->ny = bitmaps[i_img].ny;
|
||||||
|
img_u8->buf.resize(bitmaps[i_img].data.size());
|
||||||
|
std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
|
||||||
|
|
||||||
|
// preprocess image
|
||||||
|
clip_image_f32_batch_ptr batch_f32(new clip_image_f32_batch);
|
||||||
|
bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), batch_f32.get());
|
||||||
|
if (!ok) {
|
||||||
|
LOG_ERR("Unable to preprocess image\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
|
||||||
|
image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
|
||||||
|
image_tokens->ny = 1; // TODO
|
||||||
|
image_tokens->batch_f32 = std::move(batch_f32);
|
||||||
|
|
||||||
|
mtmd_input_chunk chunk{
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
|
{},
|
||||||
|
image_tokens,
|
||||||
|
};
|
||||||
|
output->emplace_back(std::move(chunk));
|
||||||
|
i_img++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
|
||||||
|
for (auto & chunk : *chunks) {
|
||||||
|
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
|
||||||
|
delete chunk.tokens_image;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete chunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
||||||
|
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||||
|
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
|
||||||
|
bool ok = clip_image_batch_encode(
|
||||||
|
ctx->ctx_clip,
|
||||||
|
ctx->n_threads,
|
||||||
|
image_tokens->batch_f32.get(),
|
||||||
|
ctx->image_embd_v.data());
|
||||||
|
return ok ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * mtmd_get_output_embd(mtmd_context * ctx) {
|
||||||
|
return ctx->image_embd_v.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
|
||||||
|
size_t n_tokens = 0;
|
||||||
|
for (auto & chunk : *chunks) {
|
||||||
|
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
n_tokens += chunk.tokens_text.size();
|
||||||
|
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
n_tokens += chunk.tokens_image->n_tokens();
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "chunk type not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper struct to make working with embd batch easier
|
||||||
|
// note: this will be removed after llama_batch_ext refactoring
|
||||||
|
struct decode_embd_batch {
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id> seq_id_0;
|
||||||
|
std::vector<llama_seq_id *> seq_ids;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
llama_batch batch;
|
||||||
|
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||||
|
pos .resize(n_tokens);
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_ids .resize(n_tokens + 1);
|
||||||
|
logits .resize(n_tokens);
|
||||||
|
seq_id_0.resize(1);
|
||||||
|
seq_id_0[0] = seq_id;
|
||||||
|
seq_ids [n_tokens] = nullptr;
|
||||||
|
batch = {
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*tokens =*/ nullptr,
|
||||||
|
/*embd =*/ embd,
|
||||||
|
/*pos =*/ pos.data(),
|
||||||
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
|
/*seq_id =*/ seq_ids.data(),
|
||||||
|
/*logits =*/ logits.data(),
|
||||||
|
};
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
batch.pos [i] = pos_0 + i;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
|
batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||||
|
llama_context * lctx,
|
||||||
|
mtmd_input_chunks * chunks,
|
||||||
|
llama_pos pos0,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch) {
|
||||||
|
int32_t ret;
|
||||||
|
llama_pos n_past = pos0;
|
||||||
|
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||||
|
|
||||||
|
for (auto & chunk : *chunks) {
|
||||||
|
bool is_last = &chunk == &chunks->back();
|
||||||
|
if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
|
// TODO @ngxson : may need to split into smaller batches
|
||||||
|
text_batch.n_tokens = chunk.tokens_text.size();
|
||||||
|
for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
|
||||||
|
text_batch.token [i] = chunk.tokens_text[i];
|
||||||
|
text_batch.pos [i] = n_past++;
|
||||||
|
text_batch.n_seq_id[i] = 1;
|
||||||
|
text_batch.seq_id [i][0] = seq_id;
|
||||||
|
text_batch.logits [i] = false;
|
||||||
|
}
|
||||||
|
if (is_last) {
|
||||||
|
// always get logits for last input chunk
|
||||||
|
text_batch.logits[text_batch.n_tokens - 1] = true;
|
||||||
|
}
|
||||||
|
ret = llama_decode(lctx, text_batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to decode text\n");
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||||
|
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
|
||||||
|
GGML_ASSERT(chunk.tokens_image != nullptr);
|
||||||
|
int64_t t0 = ggml_time_ms();
|
||||||
|
if (ctx->print_timings) {
|
||||||
|
LOG_INF("encoding image...\n");
|
||||||
|
}
|
||||||
|
ret = mtmd_encode(ctx, chunk.tokens_image);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to encode image\n");
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (ctx->print_timings) {
|
||||||
|
LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t n_tokens = chunk.tokens_image->n_tokens();
|
||||||
|
float * embd = mtmd_get_output_embd(ctx);
|
||||||
|
decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
|
||||||
|
int64_t t1 = ggml_time_ms();
|
||||||
|
ret = llama_decode(lctx, batch_img.batch);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_ERR("failed to decode image\n");
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (ctx->print_timings) {
|
||||||
|
LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
||||||
|
}
|
||||||
|
|
||||||
|
n_past += n_tokens;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "chunk type not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_free(text_batch);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) {
|
||||||
|
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||||
|
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
|
||||||
|
if (!ok) {
|
||||||
|
LOG_ERR("Unable to load image from buffer\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
|
||||||
|
output.data.resize(output.nx * output.ny * 3);
|
||||||
|
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) {
|
||||||
|
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||||
|
bool ok = clip_image_load_from_file(fname, img_u8.get());
|
||||||
|
if (!ok) {
|
||||||
|
LOG_ERR("Unable to load image %s\n", fname);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny);
|
||||||
|
output.data.resize(output.nx * output.ny * 3);
|
||||||
|
std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
|
||||||
|
return 0;
|
||||||
|
}
|
146
examples/llava/mtmd.h
Normal file
146
examples/llava/mtmd.h
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#ifndef MTMD_H
|
||||||
|
#define MTMD_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "llama.h"
|
||||||
|
#include "clip.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#ifdef LLAMA_SHARED
|
||||||
|
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||||
|
# ifdef LLAMA_BUILD
|
||||||
|
# define MTMD_API __declspec(dllexport)
|
||||||
|
# else
|
||||||
|
# define MTMD_API __declspec(dllimport)
|
||||||
|
# endif
|
||||||
|
# else
|
||||||
|
# define MTMD_API __attribute__ ((visibility ("default")))
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# define MTMD_API
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
enum mtmd_input_chunk_type {
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||||
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mtmd_context;
|
||||||
|
struct mtmd_image_tokens;
|
||||||
|
|
||||||
|
// represents raw image data, layout is RGBRGBRGB...
|
||||||
|
// length of data must be nx * ny * 3
|
||||||
|
struct mtmd_bitmap {
|
||||||
|
uint32_t nx;
|
||||||
|
uint32_t ny;
|
||||||
|
std::vector<unsigned char> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mtmd_input_chunk {
|
||||||
|
mtmd_input_chunk_type type;
|
||||||
|
std::vector<llama_token> tokens_text;
|
||||||
|
mtmd_image_tokens * tokens_image = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
using mtmd_input_chunks = std::vector<mtmd_input_chunk>;
|
||||||
|
|
||||||
|
struct mtmd_context_params {
|
||||||
|
bool use_gpu = true;
|
||||||
|
bool print_timings = true;
|
||||||
|
int n_threads = 4;
|
||||||
|
enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO;
|
||||||
|
const char * image_marker = "<__image__>";
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mtmd_input_text {
|
||||||
|
std::string text;
|
||||||
|
bool add_special;
|
||||||
|
bool parse_special;
|
||||||
|
};
|
||||||
|
|
||||||
|
// initialize the mtmd context
|
||||||
|
// return nullptr on failure
|
||||||
|
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||||
|
const llama_model * text_model,
|
||||||
|
const mtmd_context_params ctx_params);
|
||||||
|
|
||||||
|
MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||||
|
|
||||||
|
// tokenize an input text prompt and an image
|
||||||
|
// the prompt must have the input image marker (default: "<__image__>") in it
|
||||||
|
// the marker will be replaced with the image tokens
|
||||||
|
// for example:
|
||||||
|
// "here is an image: <__image__>\ndescribe it in detail."
|
||||||
|
// this will gives 3 chunks:
|
||||||
|
// 1. "here is an image: <start_of_image>"
|
||||||
|
// 2. (image tokens)
|
||||||
|
// 3. "<end_of_image>\ndescribe it in detail."
|
||||||
|
// number of bitmaps must be equal to the number of image markers in the prompt
|
||||||
|
// this function is thread-safe (shared ctx)
|
||||||
|
MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
|
||||||
|
const mtmd_input_text & text,
|
||||||
|
const std::vector<mtmd_bitmap> & bitmaps);
|
||||||
|
|
||||||
|
// free image chunk data
|
||||||
|
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
||||||
|
|
||||||
|
// returns 0 on success
|
||||||
|
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||||
|
const mtmd_image_tokens * image_tokens);
|
||||||
|
|
||||||
|
// get output embeddings from the last encode pass
|
||||||
|
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||||
|
|
||||||
|
//
|
||||||
|
// helper functions (can be implemented based on other functions)
|
||||||
|
//
|
||||||
|
|
||||||
|
// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
|
||||||
|
MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
|
||||||
|
|
||||||
|
// helper function that automatically:
|
||||||
|
// 1. run llama_decode() on text chunks
|
||||||
|
// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
|
||||||
|
// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
|
||||||
|
// otherwise, returns 0 on success
|
||||||
|
MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
|
||||||
|
llama_context * lctx,
|
||||||
|
mtmd_input_chunks * chunks,
|
||||||
|
llama_pos pos0,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
int32_t n_batch);
|
||||||
|
|
||||||
|
// helper function to construct a mtmd_bitmap from a file
|
||||||
|
// returns 0 on success
|
||||||
|
// this function is thread-safe
|
||||||
|
MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output);
|
||||||
|
|
||||||
|
// helper function to construct a mtmd_bitmap from a buffer
|
||||||
|
// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
|
||||||
|
// returns 0 on success
|
||||||
|
// this function is thread-safe
|
||||||
|
MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output);
|
||||||
|
|
||||||
|
// convenient unique_ptr wrappers
|
||||||
|
struct mtmd_context_deleter {
|
||||||
|
void operator()(mtmd_context * val) { mtmd_free(val); }
|
||||||
|
};
|
||||||
|
using mtmd_context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
|
||||||
|
|
||||||
|
struct mtmd_input_chunks_deleter {
|
||||||
|
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
|
||||||
|
};
|
||||||
|
using mtmd_input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
static_assert(false && "C header is not yet supported by this library");
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
Reference in New Issue
Block a user