mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
rework, targeting llama-server
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -98,6 +98,7 @@ examples/server/*.css.hpp
|
|||||||
examples/server/*.html.hpp
|
examples/server/*.html.hpp
|
||||||
examples/server/*.js.hpp
|
examples/server/*.js.hpp
|
||||||
examples/server/*.mjs.hpp
|
examples/server/*.mjs.hpp
|
||||||
|
examples/server/*.gz.hpp
|
||||||
!build_64.sh
|
!build_64.sh
|
||||||
!examples/*.bat
|
!examples/*.bat
|
||||||
!examples/*/*.kts
|
!examples/*/*.kts
|
||||||
|
@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
|
|||||||
return buf.str();
|
return buf.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
|
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
|
||||||
std::stringstream buf;
|
std::stringstream buf;
|
||||||
|
|
||||||
@ -614,6 +615,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|||||||
|
|
||||||
return buf.str();
|
return buf.str();
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
void string_process_escapes(std::string & input) {
|
void string_process_escapes(std::string & input) {
|
||||||
std::size_t input_len = input.length();
|
std::size_t input_len = input.length();
|
||||||
@ -1608,27 +1610,20 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
|
|||||||
// Batch utils
|
// Batch utils
|
||||||
//
|
//
|
||||||
|
|
||||||
void common_batch_clear(struct llama_batch & batch) {
|
void common_batch_clear(struct llama_batch * batch) {
|
||||||
batch.n_tokens = 0;
|
llama_batch_clear(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_batch_add(
|
void common_batch_add(
|
||||||
struct llama_batch & batch,
|
struct llama_batch * batch,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
const std::vector<llama_seq_id> & seq_ids,
|
||||||
bool logits) {
|
bool logits) {
|
||||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
|
||||||
|
if (res == -1) {
|
||||||
batch.token [batch.n_tokens] = id;
|
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
|
||||||
batch.pos [batch.n_tokens] = pos;
|
|
||||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
|
||||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
|
||||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
|
||||||
}
|
}
|
||||||
batch.logits [batch.n_tokens] = logits;
|
|
||||||
|
|
||||||
batch.n_tokens++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
|
|||||||
// Batch utils
|
// Batch utils
|
||||||
//
|
//
|
||||||
|
|
||||||
void common_batch_clear(struct llama_batch & batch);
|
void common_batch_clear(struct llama_batch * batch);
|
||||||
|
|
||||||
void common_batch_add(
|
void common_batch_add(
|
||||||
struct llama_batch & batch,
|
struct llama_batch * batch,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
const std::vector<llama_seq_id> & seq_ids,
|
||||||
|
@ -13,7 +13,7 @@ struct common_speculative {
|
|||||||
struct llama_context * ctx;
|
struct llama_context * ctx;
|
||||||
struct common_sampler * smpl;
|
struct common_sampler * smpl;
|
||||||
|
|
||||||
llama_batch batch;
|
llama_batch * batch;
|
||||||
llama_tokens prompt;
|
llama_tokens prompt;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
|
|||||||
auto * result = new common_speculative {
|
auto * result = new common_speculative {
|
||||||
/* .ctx = */ ctx_dft,
|
/* .ctx = */ ctx_dft,
|
||||||
/* .smpl = */ nullptr,
|
/* .smpl = */ nullptr,
|
||||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
|
||||||
/* .prompt = */ {},
|
/* .prompt = */ {},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we should rarely end-up here during normal decoding
|
// we should rarely end-up here during normal decoding
|
||||||
if (batch.n_tokens > 0) {
|
if (llama_batch_get_n_tokens(batch) > 0) {
|
||||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||||
|
|
||||||
llama_decode(ctx, batch);
|
llama_decode(ctx, batch);
|
||||||
|
@ -1215,7 +1215,7 @@ struct server_slot {
|
|||||||
// only used for completion/embedding/infill/rerank
|
// only used for completion/embedding/infill/rerank
|
||||||
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
llama_batch batch_spec = {};
|
llama_batch_ptr batch_spec;
|
||||||
|
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
llama_context * ctx_dft = nullptr;
|
llama_context * ctx_dft = nullptr;
|
||||||
@ -1787,7 +1787,7 @@ struct server_context {
|
|||||||
|
|
||||||
llama_context_params cparams_dft;
|
llama_context_params cparams_dft;
|
||||||
|
|
||||||
llama_batch batch = {};
|
llama_batch_ptr batch;
|
||||||
|
|
||||||
bool clean_kv_cache = true;
|
bool clean_kv_cache = true;
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
@ -1820,11 +1820,7 @@ struct server_context {
|
|||||||
|
|
||||||
common_speculative_free(slot.spec);
|
common_speculative_free(slot.spec);
|
||||||
slot.spec = nullptr;
|
slot.spec = nullptr;
|
||||||
|
|
||||||
llama_batch_free(slot.batch_spec);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_model(const common_params & params) {
|
bool load_model(const common_params & params) {
|
||||||
@ -1944,7 +1940,7 @@ struct server_context {
|
|||||||
slot.n_predict = params_base.n_predict;
|
slot.n_predict = params_base.n_predict;
|
||||||
|
|
||||||
if (model_dft) {
|
if (model_dft) {
|
||||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1));
|
||||||
|
|
||||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||||
if (slot.ctx_dft == nullptr) {
|
if (slot.ctx_dft == nullptr) {
|
||||||
@ -1969,7 +1965,7 @@ struct server_context {
|
|||||||
|
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
slots.push_back(slot);
|
slots.push_back(std::move(slot));
|
||||||
}
|
}
|
||||||
|
|
||||||
default_generation_settings_for_props = slots[0].to_json();
|
default_generation_settings_for_props = slots[0].to_json();
|
||||||
@ -1980,7 +1976,7 @@ struct server_context {
|
|||||||
const int32_t n_batch = llama_n_batch(ctx);
|
const int32_t n_batch = llama_n_batch(ctx);
|
||||||
|
|
||||||
// only a single seq_id per token is needed
|
// only a single seq_id per token is needed
|
||||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.init();
|
metrics.init();
|
||||||
@ -2098,9 +2094,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slot.ctx_dft) {
|
if (slot.ctx_dft) {
|
||||||
llama_batch_free(slot.batch_spec);
|
slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1));
|
||||||
|
|
||||||
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_STARTED;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
@ -2408,7 +2402,7 @@ struct server_context {
|
|||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
void send_embedding(const server_slot & slot, llama_batch_ptr & batch) {
|
||||||
auto res = std::make_unique<server_task_result_embd>();
|
auto res = std::make_unique<server_task_result_embd>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
@ -2419,18 +2413,19 @@ struct server_context {
|
|||||||
|
|
||||||
std::vector<float> embd_res(n_embd, 0.0f);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
|
||||||
|
if (!tok.logits || tok.seq_id[0] != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
|
||||||
|
|
||||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||||
continue;
|
continue;
|
||||||
@ -2451,24 +2446,25 @@ struct server_context {
|
|||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
void send_rerank(const server_slot & slot, llama_batch_ptr & batch) {
|
||||||
auto res = std::make_unique<server_task_result_rerank>();
|
auto res = std::make_unique<server_task_result_rerank>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->n_tokens = slot.n_prompt_tokens;
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
|
||||||
|
if (!tok.logits || tok.seq_id[0] != slot.id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd == NULL) {
|
if (embd == NULL) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
|
||||||
|
|
||||||
res->score = -1e6;
|
res->score = -1e6;
|
||||||
continue;
|
continue;
|
||||||
@ -2859,7 +2855,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch.get());
|
||||||
|
|
||||||
// track if given slot can be batched with slots already in the batch
|
// track if given slot can be batched with slots already in the batch
|
||||||
server_slot * slot_batched = nullptr;
|
server_slot * slot_batched = nullptr;
|
||||||
@ -2881,9 +2877,9 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = llama_batch_get_n_tokens(batch.get());
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
@ -2900,7 +2896,7 @@ struct server_context {
|
|||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
if (slot.is_processing()) {
|
if (slot.is_processing()) {
|
||||||
@ -3066,7 +3062,7 @@ struct server_context {
|
|||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.is_non_causal()) {
|
if (slot.is_non_causal()) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3086,11 +3082,11 @@ struct server_context {
|
|||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) {
|
||||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
@ -3100,13 +3096,13 @@ struct server_context {
|
|||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
|
|
||||||
// entire prompt has been processed
|
// entire prompt has been processed
|
||||||
if (slot.n_past == slot.n_prompt_tokens) {
|
if (slot.n_past == slot.n_prompt_tokens) {
|
||||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0);
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
@ -3116,27 +3112,27 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_set_logits_last(batch.get());
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens >= n_batch) {
|
if (llama_batch_get_n_tokens(batch.get()) >= n_batch) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (llama_batch_get_n_tokens(batch.get()) == 0) {
|
||||||
SRV_WRN("%s", "no tokens to decode\n");
|
SRV_WRN("%s", "no tokens to decode\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get()));
|
||||||
|
|
||||||
if (slot_batched) {
|
if (slot_batched) {
|
||||||
// make sure we're in the right embedding mode
|
// make sure we're in the right embedding mode
|
||||||
@ -3146,20 +3142,12 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i);
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens));
|
||||||
n_tokens,
|
|
||||||
batch.token + i,
|
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view.get());
|
||||||
metrics.on_decoded(slots);
|
metrics.on_decoded(slots);
|
||||||
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
@ -3294,16 +3282,16 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// construct the speculation batch
|
// construct the speculation batch
|
||||||
common_batch_clear(slot.batch_spec);
|
common_batch_clear(slot.batch_spec.get());
|
||||||
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true);
|
||||||
|
|
||||||
for (size_t i = 0; i < draft.size(); ++i) {
|
for (size_t i = 0; i < draft.size(); ++i) {
|
||||||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get()));
|
||||||
|
|
||||||
llama_decode(ctx, slot.batch_spec);
|
llama_decode(ctx, slot.batch_spec.get());
|
||||||
|
|
||||||
// the accepted tokens from the speculation
|
// the accepted tokens from the speculation
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||||
|
@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter {
|
|||||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_batch_deleter {
|
||||||
|
void operator()(llama_batch * batch) { llama_batch_free(batch); }
|
||||||
|
};
|
||||||
|
|
||||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||||
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
||||||
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
||||||
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
||||||
|
typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr;
|
||||||
|
@ -233,6 +233,14 @@ extern "C" {
|
|||||||
|
|
||||||
struct llama_batch;
|
struct llama_batch;
|
||||||
|
|
||||||
|
struct llama_batch_token_info {
|
||||||
|
llama_token token;
|
||||||
|
llama_pos pos;
|
||||||
|
int32_t n_seq_id;
|
||||||
|
llama_seq_id * seq_id;
|
||||||
|
int8_t logits;
|
||||||
|
};
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||||
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
||||||
@ -837,34 +845,44 @@ extern "C" {
|
|||||||
int32_t pos0,
|
int32_t pos0,
|
||||||
int32_t seq_id);
|
int32_t seq_id);
|
||||||
|
|
||||||
|
// Get the number of tokens in the batch
|
||||||
|
LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_batch_token_info llama_batch_get_token_info(
|
||||||
|
struct llama_batch * batch,
|
||||||
|
int32_t i);
|
||||||
|
|
||||||
// Add text tokens to the batch
|
// Add text tokens to the batch
|
||||||
// First token in the list starts at position pos0
|
|
||||||
// Return values:
|
// Return values:
|
||||||
// 0 : success
|
// 0 : success
|
||||||
// -1 : not enough space in the batch
|
// -1 : not enough space in the batch
|
||||||
// -2 : embd is already set, cannot add text tokens
|
// -2 : embd is already set, cannot add text tokens
|
||||||
LLAMA_API int32_t llama_batch_add_text(
|
LLAMA_API int32_t llama_batch_add_text_token(
|
||||||
struct llama_batch * batch,
|
struct llama_batch * batch,
|
||||||
llama_token * tokens,
|
llama_token token,
|
||||||
size_t n_tokens,
|
llama_pos pos,
|
||||||
int32_t pos0,
|
const llama_seq_id * seq_ids,
|
||||||
int32_t seq_id);
|
size_t n_seq_ids,
|
||||||
|
float logits);
|
||||||
// Same as llama_batch_add_text, but accepts multiple sequences
|
|
||||||
LLAMA_API int32_t llama_batch_add_text(
|
|
||||||
struct llama_batch * batch,
|
|
||||||
llama_token * tokens,
|
|
||||||
size_t n_tokens,
|
|
||||||
int32_t pos0,
|
|
||||||
int32_t * seq_ids,
|
|
||||||
size_t n_seq_ids);
|
|
||||||
|
|
||||||
// Set logits for the token in the ith sequence
|
// Set logits for the token in the ith sequence
|
||||||
// If pos == -1, logits will be set for the all tokens
|
// If pos == -1, logits will be set for the all tokens
|
||||||
|
// Returns -1 if the token is not in the batch
|
||||||
LLAMA_API int32_t llama_batch_set_logits(
|
LLAMA_API int32_t llama_batch_set_logits(
|
||||||
struct llama_batch * batch,
|
struct llama_batch * batch,
|
||||||
int32_t pos,
|
llama_pos pos,
|
||||||
int32_t seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Set logits for the last added token
|
||||||
|
// Returns -1 if there is no tokens in the batch
|
||||||
|
LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch);
|
||||||
|
|
||||||
|
// Get a "view" from a number of tokens offset
|
||||||
|
// Return returned batch must be freed with llama_batch_free()
|
||||||
|
LLAMA_API struct llama_batch * llama_batch_get_view(
|
||||||
|
struct llama_batch * batch,
|
||||||
|
int32_t offset,
|
||||||
|
int32_t n_tokens);
|
||||||
|
|
||||||
// Remove everything from the batch
|
// Remove everything from the batch
|
||||||
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
|
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
|
||||||
@ -878,7 +896,7 @@ extern "C" {
|
|||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the KV cache state is restored to the state before this call
|
||||||
LLAMA_API int32_t llama_encode(
|
LLAMA_API int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch * batch);
|
||||||
|
|
||||||
// Positive return values does not mean a fatal error, but rather a warning.
|
// Positive return values does not mean a fatal error, but rather a warning.
|
||||||
// 0 - success
|
// 0 - success
|
||||||
@ -886,7 +904,7 @@ extern "C" {
|
|||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the KV cache state is restored to the state before this call
|
||||||
LLAMA_API int32_t llama_decode(
|
LLAMA_API int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch * batch);
|
||||||
|
|
||||||
// Set the number of threads used for decoding
|
// Set the number of threads used for decoding
|
||||||
// n_threads is the number of threads used for generation (single token)
|
// n_threads is the number of threads used for generation (single token)
|
||||||
|
@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one(
|
|||||||
int32_t n_tokens) {
|
int32_t n_tokens) {
|
||||||
return new llama_batch{
|
return new llama_batch{
|
||||||
/*n_tokens =*/ n_tokens,
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*max_tokens =*/ n_tokens,
|
||||||
|
/*is_view =*/ false,
|
||||||
/*tokens =*/ tokens,
|
/*tokens =*/ tokens,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one(
|
|||||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||||
llama_batch * batch = new llama_batch{
|
llama_batch * batch = new llama_batch{
|
||||||
/*n_tokens =*/ 0,
|
/*n_tokens =*/ 0,
|
||||||
|
/*max_tokens =*/ n_tokens_alloc,
|
||||||
|
/*is_view =*/ false,
|
||||||
/*tokens =*/ nullptr,
|
/*tokens =*/ nullptr,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd(
|
|||||||
int32_t seq_id) {
|
int32_t seq_id) {
|
||||||
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
||||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||||
for (int32_t i = 0; i < n_embd; i++) {
|
for (size_t i = 0; i < n_embd; i++) {
|
||||||
batch->pos [i] = pos0 + i;
|
batch->pos [i] = pos0 + i;
|
||||||
batch->n_seq_id[i] = 1;
|
batch->n_seq_id[i] = 1;
|
||||||
batch->seq_id [i][0] = seq_id;
|
batch->seq_id [i][0] = seq_id;
|
||||||
}
|
}
|
||||||
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_batch_add_text(
|
int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) {
|
||||||
|
return batch->n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_batch_add_text_token(
|
||||||
struct llama_batch * batch,
|
struct llama_batch * batch,
|
||||||
llama_token * tokens,
|
llama_token token,
|
||||||
size_t n_tokens,
|
llama_pos pos,
|
||||||
int32_t pos0,
|
const llama_seq_id * seq_ids,
|
||||||
int32_t * seq_ids,
|
size_t n_seq_ids,
|
||||||
size_t n_seq_ids) {
|
float logits) {
|
||||||
if (batch->n_tokens + n_tokens > batch->n_tokens) {
|
if (batch->n_tokens + 1 > batch->max_tokens) {
|
||||||
return -1;
|
return -1; // llama_batch size exceeded
|
||||||
}
|
}
|
||||||
if (batch->embd) {
|
if (batch->embd) {
|
||||||
return -2;
|
return -2; // embd is already set, cannot add text tokens
|
||||||
}
|
}
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
batch->token [batch->n_tokens] = token;
|
||||||
batch->token [batch->n_tokens + i] = tokens[i];
|
batch->pos [batch->n_tokens] = pos;
|
||||||
batch->pos [batch->n_tokens + i] = pos0 + i;
|
batch->n_seq_id[batch->n_tokens] = n_seq_ids;
|
||||||
batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
|
for (size_t j = 0; j < n_seq_ids; j++) {
|
||||||
for (int32_t j = 0; j < n_seq_ids; j++) {
|
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
|
||||||
batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
|
|
||||||
}
|
}
|
||||||
}
|
batch->logits [batch->n_tokens] = logits;
|
||||||
}
|
batch->n_tokens++;
|
||||||
|
return 0;
|
||||||
int32_t llama_batch_add_text(
|
|
||||||
struct llama_batch * batch,
|
|
||||||
llama_token * tokens,
|
|
||||||
size_t n_tokens,
|
|
||||||
int32_t pos0,
|
|
||||||
int32_t seq_id) {
|
|
||||||
std::array<int32_t, 1> seq_ids = { seq_id };
|
|
||||||
return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_batch_set_logits(
|
int32_t llama_batch_set_logits(
|
||||||
struct llama_batch * batch,
|
struct llama_batch * batch,
|
||||||
int32_t pos,
|
llama_pos pos,
|
||||||
int32_t seq_id) {
|
llama_seq_id seq_id) {
|
||||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||||
// find the token having seq_id
|
// find the token having seq_id
|
||||||
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
|
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
|
||||||
@ -415,18 +415,63 @@ int32_t llama_batch_set_logits(
|
|||||||
// found the sequence
|
// found the sequence
|
||||||
if (pos == -1 || pos == batch->pos[i]) {
|
if (pos == -1 || pos == batch->pos[i]) {
|
||||||
batch->logits[i] = true;
|
batch->logits[i] = true;
|
||||||
break;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return -1; // not found
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||||
|
if (batch->n_tokens == 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
batch->logits[batch->n_tokens - 1] = true;
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_batch_clear(struct llama_batch * batch) {
|
void llama_batch_clear(struct llama_batch * batch) {
|
||||||
batch->n_tokens = 0;
|
batch->n_tokens = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_batch * llama_batch_get_view(
|
||||||
|
struct llama_batch * batch,
|
||||||
|
int32_t offset,
|
||||||
|
int32_t n_tokens) {
|
||||||
|
if (batch->embd) {
|
||||||
|
return nullptr; // not yet supported
|
||||||
|
}
|
||||||
|
llama_batch * batch_view = new llama_batch{
|
||||||
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*max_tokens =*/ n_tokens,
|
||||||
|
/*is_view =*/ true,
|
||||||
|
/*tokens =*/ batch->token + offset,
|
||||||
|
/*embd =*/ nullptr,
|
||||||
|
/*pos =*/ batch->pos + offset,
|
||||||
|
/*n_seq_id =*/ batch->n_seq_id + offset,
|
||||||
|
/*seq_id =*/ batch->seq_id + offset,
|
||||||
|
/*logits =*/ batch->logits + offset,
|
||||||
|
};
|
||||||
|
return batch_view;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_batch_token_info llama_batch_get_token_info(
|
||||||
|
struct llama_batch * batch,
|
||||||
|
int32_t i) {
|
||||||
|
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
|
||||||
|
return llama_batch_token_info{
|
||||||
|
/*token =*/ batch->token [i],
|
||||||
|
/*pos =*/ batch->pos [i],
|
||||||
|
/*n_seq_id =*/ batch->n_seq_id[i],
|
||||||
|
/*seq_id =*/ batch->seq_id [i],
|
||||||
|
/*logits =*/ batch->logits [i],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
void llama_batch_free(struct llama_batch * batch) {
|
void llama_batch_free(struct llama_batch * batch) {
|
||||||
|
// do not free the members if it's a view
|
||||||
|
if (!batch->is_view) {
|
||||||
if (batch->token) free(batch->token);
|
if (batch->token) free(batch->token);
|
||||||
if (batch->embd) free(batch->embd);
|
if (batch->embd) free(batch->embd);
|
||||||
if (batch->pos) free(batch->pos);
|
if (batch->pos) free(batch->pos);
|
||||||
@ -438,5 +483,6 @@ void llama_batch_free(struct llama_batch * batch) {
|
|||||||
free(batch->seq_id);
|
free(batch->seq_id);
|
||||||
}
|
}
|
||||||
if (batch->logits) free(batch->logits);
|
if (batch->logits) free(batch->logits);
|
||||||
|
}
|
||||||
delete batch;
|
delete batch;
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
//
|
//
|
||||||
struct llama_batch {
|
struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
int32_t max_tokens;
|
||||||
|
bool is_view;
|
||||||
|
|
||||||
llama_token * token;
|
llama_token * token;
|
||||||
float * embd;
|
float * embd;
|
||||||
|
@ -9978,8 +9978,8 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
|||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch * batch) {
|
||||||
const int ret = llama_encode_impl(*ctx, batch);
|
const int ret = llama_encode_impl(*ctx, *batch);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
@ -9989,8 +9989,8 @@ int32_t llama_encode(
|
|||||||
|
|
||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch * batch) {
|
||||||
const int ret = llama_decode_impl(*ctx, batch);
|
const int ret = llama_decode_impl(*ctx, *batch);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user