Merge branch 'master' into xsn/private_batch_api

This commit is contained in:
Xuan Son Nguyen
2025-03-13 15:55:18 +01:00
173 changed files with 26425 additions and 16117 deletions

View File

@ -131,9 +131,9 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
auto grammar_triggers = json::array();
for (const auto & trigger : sampling.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json<json>());
}
return json {
@ -170,8 +170,8 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"samplers", samplers},
@ -356,24 +356,6 @@ struct server_task {
}
{
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
common_grammar_trigger trigger;
trigger.word = t.at("word");
trigger.at_start = t.at("at_start");
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
@ -383,12 +365,39 @@ struct server_task {
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) {
auto ct = common_grammar_trigger::from_json(t);
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto & word = ct.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
trigger.value = word;
trigger.token = token;
params.sampling.grammar_triggers.push_back(std::move(trigger));
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
params.sampling.grammar_triggers.push_back(ct);
}
}
}
if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) {
throw std::runtime_error("Error: no triggers set for lazy grammar!");
}
}
@ -742,7 +751,10 @@ struct server_task_result_cmpl_final : server_task_result {
{"name", tc.name},
{"arguments", tc.arguments},
}},
{"id", tc.id},
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
// We only generate a random id for the ones that don't generate one by themselves
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
@ -1304,7 +1316,7 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
bool can_batch_with(server_slot & other_slot) {
bool can_batch_with(server_slot & other_slot) const {
return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora);
}
@ -1888,6 +1900,7 @@ struct server_context {
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
@ -2023,6 +2036,18 @@ struct server_context {
return ret;
}
bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & token : tokens) {
if (token < 0 || token >= n_vocab) {
return false;
}
}
return true;
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot.reset();
slot.id_task = task.id;
@ -2037,11 +2062,16 @@ struct server_context {
slot.lora = task.params.lora;
}
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
if (!can_detokenize) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict);
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict;
}
@ -2077,7 +2107,7 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_kv_self_clear(ctx);
clean_kv_cache = false;
}
@ -2142,14 +2172,6 @@ struct server_context {
}
if (slot.has_new_line) {
// if we have already seen a new line, we stop after a certain time limit
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
@ -2188,6 +2210,14 @@ struct server_context {
// check if there is a new line in the generated text
if (result.text_to_send.find('\n') != std::string::npos) {
slot.has_new_line = true;
// if we have seen a new line, we stop after a certain time limit, but only upon another new line
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
}
}
// if context shift is disabled, we stop when it reaches the context limit
@ -2621,8 +2651,8 @@ struct server_context {
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res->t_start = metrics.t_start;
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
@ -2738,7 +2768,7 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
auto res = std::make_unique<server_task_result_slot_erase>();
@ -2806,8 +2836,8 @@ struct server_context {
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2998,8 +3028,8 @@ struct server_context {
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@ -3037,9 +3067,9 @@ struct server_context {
}
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
@ -3271,7 +3301,7 @@ struct server_context {
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;