apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
}
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");
return 1;
@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
llama_token new_token_id;
while (true) {
check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) {
if (llama_decode_ext(llama_data.context.get(), batch.get())) {
printe("failed to decode\n");
return 1;
}
@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
print_word_and_concatenate_to_response(piece, response);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0));
}
printf(LOG_COL_DEFAULT);