diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 40ff64838..9ae7e4dbb 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -81,6 +81,14 @@ int main(int argc, char ** argv) { params.embedding = true; + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the + // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache + // in order to support any number of prompts + if (params.n_parallel == 1) { + LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); + params.kv_unified = true; + } + // utilize the full context if (params.n_batch < params.n_ctx) { LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index db79588f1..1065ec6bb 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -15,6 +15,12 @@ int main(int argc, char ** argv) { return 1; } + if (params.n_parallel == 1) { + // the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache + printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__); + params.kv_unified = true; + } + common_init(); if (params.n_predict < 0) { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index a546063c0..8698d89ac 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -59,7 +59,7 @@ bool llama_batch_allocr::init( for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); return false; } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fc1557a2d..9658abf96 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -185,7 +185,7 @@ llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test(test-regex-partial.cpp) -llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4) +llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) diff --git a/tests/test-thread-safety.cpp b/tests/test-thread-safety.cpp index d525b7430..853495b00 100644 --- a/tests/test-thread-safety.cpp +++ b/tests/test-thread-safety.cpp @@ -34,6 +34,9 @@ int main(int argc, char ** argv) { auto cparams = common_context_params_to_llama(params); + // each context has a single sequence + cparams.n_seq_max = 1; + int dev_count = ggml_backend_dev_count(); int gpu_dev_count = 0; for (int i = 0; i < dev_count; ++i) {