diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 5744e00a8..f6a553239 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -130,6 +130,13 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; + + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + + if (!supports_set_rows) { + LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); + } } void llama_kv_cache_unified::clear(bool data) { @@ -751,15 +758,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ auto * k = layers[ikv].k; + const int64_t n_embd_k_gqa = k->ne[0]; const int64_t n_tokens = k_cur->ne[2]; - if (kv_idxs) { - return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs); + k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + + if (kv_idxs && supports_set_rows) { + return ggml_set_rows(ctx, k, k_cur, kv_idxs); } + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends + ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + n_tokens*n_embd_k_gqa, + ggml_row_size(k->type, n_embd_k_gqa)*head_cur); return ggml_cpy(ctx, k_cur, k_view); } @@ -769,37 +782,43 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; + const int64_t n_embd_v_gqa = v->ne[0]; const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + + if (kv_idxs && supports_set_rows) { + if (!v_trans) { + return ggml_set_rows(ctx, v, v_cur, kv_idxs); + } + + // note: the V cache is transposed when not using flash attention + v_cur = ggml_transpose(ctx, v_cur); + + // the row becomes a single element and we repeat the KV indices d_head times + ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + + v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + + // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast + kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1); + + return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); + } + + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends ggml_tensor * v_view = nullptr; if (!v_trans) { - if (kv_idxs) { - return ggml_set_rows(ctx, v, v_cur, kv_idxs); - } - v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + n_tokens*n_embd_v_gqa, + ggml_row_size(v->type, n_embd_v_gqa)*head_cur); } else { v_cur = ggml_transpose(ctx, v_cur); - // note: the V cache is transposed when not using flash attention - if (kv_idxs) { - // the row becomes a single element and we repeat the KV indices d_head times - // TODO: this seems not very optimal - can we do something better? - v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); - - v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); - - kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1); - - return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); - } - - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, (v->ne[1])*ggml_element_size(v), (head_cur)*ggml_element_size(v)); } @@ -808,6 +827,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ } void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const { + if (!supports_set_rows) { + return; + } + const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 5b5da3dc2..6f8efeb79 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -158,8 +158,13 @@ private: // SWA const uint32_t n_swa = 0; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + int supports_set_rows = false; + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs;