mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
remove token_info API
This commit is contained in:
@ -1205,6 +1205,47 @@ struct server_task_result_apply_lora : server_task_result {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct server_batch {
|
||||||
|
llama_batch_ext_ptr batch;
|
||||||
|
struct batch_token {
|
||||||
|
llama_token token;
|
||||||
|
llama_seq_id seq_id;
|
||||||
|
bool logits;
|
||||||
|
};
|
||||||
|
std::vector<batch_token> tokens;
|
||||||
|
server_batch() = default;
|
||||||
|
server_batch(int32_t n_tokens, int32_t n_seq_max) {
|
||||||
|
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
||||||
|
tokens.reserve(n_tokens);
|
||||||
|
}
|
||||||
|
void clear() {
|
||||||
|
llama_batch_ext_clear(batch.get());
|
||||||
|
tokens.clear();
|
||||||
|
}
|
||||||
|
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
||||||
|
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
||||||
|
tokens.push_back({token, seq_id, logits});
|
||||||
|
}
|
||||||
|
void set_logits_last() {
|
||||||
|
if (!tokens.empty()) {
|
||||||
|
llama_batch_ext_set_logits_last(batch.get());
|
||||||
|
tokens.back().logits = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int32_t get_n_tokens() const {
|
||||||
|
return (int32_t)tokens.size();
|
||||||
|
}
|
||||||
|
server_batch get_view(int32_t offset, int32_t n_tokens) {
|
||||||
|
server_batch view;
|
||||||
|
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
|
||||||
|
view.tokens.reserve(n_tokens);
|
||||||
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
|
view.tokens.push_back(tokens[offset + i]);
|
||||||
|
}
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
int id;
|
int id;
|
||||||
int id_task = -1;
|
int id_task = -1;
|
||||||
@ -1212,7 +1253,7 @@ struct server_slot {
|
|||||||
// only used for completion/embedding/infill/rerank
|
// only used for completion/embedding/infill/rerank
|
||||||
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
llama_batch_ext_ptr batch_spec;
|
server_batch batch_spec;
|
||||||
|
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
llama_context * ctx_dft = nullptr;
|
llama_context * ctx_dft = nullptr;
|
||||||
@ -1784,7 +1825,7 @@ struct server_context {
|
|||||||
|
|
||||||
llama_context_params cparams_dft;
|
llama_context_params cparams_dft;
|
||||||
|
|
||||||
llama_batch_ext_ptr batch;
|
server_batch batch;
|
||||||
|
|
||||||
bool clean_kv_cache = true;
|
bool clean_kv_cache = true;
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
@ -1909,7 +1950,7 @@ struct server_context {
|
|||||||
slot.n_predict = params_base.n_predict;
|
slot.n_predict = params_base.n_predict;
|
||||||
|
|
||||||
if (model_dft) {
|
if (model_dft) {
|
||||||
slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1));
|
slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1);
|
||||||
|
|
||||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||||
if (slot.ctx_dft == nullptr) {
|
if (slot.ctx_dft == nullptr) {
|
||||||
@ -1945,7 +1986,7 @@ struct server_context {
|
|||||||
const int32_t n_batch = llama_n_batch(ctx);
|
const int32_t n_batch = llama_n_batch(ctx);
|
||||||
|
|
||||||
// only a single seq_id per token is needed
|
// only a single seq_id per token is needed
|
||||||
batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1));
|
batch = server_batch(std::max(n_batch, params_base.n_parallel), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.init();
|
metrics.init();
|
||||||
@ -2063,7 +2104,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slot.ctx_dft) {
|
if (slot.ctx_dft) {
|
||||||
slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1));
|
slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_STARTED;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
@ -2371,7 +2412,7 @@ struct server_context {
|
|||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) {
|
void send_embedding(const server_slot & slot, server_batch & batch) {
|
||||||
auto res = std::make_unique<server_task_result_embd>();
|
auto res = std::make_unique<server_task_result_embd>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
@ -2382,19 +2423,19 @@ struct server_context {
|
|||||||
|
|
||||||
std::vector<float> embd_res(n_embd, 0.0f);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
|
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||||
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
|
auto tok = batch.tokens[i];
|
||||||
if (!tok.logits || tok.seq_id[0] != slot.id) {
|
if (!tok.logits || tok.seq_id != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
|
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||||
|
|
||||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||||
continue;
|
continue;
|
||||||
@ -2415,25 +2456,25 @@ struct server_context {
|
|||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) {
|
void send_rerank(const server_slot & slot, server_batch & batch) {
|
||||||
auto res = std::make_unique<server_task_result_rerank>();
|
auto res = std::make_unique<server_task_result_rerank>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->n_tokens = slot.n_prompt_tokens;
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
|
||||||
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
|
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||||
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
|
auto tok = batch.tokens[i];
|
||||||
if (!tok.logits || tok.seq_id[0] != slot.id) {
|
if (!tok.logits || tok.seq_id != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
|
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||||
|
|
||||||
res->score = -1e6;
|
res->score = -1e6;
|
||||||
continue;
|
continue;
|
||||||
@ -2824,7 +2865,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
llama_batch_ext_clear(batch.get());
|
batch.clear();
|
||||||
|
|
||||||
// track if given slot can be batched with slots already in the batch
|
// track if given slot can be batched with slots already in the batch
|
||||||
server_slot * slot_batched = nullptr;
|
server_slot * slot_batched = nullptr;
|
||||||
@ -2846,10 +2887,9 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get());
|
slot.i_batch = batch.get_n_tokens();
|
||||||
|
|
||||||
std::array<llama_token, 1> seq_id = { slot.id };
|
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
|
||||||
llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true);
|
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
@ -2866,7 +2906,7 @@ struct server_context {
|
|||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
|
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
if (slot.is_processing()) {
|
if (slot.is_processing()) {
|
||||||
@ -3032,7 +3072,7 @@ struct server_context {
|
|||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.is_non_causal()) {
|
if (slot.is_non_causal()) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
|
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3052,12 +3092,11 @@ struct server_context {
|
|||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
|
||||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
std::array<llama_token, 1> seq_id = { slot.id };
|
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
|
||||||
llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
@ -3067,13 +3106,13 @@ struct server_context {
|
|||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
|
|
||||||
// entire prompt has been processed
|
// entire prompt has been processed
|
||||||
if (slot.n_past == slot.n_prompt_tokens) {
|
if (slot.n_past == slot.n_prompt_tokens) {
|
||||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0);
|
GGML_ASSERT(batch.get_n_tokens() > 0);
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
@ -3083,27 +3122,27 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
llama_batch_ext_set_logits_last(batch.get());
|
batch.set_logits_last();
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
|
slot.i_batch = batch.get_n_tokens() - 1;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()));
|
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) {
|
if (batch.get_n_tokens() >= n_batch) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
|
if (batch.get_n_tokens() == 0) {
|
||||||
SRV_WRN("%s", "no tokens to decode\n");
|
SRV_WRN("%s", "no tokens to decode\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get()));
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
|
||||||
|
|
||||||
if (slot_batched) {
|
if (slot_batched) {
|
||||||
// make sure we're in the right embedding mode
|
// make sure we're in the right embedding mode
|
||||||
@ -3113,12 +3152,12 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) {
|
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i);
|
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
|
||||||
|
|
||||||
llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens));
|
server_batch batch_view = batch.get_view(i, n_tokens);
|
||||||
|
|
||||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
const int ret = llama_decode_ext(ctx, batch_view.batch.get());
|
||||||
metrics.on_decoded(slots);
|
metrics.on_decoded(slots);
|
||||||
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
@ -3253,17 +3292,16 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// construct the speculation batch
|
// construct the speculation batch
|
||||||
llama_batch_ext_clear(slot.batch_spec.get());
|
slot.batch_spec.clear();
|
||||||
std::array<llama_token, 1> seq_id = { slot.id };
|
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
|
||||||
llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < draft.size(); ++i) {
|
for (size_t i = 0; i < draft.size(); ++i) {
|
||||||
llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true);
|
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get()));
|
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
|
||||||
|
|
||||||
llama_decode_ext(ctx, slot.batch_spec.get());
|
llama_decode_ext(ctx, slot.batch_spec.batch.get());
|
||||||
|
|
||||||
// the accepted tokens from the speculation
|
// the accepted tokens from the speculation
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||||
|
@ -263,14 +263,6 @@ extern "C" {
|
|||||||
// It can contain text tokens and embeddings for one or many sequences
|
// It can contain text tokens and embeddings for one or many sequences
|
||||||
struct llama_batch_ext;
|
struct llama_batch_ext;
|
||||||
|
|
||||||
struct llama_batch_ext_token_info {
|
|
||||||
llama_token token;
|
|
||||||
llama_pos pos;
|
|
||||||
int32_t n_seq_id;
|
|
||||||
llama_seq_id * seq_id;
|
|
||||||
int8_t logits;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||||
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
||||||
@ -896,10 +888,6 @@ extern "C" {
|
|||||||
// Get the number of tokens in the batch
|
// Get the number of tokens in the batch
|
||||||
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
||||||
|
|
||||||
LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
|
|
||||||
struct llama_batch_ext * batch,
|
|
||||||
int32_t i);
|
|
||||||
|
|
||||||
// Add text tokens to the batch
|
// Add text tokens to the batch
|
||||||
// Return values:
|
// Return values:
|
||||||
// 0 : success
|
// 0 : success
|
||||||
|
@ -480,19 +480,6 @@ struct llama_batch_ext * llama_batch_ext_get_view(
|
|||||||
return batch_view;
|
return batch_view;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
|
|
||||||
struct llama_batch_ext * batch,
|
|
||||||
int32_t i) {
|
|
||||||
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
|
|
||||||
return llama_batch_ext_token_info{
|
|
||||||
/*token =*/ batch->token [i],
|
|
||||||
/*pos =*/ batch->pos [i],
|
|
||||||
/*n_seq_id =*/ batch->n_seq_id[i],
|
|
||||||
/*seq_id =*/ batch->seq_id [i],
|
|
||||||
/*logits =*/ batch->logits [i],
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_batch_ext_free(struct llama_batch_ext * batch) {
|
void llama_batch_ext_free(struct llama_batch_ext * batch) {
|
||||||
// do not free the members if it's a view
|
// do not free the members if it's a view
|
||||||
if (!batch->is_view) {
|
if (!batch->is_view) {
|
||||||
|
Reference in New Issue
Block a user