diff --git a/common/arg.cpp b/common/arg.cpp index 40af7e574..ec50bda57 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.swa_full = true; } ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"--attn-streams", "-as"}, + string_format("use multiple streams when computing the attention (default: %s)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.attn_streams ? "true" : "false"), + [](common_params & params) { + params.attn_streams = true; + } + ).set_env("LLAMA_ARG_ATTN_STREAMS")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), diff --git a/common/common.cpp b/common/common.cpp index e4e71ad13..7dcc277fd 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1157,6 +1157,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.no_perf = params.no_perf; cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; + cparams.attn_streams = params.attn_streams; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index 8922090e7..44a14c1a4 100644 --- a/common/common.h +++ b/common/common.h @@ -330,6 +330,7 @@ struct common_params { bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + bool attn_streams = false; // multi-stream attention and KV cache buffers bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool use_mmap = true; // use mmap for faster loads diff --git a/include/llama.h b/include/llama.h index 3eda9bc68..c16b72a6f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -374,6 +374,10 @@ extern "C" { bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 + + bool attn_streams; // if enabled, use multiple streams during the attention (determined by n_seq_max) + // NOTE: this requires support for the ggml_set_rows() operator + // this flag can improve the performance for parallel, multi-sequence use cases }; // model quantization parameters diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ebec21faf..2fbdf0bc9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -33,9 +33,6 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } - const char * LLAMA_HT = getenv("LLAMA_HT"); - cparams.kv_unified = (LLAMA_HT && atoi(LLAMA_HT) > 0) ? false : true; - cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -104,7 +101,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.op_offload = params.op_offload; + cparams.op_offload = params.op_offload; + cparams.attn_streams = params.attn_streams; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -115,6 +113,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: attn_streams = %s\n", __func__, cparams.attn_streams ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -270,7 +269,7 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -314,6 +313,10 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -478,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -2192,6 +2195,7 @@ llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, + /*.attn_streams =*/ false, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 38750affc..18cb7786c 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -33,7 +33,7 @@ struct llama_cparams { bool no_perf; bool warmup; bool op_offload; - bool kv_unified; + bool attn_streams; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 373ae53f4..0e65d0a56 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1001,7 +1001,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); const auto n_kv = inp->mctx->get_attn()->get_n_kv(); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1; inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); @@ -1212,7 +1212,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); @@ -1459,7 +1459,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + const auto n_stream = cparams.attn_streams ? ubatch.n_seqs_unq : 1; { const auto n_kv = mctx_cur->get_base()->get_n_kv(); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index c5eb0fb85..3874777c8 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -317,14 +317,23 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id // TODO: do we need synchronization here? } - // TODO: support this: - GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift"); - v_cells[s1].reset(); for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { if (v_cells[s0].seq_has(i, seq_id_src)) { - v_cells[s1].pos_set(i, v_cells[s0].pos_get(i)); + llama_pos pos = v_cells[s0].pos_get(i); + llama_pos shift = v_cells[s0].get_shift(i); + + if (shift != 0) { + pos -= shift; + assert(pos >= 0); + } + + v_cells[s1].pos_set(i, pos); v_cells[s1].seq_add(i, seq_id_dst); + + if (shift != 0) { + v_cells[s1].pos_add(i, shift); + } } } @@ -1057,7 +1066,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported"); + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, @@ -1101,7 +1110,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported"); + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); ggml_tensor * v_view = nullptr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c1ea3e9eb..e143a7d8f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14712,7 +14712,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, uint32_t n_ctx_per_stream = cparams.n_ctx; - if (!cparams.kv_unified) { + if (cparams.attn_streams) { n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); @@ -14735,7 +14735,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, params.swa_full, - cparams.kv_unified, + !cparams.attn_streams, n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, @@ -14750,7 +14750,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.kv_unified, + !cparams.attn_streams, n_ctx_per_stream, cparams.n_seq_max, padding,