apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@@ -2,6 +2,7 @@
#include "llava.h"
#include "llama.h"
#include "llama-cpp.h"
#include <algorithm>
#include <cerrno>
@@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true;
}
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
@@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}