diff --git a/common/common.cpp b/common/common.cpp index 4cc40ed8b..218f1e1dc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -934,7 +934,7 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } - if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; } @@ -1041,7 +1041,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } - llama_kv_self_clear(lctx); + llama_memory_clear(llama_get_memory(lctx), true); llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); diff --git a/common/speculative.cpp b/common/speculative.cpp index ccad70fa9..843bd1ddb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft( auto & smpl = spec->smpl; auto & prompt = spec->prompt; + auto * mem = llama_get_memory(ctx); + int reuse_i = 0; int reuse_n = 0; @@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft( result.reserve(params.n_draft); if (reuse_n == 0) { - llama_kv_self_clear(ctx); + llama_memory_clear(mem, false); prompt.clear(); } else { @@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_kv_self_seq_rm (ctx, 0, 0, reuse_i); - llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i); } if (reuse_n < (int) prompt.size()) { - llama_kv_self_seq_rm (ctx, 0, reuse_n, -1); + llama_memory_seq_rm (mem, 0, reuse_n, -1); prompt.erase(prompt.begin() + reuse_n, prompt.end()); } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 514989e34..fd90bbec5 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens) } if n_parallel > 1 { diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 8bef7f8f6..681929d27 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 539bc4d60..041da61c7 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -45,7 +45,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); @@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_token eos_token = llama_vocab_eos(vocab); - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 9654cd53c..711ddc5d1 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_self_clear(reinterpret_cast(context)); + llama_memory_clear(llama_get_memory(reinterpret_cast(context)), true); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index f6e31abc9..dc2bafc88 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -210,7 +210,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000; @@ -223,7 +223,7 @@ actor LlamaContext { // bench text generation - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000; @@ -242,7 +242,7 @@ actor LlamaContext { let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000; - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -292,7 +292,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), true) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 5f8620973..1e26d8221 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -60,6 +60,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); // Tokenize the prompt @@ -94,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_self_seq_cp(ctx, 0, s, -1, -1); + llama_memory_seq_cp(mem, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -427,17 +429,17 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_self_seq_rm(ctx, -1, n_past, -1); + llama_memory_seq_rm(mem, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_self_seq_keep(ctx, seq_id_best); - llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1); + llama_memory_seq_keep(mem, seq_id_best); + llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1); + llama_memory_seq_rm (mem, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_self_seq_cp(ctx, 0, s, -1, -1); + llama_memory_seq_cp(mem, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2ee502939..2bfa26b55 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -181,7 +181,7 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_self_seq_rm(ctx, 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1); common_batch_clear(batch_tgt); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index cd85bea9a..d53e089a4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -194,6 +194,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); // load the prompts from an external file if there are any @@ -259,7 +261,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_self_seq_cp(ctx, 0, i, -1, -1); + llama_memory_seq_cp(mem, 0, i, -1, -1); } LOG_INF("\n"); @@ -286,9 +288,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_self_seq_rm(ctx, i, -1, -1); + llama_memory_seq_rm(mem, i, -1, -1); // but keep the system prompt - llama_kv_self_seq_cp(ctx, 0, i, -1, -1); + llama_memory_seq_cp(mem, 0, i, -1, -1); } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -447,8 +449,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_memory_seq_rm(mem, client.id + 1, -1, -1); + llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 5ac881b45..8a4faa383 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,6 +126,8 @@ int main(int argc, char ** argv) { int n_past = 0; + auto * mem = llama_get_memory(ctx); + // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { if (i > 0 && n_grp > 1) { @@ -133,10 +135,10 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd); + llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; } common_batch_clear(batch); @@ -166,10 +168,10 @@ int main(int argc, char ** argv) { LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; common_batch_clear(batch); @@ -195,10 +197,10 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; } } diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 754da1411..042e12c2b 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), false); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 760ebbbf0..db79588f1 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -196,7 +196,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_self_clear(ctx3); + llama_memory_clear(llama_get_memory(ctx3), true); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 6608d4bea..2aee0a919 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -98,7 +98,7 @@ int main(int argc, char ** argv) { auto generate = [&](const std::string & prompt) { std::string response; - const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0; // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0); + int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 0783ed4a4..99196c9d0 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -217,7 +217,7 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 561c30883..0adffdb00 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -142,6 +142,8 @@ int main(int argc, char ** argv) { } } + auto * mem_tgt = llama_get_memory(ctx_tgt); + auto * mem_dft = llama_get_memory(ctx_dft); // Tokenize the prompt std::vector inp; @@ -420,14 +422,14 @@ int main(int argc, char ** argv) { { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_self_seq_keep(ctx_dft, s_keep); - llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_self_seq_keep(ctx_dft, 0); + llama_memory_seq_keep(mem_dft, s_keep); + llama_memory_seq_cp (mem_dft, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_dft, 0); - llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_self_seq_keep(ctx_tgt, s_keep); - llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_self_seq_keep(ctx_tgt, 0); + llama_memory_seq_rm (mem_tgt, s_keep, n_past_tgt, -1); + llama_memory_seq_keep(mem_tgt, s_keep); + llama_memory_seq_cp (mem_tgt, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -444,7 +446,7 @@ int main(int argc, char ** argv) { common_batch_clear(batch_dft); common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); + llama_memory_seq_rm(mem_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -503,8 +505,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_memory_seq_rm(mem_dft, n_seq_cur, -1, -1); + llama_memory_seq_cp(mem_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -585,9 +587,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_self_seq_keep(ctx_tgt, 0); + llama_memory_seq_keep(mem_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_memory_seq_cp(mem_tgt, 0, s, -1, -1); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/include/llama.h b/include/llama.h index aa5330e2a..015a57898 100644 --- a/include/llama.h +++ b/include/llama.h @@ -625,7 +625,10 @@ extern "C" { // // Clear the memory contents - LLAMA_API void llama_memory_clear(llama_memory_t mem); + // If data == true, the data buffers will also be cleared together with the metadata + LLAMA_API void llama_memory_clear( + llama_memory_t mem, + bool data); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails @@ -705,74 +708,82 @@ extern "C" { "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx); + DEPRECATED(LLAMA_API void llama_kv_self_clear( + struct llama_context * ctx), + "Use llama_memory_clear() instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_self_seq_rm( + DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_rm() instead"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_cp( + DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_cp() instead"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_self_seq_keep( + DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_keep() instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_add( + DEPRECATED(LLAMA_API void llama_kv_self_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - llama_pos delta); + llama_pos delta), + "Use llama_memory_seq_add() instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_div( + DEPRECATED(void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - int d); + int d), + "Use llama_memory_seq_div() instead"); // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_min( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_min() instead"); // Returns the largest position present in the KV cache for the specified sequence // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_max( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_max() instead"); // Defragment the KV cache // This will be applied: @@ -781,7 +792,8 @@ extern "C" { "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting - LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); + DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), + "use llama_memory_can_shift() instead"); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ea1910685..b130b484b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -422,6 +422,7 @@ llama_memory_t llama_context::get_memory() const { return memory.get(); } +// deprecated void llama_context::kv_self_defrag_sched() { if (!memory) { return; @@ -430,6 +431,7 @@ void llama_context::kv_self_defrag_sched() { memory_force_optimize = true; } +// deprecated bool llama_context::kv_self_update(bool optimize) { if (!memory) { return false; @@ -2053,7 +2055,7 @@ void llama_context::opt_epoch_iter( const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); - memory->clear(); + memory->clear(true); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -2426,8 +2428,12 @@ llama_memory_t llama_get_memory(const struct llama_context * ctx) { return ctx->get_memory(); } -void llama_memory_clear(llama_memory_t mem) { - mem->clear(); +void llama_memory_clear(llama_memory_t mem, bool data) { + if (!mem) { + return; + } + + mem->clear(data); } bool llama_memory_seq_rm( @@ -2435,6 +2441,10 @@ bool llama_memory_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (!mem) { + return true; + } + return mem->seq_rm(seq_id, p0, p1); } @@ -2444,12 +2454,20 @@ void llama_memory_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (!mem) { + return; + } + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); } void llama_memory_seq_keep( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return; + } + mem->seq_keep(seq_id); } @@ -2459,6 +2477,10 @@ void llama_memory_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { + if (!mem) { + return; + } + mem->seq_add(seq_id, p0, p1, delta); } @@ -2468,22 +2490,38 @@ void llama_memory_seq_div( llama_pos p0, llama_pos p1, int d) { + if (!mem) { + return; + } + mem->seq_div(seq_id, p0, p1, d); } llama_pos llama_memory_seq_pos_min( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return -1; + } + return mem->seq_pos_min(seq_id); } llama_pos llama_memory_seq_pos_max( llama_memory_t mem, llama_seq_id seq_id) { + if (!mem) { + return -1; + } + return mem->seq_pos_max(seq_id); } bool llama_memory_can_shift(llama_memory_t mem) { + if (!mem) { + return false; + } + return mem->get_can_shift(); } @@ -2534,15 +2572,17 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) { return res; } +// deprecated void llama_kv_self_clear(llama_context * ctx) { auto * kv = llama_get_memory(ctx); if (!kv) { return; } - llama_memory_clear(kv); + llama_memory_clear(kv, true); } +// deprecated bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, @@ -2556,6 +2596,7 @@ bool llama_kv_self_seq_rm( return llama_memory_seq_rm(kv, seq_id, p0, p1); } +// deprecated void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, @@ -2570,6 +2611,7 @@ void llama_kv_self_seq_cp( llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); } +// deprecated void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2579,6 +2621,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { llama_memory_seq_keep(kv, seq_id); } +// deprecated void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, @@ -2593,6 +2636,7 @@ void llama_kv_self_seq_add( llama_memory_seq_add(kv, seq_id, p0, p1, delta); } +// deprecated void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, @@ -2607,6 +2651,7 @@ void llama_kv_self_seq_div( llama_memory_seq_div(kv, seq_id, p0, p1, d); } +// deprecated llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2616,6 +2661,7 @@ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { return llama_memory_seq_pos_min(kv, seq_id); } +// deprecated llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { auto * kv = llama_get_memory(ctx); if (!kv) { @@ -2631,6 +2677,7 @@ void llama_kv_self_defrag(llama_context * ctx) { ctx->kv_self_defrag_sched(); } +// deprecated bool llama_kv_self_can_shift(const llama_context * ctx) { auto * kv = llama_get_memory(ctx); if (!kv) { diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 77bd57065..f5c6dcd66 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -117,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( } } -void llama_kv_cache_recurrent::clear() { +void llama_kv_cache_recurrent::clear(bool data) { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); cells[i].src = -1; cells[i].tail = -1; } + head = 0; used = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -723,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -880,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { kv_cell & cell = cells[i]; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index cb813dfe8..d1da12256 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -39,7 +39,7 @@ public: llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 3aa606c84..28d182654 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( hparams.n_swa, hparams.swa_type); } -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); +void llama_kv_cache_unified_iswa::clear(bool data) { + kv_base->clear(data); + kv_swa ->clear(data); } bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 3fabcd6b8..3dbf33ed7 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -43,7 +43,7 @@ public: bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 5354f808c..3a40463fd 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -129,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified( } } -void llama_kv_cache_unified::clear() { +void llama_kv_cache_unified::clear(bool data) { cells.reset(); head = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -1319,7 +1321,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -1500,7 +1502,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index d01a9abd7..49f410ef6 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -68,7 +68,7 @@ public: bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; diff --git a/src/llama-memory.h b/src/llama-memory.h index 5993b59be..991aae781 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -90,7 +90,8 @@ struct llama_memory_i { // ops // - virtual void clear() = 0; + // if data == true, the data buffers will also be cleared together with the metadata + virtual void clear(bool data) = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 119df471b..a0a2e5ac5 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -57,6 +57,8 @@ int main(int argc, char ** argv) { return 1; } + auto * mem = llama_get_memory(ctx); + const int32_t n_kv_max = llama_n_ctx(ctx); llama_batch batch = llama_batch_init(n_kv_max, 0, 1); @@ -132,7 +134,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_self_clear(ctx); + llama_memory_clear(mem, false); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_ERR("%s: llama_decode() failed\n", __func__); @@ -141,7 +143,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_self_seq_cp(ctx, 0, i, -1, -1); + llama_memory_seq_cp(mem, 0, i, -1, -1); } } diff --git a/tools/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp index 2a9071550..d2d97e05c 100644 --- a/tools/cvector-generator/cvector-generator.cpp +++ b/tools/cvector-generator/cvector-generator.cpp @@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { } static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 81d0404d6..daad44e59 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -498,7 +498,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); llama_batch batch = llama_batch_init(n_batch, 0, 1); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 803630d26..e59d61f19 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1900,7 +1900,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), false); // cool off before the test if (params.delay) { @@ -1948,7 +1948,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), false); if (t.n_depth > 0) { if (params.progress) { diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 1bd2be2d9..19b247b0d 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -147,6 +147,8 @@ int main(int argc, char ** argv) { return 1; } + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); auto chat_templates = common_chat_templates_init(model, params.chat_template); @@ -351,7 +353,7 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1); + llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1); } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -599,8 +601,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard); + llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -623,9 +625,9 @@ int main(int argc, char ** argv) { LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_memory_seq_add(mem, 0, ga_i, n_past, ib*bd); + llama_memory_seq_div(mem, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_memory_seq_add(mem, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 40deab5ab..599e682e0 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -342,7 +342,7 @@ int main(int argc, char ** argv) { } if (line == "/clear") { ctx.n_past = 0; - llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS + llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS LOG("Chat history cleared\n\n"); continue; } diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index b5cdf5beb..189dcb3d7 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -361,7 +361,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); llama_batch batch = llama_batch_init(n_batch, 0, 1); @@ -547,7 +547,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -924,7 +924,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { return; } - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1217,7 +1217,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) return; } - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1592,7 +1592,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par return; } - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1782,7 +1782,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } // clear the KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); llama_batch batch = llama_batch_init(n_batch, 0, 1); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 4aef93863..c65afd61e 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -939,7 +939,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama // Function to tokenize the prompt static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_kv_self_seq_pos_max(llama_data.context.get(), 0) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == 0; const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); prompt_tokens.resize(n_prompt_tokens); @@ -955,7 +955,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt // Check if we have enough space in the context to evaluate this batch static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { const int n_ctx = llama_n_ctx(ctx.get()); - const int n_ctx_used = llama_kv_self_seq_pos_max(ctx.get(), 0); + const int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx.get()), 0); if (n_ctx_used + batch.n_tokens > n_ctx) { printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 9038df4c3..2e78dcd7b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2006,7 +2006,7 @@ struct server_context { } } - if (!llama_kv_self_can_shift(ctx)) { + if (!llama_memory_can_shift(llama_get_memory(ctx))) { if (params_base.ctx_shift) { params_base.ctx_shift = false; SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); @@ -2224,7 +2224,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); clean_kv_cache = false; } @@ -2910,7 +2910,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_self_seq_rm(ctx, slot->id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); slot->cache_tokens.clear(); auto res = std::make_unique(); @@ -2985,8 +2985,8 @@ struct server_context { SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); // add generated tokens to cache { @@ -3189,8 +3189,8 @@ struct server_context { const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c); - llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); @@ -3212,7 +3212,7 @@ struct server_context { } if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { - const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); @@ -3247,9 +3247,9 @@ struct server_context { } // keep only the common part - if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_self_seq_rm(ctx, slot.id, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left slot.n_past = 0; @@ -3589,7 +3589,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result;