apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@ -192,10 +192,11 @@ int main(int argc, char ** argv) {
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
for (int32_t i = 0; i < n_tokens_system; ++i) {
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@ -224,14 +225,15 @@ int main(int argc, char ** argv) {
continue;
}
client.i_batch = batch.n_tokens;
client.i_batch = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_self_seq_rm(ctx, i, -1, -1);
@ -243,7 +245,7 @@ int main(int argc, char ** argv) {
}
// insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) {
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
tokens_prompt = common_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
if (llama_batch_ext_get_n_tokens(batch) > 0) {
llama_batch_ext_set_output_last(batch);
}
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1;
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
}
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
const int ret = llama_decode_ext(ctx, batch_view);
llama_batch_ext_free(batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
// TODO: print sampling/grammar timings for all clients
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();