mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
server : avoid common_batch
ggml-ci
This commit is contained in:
@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
|
||||
// this is meant to be temporary
|
||||
struct common_batch {
|
||||
llama_batch_ext_ptr batch;
|
||||
struct batch_token {
|
||||
llama_token token;
|
||||
llama_seq_id seq_id; // only support single seq for now
|
||||
bool logits;
|
||||
};
|
||||
std::vector<batch_token> tokens;
|
||||
int n_outputs = 0;
|
||||
common_batch() = default;
|
||||
common_batch(int32_t n_tokens, int32_t n_seq_max) {
|
||||
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
||||
tokens.reserve(n_tokens);
|
||||
}
|
||||
void clear() {
|
||||
llama_batch_ext_clear(batch.get());
|
||||
tokens.clear();
|
||||
}
|
||||
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
||||
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
||||
tokens.push_back({token, seq_id, logits});
|
||||
if (logits) {
|
||||
n_outputs++;
|
||||
}
|
||||
}
|
||||
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
|
||||
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
|
||||
tokens.push_back({token, seq_ids[0], logits});
|
||||
if (logits) {
|
||||
n_outputs++;
|
||||
}
|
||||
}
|
||||
void set_logits_last() {
|
||||
if (!tokens.empty()) {
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
tokens.back().logits = true;
|
||||
}
|
||||
}
|
||||
int32_t get_n_tokens() const {
|
||||
return (int32_t)tokens.size();
|
||||
}
|
||||
llama_batch_ext * get() {
|
||||
return batch.get();
|
||||
}
|
||||
common_batch get_view(int32_t offset, int32_t n_tokens) {
|
||||
common_batch view;
|
||||
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
|
||||
view.tokens.reserve(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
view.tokens.push_back(tokens[offset + i]);
|
||||
if (tokens[offset + i].logits) {
|
||||
view.n_outputs++;
|
||||
}
|
||||
}
|
||||
return view;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
@ -1224,7 +1224,7 @@ struct server_slot {
|
||||
// only used for completion/embedding/infill/rerank
|
||||
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
common_batch batch_spec;
|
||||
llama_batch_ext_ptr batch_spec;
|
||||
|
||||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
@ -1248,7 +1248,7 @@ struct server_slot {
|
||||
int32_t n_past = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t i_batch = -1;
|
||||
int32_t i_batch = -1; // TODO: remove and use only sequence-based sampling
|
||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||
|
||||
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||
@ -1796,7 +1796,7 @@ struct server_context {
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
common_batch batch;
|
||||
llama_batch_ext_ptr batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
@ -1922,7 +1922,7 @@ struct server_context {
|
||||
slot.n_predict = params_base.n_predict;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
|
||||
slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1));
|
||||
|
||||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
@ -1958,7 +1958,7 @@ struct server_context {
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
|
||||
batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1));
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
@ -2093,7 +2093,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
|
||||
slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1));
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
@ -2401,7 +2401,7 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_embedding(const server_slot & slot, common_batch & batch) {
|
||||
void send_embedding(const server_slot & slot) {
|
||||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
@ -2410,34 +2410,40 @@ struct server_context {
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
|
||||
const llama_seq_id seq_id = slot.id;
|
||||
|
||||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||
auto tok = batch.tokens[i];
|
||||
if (!tok.logits || tok.seq_id != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
const float * embd = llama_get_embeddings_seq(ctx, seq_id);
|
||||
|
||||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||
SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id);
|
||||
|
||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
continue;
|
||||
}
|
||||
|
||||
// normalize only when there is pooling
|
||||
// TODO: configurable
|
||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||
res->embedding.push_back(embd_res);
|
||||
} else {
|
||||
res->embedding.push_back({ embd, embd + n_embd });
|
||||
}
|
||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||
res->embedding.push_back(embd_res);
|
||||
} else {
|
||||
GGML_ABORT("embeddings without pooling is not supported yet");
|
||||
//for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
|
||||
// auto tok = batch.tokens[i];
|
||||
// if (!tok.logits || tok.seq_id != slot.id) {
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// const float * embd = llama_get_embeddings_ith(ctx, tok.seq_id);
|
||||
// if (embd == NULL) {
|
||||
// SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||
|
||||
// res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// res->embedding.push_back({ embd, embd + n_embd });
|
||||
//}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "%s", "sending embeddings\n");
|
||||
@ -2445,30 +2451,20 @@ struct server_context {
|
||||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
void send_rerank(const server_slot & slot, common_batch & batch) {
|
||||
void send_rerank(const server_slot & slot) {
|
||||
auto res = std::make_unique<server_task_result_rerank>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->n_tokens = slot.n_prompt_tokens;
|
||||
|
||||
for (int i = 0; i < batch.get_n_tokens(); ++i) {
|
||||
auto tok = batch.tokens[i];
|
||||
if (!tok.logits || tok.seq_id != slot.id) {
|
||||
continue;
|
||||
}
|
||||
const llama_seq_id seq_id = slot.id;
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
|
||||
|
||||
res->score = -1e6;
|
||||
continue;
|
||||
}
|
||||
const float * embd = llama_get_embeddings_seq(ctx, seq_id);
|
||||
if (embd == NULL) {
|
||||
SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id);
|
||||
|
||||
res->score = -1e6;
|
||||
} else {
|
||||
res->score = embd[0];
|
||||
}
|
||||
|
||||
@ -2854,7 +2850,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// start populating the batch for this iteration
|
||||
batch.clear();
|
||||
llama_batch_ext_clear(batch.get());
|
||||
|
||||
// track if given slot can be batched with slots already in the batch
|
||||
server_slot * slot_batched = nullptr;
|
||||
@ -2876,9 +2872,9 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
slot.i_batch = batch.get_n_tokens();
|
||||
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get());
|
||||
|
||||
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
|
||||
llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, &slot.id, 1, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
@ -2895,7 +2891,7 @@ struct server_context {
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
|
||||
if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// check if we can batch this slot with the previous one
|
||||
if (slot.is_processing()) {
|
||||
@ -3061,7 +3057,7 @@ struct server_context {
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.is_non_causal()) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
|
||||
if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@ -3081,11 +3077,12 @@ struct server_context {
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
|
||||
while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) {
|
||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
|
||||
//batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
|
||||
llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
@ -3095,13 +3092,14 @@ struct server_context {
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n",
|
||||
slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.n_past == slot.n_prompt_tokens) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.get_n_tokens() > 0);
|
||||
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
@ -3111,27 +3109,28 @@ struct server_context {
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.set_logits_last();
|
||||
//batch.set_logits_last();
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.get_n_tokens() - 1;
|
||||
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()));
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.get_n_tokens() >= n_batch) {
|
||||
if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.get_n_tokens() == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
return;
|
||||
}
|
||||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get()));
|
||||
|
||||
if (slot_batched) {
|
||||
// make sure we're in the right embedding mode
|
||||
@ -3141,10 +3140,10 @@ struct server_context {
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
|
||||
for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i);
|
||||
|
||||
common_batch batch_view = batch.get_view(i, n_tokens);
|
||||
llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens));
|
||||
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
metrics.on_decoded(slots);
|
||||
@ -3177,14 +3176,14 @@ struct server_context {
|
||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
|
||||
// prompt evaluated for embedding
|
||||
send_embedding(slot, batch_view);
|
||||
send_embedding(slot);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
|
||||
send_rerank(slot, batch_view);
|
||||
send_rerank(slot);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
@ -3281,14 +3280,17 @@ struct server_context {
|
||||
}
|
||||
|
||||
// construct the speculation batch
|
||||
slot.batch_spec.clear();
|
||||
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
|
||||
//slot.batch_spec.clear();
|
||||
//slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
|
||||
llama_batch_ext_clear(slot.batch_spec.get());
|
||||
llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true);
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
|
||||
//slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
|
||||
llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get()));
|
||||
|
||||
llama_decode_ext(ctx, slot.batch_spec.get());
|
||||
|
||||
@ -4147,6 +4149,11 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' is not yet supported. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
@ -4241,6 +4248,11 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' cannot be used with reranking. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// TODO: implement
|
||||
|
@ -88,13 +88,19 @@ def test_embedding_pooling_none():
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": "hello hello hello",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert 'embedding' in res.body[0]
|
||||
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
||||
|
||||
# make sure embedding vector is not normalized
|
||||
for x in res.body[0]['embedding']:
|
||||
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||
# /embeddings does not support pooling type 'none'
|
||||
assert res.status_code == 400
|
||||
assert "error" in res.body
|
||||
|
||||
# TODO: re-enable when we figure out how to support pooling type 'none'
|
||||
#assert res.status_code == 200
|
||||
#assert 'embedding' in res.body[0]
|
||||
#assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
|
||||
|
||||
## make sure embedding vector is not normalized
|
||||
#for x in res.body[0]['embedding']:
|
||||
# assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
|
||||
|
||||
|
||||
def test_embedding_pooling_none_oai():
|
||||
|
Reference in New Issue
Block a user