mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
server : fix cache_tokens bug with no cache_prompt (#13533)
This commit is contained in:
@ -2951,7 +2951,8 @@ struct server_context {
|
||||
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// add generated tokens to cache
|
||||
{
|
||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
@ -2996,10 +2997,7 @@ struct server_context {
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
||||
@ -3171,6 +3169,11 @@ struct server_context {
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove the entire KV cache
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
slot.cache_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@ -3204,7 +3207,7 @@ struct server_context {
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
slot.cache_tokens.keep_first(slot.n_past);
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.n_past < slot.n_prompt_tokens
|
||||
@ -3221,7 +3224,8 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// add the image chunk to cache
|
||||
{
|
||||
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
@ -3242,9 +3246,7 @@ struct server_context {
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
|
@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
|
||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||
|
||||
|
||||
def test_nocache_long_input_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is"*32,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
server.temperature = 0.0
|
||||
|
@ -1153,7 +1153,7 @@ public:
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
void keep_first(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
// we throw an error if we try to remove a token in the middle of an image
|
||||
|
Reference in New Issue
Block a user