apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -5,6 +5,7 @@
#include "clip.h"
#include "stb_image.h"
#include "llama.h"
#include "llama-cpp.h"
#include "ggml.h"
#include "console.h"
@ -63,7 +64,7 @@ struct gemma3_context {
llama_model * model;
llama_context * lctx;
const llama_vocab * vocab;
llama_batch batch;
llama_batch_ext_ptr batch;
int n_threads = 1;
llama_pos n_past = 0;
@ -73,7 +74,7 @@ struct gemma3_context {
lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads;
batch = llama_batch_init(params.n_batch, 0, 1);
batch.reset(llama_batch_ext_init(params.n_batch, 1));
init_clip_model(params);
}
@ -87,50 +88,18 @@ struct gemma3_context {
}
};
struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
common_batch_clear(ctx.batch);
llama_batch_ext_clear(ctx.batch.get());
for (llama_token & t : tokens) {
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false);
}
if (logits_last) {
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(ctx.batch.get());
}
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode(ctx.lctx, ctx.batch)) {
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("Failed to decode text\n");
return 1;
}
@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false);
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
if (llama_decode(ctx.lctx, batch_img.batch)) {
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
LOG_ERR("failed to decode image\n");
return 1;
}
@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
fflush(stdout);
// eval the token
common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
if (llama_decode(ctx.lctx, ctx.batch)) {
llama_batch_ext_clear(ctx.batch.get());
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("failed to decode token\n");
return 1;
}