mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
apply to the rest
This commit is contained in:
@ -5,6 +5,7 @@
|
||||
#include "clip.h"
|
||||
#include "stb_image.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "ggml.h"
|
||||
#include "console.h"
|
||||
|
||||
@ -63,7 +64,7 @@ struct gemma3_context {
|
||||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
const llama_vocab * vocab;
|
||||
llama_batch batch;
|
||||
llama_batch_ext_ptr batch;
|
||||
|
||||
int n_threads = 1;
|
||||
llama_pos n_past = 0;
|
||||
@ -73,7 +74,7 @@ struct gemma3_context {
|
||||
lctx = llama_init.context.get();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
batch.reset(llama_batch_ext_init(params.n_batch, 1));
|
||||
init_clip_model(params);
|
||||
}
|
||||
|
||||
@ -87,50 +88,18 @@ struct gemma3_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct decode_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
||||
common_batch_clear(ctx.batch);
|
||||
llama_batch_ext_clear(ctx.batch.get());
|
||||
for (llama_token & t : tokens) {
|
||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false);
|
||||
}
|
||||
if (logits_last) {
|
||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_output_last(ctx.batch.get());
|
||||
}
|
||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||
LOG_ERR("Failed to decode text\n");
|
||||
return 1;
|
||||
}
|
||||
@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
|
||||
int64_t t1 = ggml_time_ms();
|
||||
eval_text(ctx, "<start_of_image>");
|
||||
llama_set_causal_attn(ctx.lctx, false);
|
||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
||||
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
|
||||
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
return 1;
|
||||
}
|
||||
@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
||||
fflush(stdout);
|
||||
|
||||
// eval the token
|
||||
common_batch_clear(ctx.batch);
|
||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
llama_batch_ext_clear(ctx.batch.get());
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
|
||||
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||
LOG_ERR("failed to decode token\n");
|
||||
return 1;
|
||||
}
|
||||
|
Reference in New Issue
Block a user