mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
apply various in places
This commit is contained in:
@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result {
|
||||
}
|
||||
};
|
||||
|
||||
struct server_batch {
|
||||
llama_batch_ext_ptr batch;
|
||||
struct batch_token {
|
||||
llama_token token;
|
||||
llama_seq_id seq_id;
|
||||
bool logits;
|
||||
};
|
||||
std::vector<batch_token> tokens;
|
||||
server_batch() = default;
|
||||
server_batch(int32_t n_tokens, int32_t n_seq_max) {
|
||||
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
||||
tokens.reserve(n_tokens);
|
||||
}
|
||||
void clear() {
|
||||
llama_batch_ext_clear(batch.get());
|
||||
tokens.clear();
|
||||
}
|
||||
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
||||
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
||||
tokens.push_back({token, seq_id, logits});
|
||||
}
|
||||
void set_logits_last() {
|
||||
if (!tokens.empty()) {
|
||||
llama_batch_ext_set_logits_last(batch.get());
|
||||
tokens.back().logits = true;
|
||||
}
|
||||
}
|
||||
int32_t get_n_tokens() const {
|
||||
return (int32_t)tokens.size();
|
||||
}
|
||||
server_batch get_view(int32_t offset, int32_t n_tokens) {
|
||||
server_batch view;
|
||||
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
|
||||
view.tokens.reserve(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
view.tokens.push_back(tokens[offset + i]);
|
||||
}
|
||||
return view;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
@ -1253,7 +1212,7 @@ struct server_slot {
|
||||
// only used for completion/embedding/infill/rerank
|
||||
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
server_batch batch_spec;
|
||||
common_batch batch_spec;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
@ -1825,7 +1784,7 @@ struct server_context {
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
server_batch batch;
|
||||
common_batch batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
@ -1950,7 +1909,7 @@ struct server_context {
|
||||
slot.n_predict = params_base.n_predict;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1);
|
||||
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
|
||||
|
||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
@ -1986,7 +1945,7 @@ struct server_context {
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = server_batch(std::max(n_batch, params_base.n_parallel), 1);
|
||||
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
@ -2104,7 +2063,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1);
|
||||
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
@ -2412,7 +2371,7 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_embedding(const server_slot & slot, server_batch & batch) {
|
||||
void send_embedding(const server_slot & slot, common_batch & batch) {
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
@ -2456,7 +2415,7 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_rerank(const server_slot & slot, server_batch & batch) {
|
||||
void send_rerank(const server_slot & slot, common_batch & batch) {
|
||||
auto res = std::make_unique<server_task_result_rerank>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
@ -3155,9 +3114,9 @@ struct server_context {
|
||||
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
|
||||
|
||||
server_batch batch_view = batch.get_view(i, n_tokens);
|
||||
common_batch batch_view = batch.get_view(i, n_tokens);
|
||||
|
||||
const int ret = llama_decode_ext(ctx, batch_view.batch.get());
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
metrics.on_decoded(slots);
|
||||
|
||||
if (ret != 0) {
|
||||
@ -3301,7 +3260,7 @@ struct server_context {
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
|
||||
|
||||
llama_decode_ext(ctx, slot.batch_spec.batch.get());
|
||||
llama_decode_ext(ctx, slot.batch_spec.get());
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
Reference in New Issue
Block a user