apply various in places

This commit is contained in:
Xuan Son Nguyen
2025-03-01 20:42:18 +01:00
parent 1d6ba97789
commit 46596caf6d
12 changed files with 142 additions and 133 deletions

View File

@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result {
}
};
struct server_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id;
bool logits;
};
std::vector<batch_token> tokens;
server_batch() = default;
server_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_logits_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
server_batch get_view(int32_t offset, int32_t n_tokens) {
server_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
}
return view;
}
};
struct server_slot {
int id;
int id_task = -1;
@ -1253,7 +1212,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
server_batch batch_spec;
common_batch batch_spec;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
@ -1825,7 +1784,7 @@ struct server_context {
llama_context_params cparams_dft;
server_batch batch;
common_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1950,7 +1909,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1);
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
@ -1986,7 +1945,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = server_batch(std::max(n_batch, params_base.n_parallel), 1);
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
}
metrics.init();
@ -2104,7 +2063,7 @@ struct server_context {
}
if (slot.ctx_dft) {
slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1);
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
}
slot.state = SLOT_STATE_STARTED;
@ -2412,7 +2371,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, server_batch & batch) {
void send_embedding(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
@ -2456,7 +2415,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, server_batch & batch) {
void send_rerank(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
@ -3155,9 +3114,9 @@ struct server_context {
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
server_batch batch_view = batch.get_view(i, n_tokens);
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode_ext(ctx, batch_view.batch.get());
const int ret = llama_decode_ext(ctx, batch_view.get());
metrics.on_decoded(slots);
if (ret != 0) {
@ -3301,7 +3260,7 @@ struct server_context {
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
llama_decode_ext(ctx, slot.batch_spec.batch.get());
llama_decode_ext(ctx, slot.batch_spec.get());
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);