mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
kv-cache : rework kv_idxs, support seq_cp
ggml-ci
This commit is contained in:
@ -290,11 +290,9 @@ int main(int argc, char ** argv) {
|
|||||||
for (int i = 1; i <= n_clients; ++i) {
|
for (int i = 1; i <= n_clients; ++i) {
|
||||||
llama_memory_seq_rm(mem, i, -1, -1);
|
llama_memory_seq_rm(mem, i, -1, -1);
|
||||||
|
|
||||||
if (is_sp_shared) {
|
|
||||||
// but keep the system prompt
|
// but keep the system prompt
|
||||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INF("%s: clearing the KV cache\n", __func__);
|
LOG_INF("%s: clearing the KV cache\n", __func__);
|
||||||
}
|
}
|
||||||
@ -453,10 +451,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||||
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
|
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
|
||||||
|
|
||||||
if (is_sp_shared) {
|
|
||||||
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
|
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
|
@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kv_idxs) {
|
if (self_k_idxs) {
|
||||||
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
|
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self_v_idxs) {
|
||||||
|
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kv_idxs) {
|
if (self_k_idxs) {
|
||||||
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
|
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self_kv_idxs_swa) {
|
if (self_v_idxs) {
|
||||||
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
|
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self_k_idxs_swa) {
|
||||||
|
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self_v_idxs_swa) {
|
||||||
|
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
@ -1209,8 +1221,8 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||||
|
|
||||||
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
ggml_set_input(inp->self_kv_idxs);
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
@ -1243,10 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
const auto & kv_idxs = inp->get_kv_idxs();
|
const auto & k_idxs = inp->get_k_idxs();
|
||||||
|
const auto & v_idxs = inp->get_v_idxs();
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
@ -1299,10 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
|
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||||
|
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
@ -1444,8 +1458,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
ggml_set_input(inp->self_kv_idxs);
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
@ -1458,8 +1472,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
|
|
||||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
ggml_set_input(inp->self_kv_idxs_swa);
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
@ -248,11 +248,13 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
// TODO: should this be I64?
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||||
@ -277,13 +279,18 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||||
|
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||||
|
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
|
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||||
|
@ -40,7 +40,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
|
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*hparams.n_layer*ggml_tensor_overhead()),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
@ -117,8 +117,17 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
ggml_format_name(k, "cache_k_l%d", il);
|
ggml_format_name(k, "cache_k_l%d", il);
|
||||||
ggml_format_name(v, "cache_v_l%d", il);
|
ggml_format_name(v, "cache_v_l%d", il);
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> k_seq;
|
||||||
|
std::vector<ggml_tensor *> v_seq;
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||||
|
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
||||||
|
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
||||||
|
}
|
||||||
|
|
||||||
map_layer_ids[il] = layers.size();
|
map_layer_ids[il] = layers.size();
|
||||||
layers.push_back({ il, k, v });
|
|
||||||
|
layers.push_back({ il, k, v, k_seq, v_seq, });
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||||
@ -262,9 +271,35 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||||||
is_full = false;
|
is_full = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(is_full && "seq_cp() is only supported for full contexts");
|
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
|
||||||
|
|
||||||
GGML_ABORT("TODO: implement\n");
|
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
|
||||||
|
|
||||||
|
for (uint32_t il = 0; il < layers.size(); ++il) {
|
||||||
|
const auto & layer = layers[il];
|
||||||
|
|
||||||
|
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
|
||||||
|
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
|
||||||
|
|
||||||
|
// TODO: do we need synchronization here?
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: support this:
|
||||||
|
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
|
||||||
|
|
||||||
|
v_cells[s1].reset();
|
||||||
|
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
|
||||||
|
if (v_cells[s0].seq_has(i, seq_id_src)) {
|
||||||
|
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
|
||||||
|
v_cells[s1].seq_add(i, seq_id_dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
v_heads[s1] = v_heads[s0];
|
||||||
|
|
||||||
|
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||||
|
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||||
@ -929,7 +964,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||||||
|
|
||||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||||
|
|
||||||
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
|
const uint64_t kv_size = get_size();
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
// note: v->nb[1] <= v->nb[2]
|
// note: v->nb[1] <= v->nb[2]
|
||||||
@ -937,20 +972,20 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
||||||
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
||||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
||||||
size_virt, // v->nb[3]
|
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
|
||||||
size_virt*sinfo.s0);
|
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)*sinfo.s0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// note: v->nb[1] > v->nb[2]
|
// note: v->nb[1] > v->nb[2]
|
||||||
return ggml_view_4d(ctx, v,
|
return ggml_view_4d(ctx, v,
|
||||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
||||||
ggml_row_size(v->type, v->ne[1]*n_seq_virt*hparams.n_embd_head_v), // v->nb[1]
|
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
||||||
ggml_row_size(v->type, v->ne[1]*n_seq_virt), // v->nb[2]
|
ggml_row_size(v->type, kv_size), // v->nb[2]
|
||||||
ggml_row_size(v->type, v->ne[1]), // v->nb[3]
|
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
|
||||||
ggml_row_size(v->type, v->ne[1]*sinfo.s0));
|
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)*sinfo.s0));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * k = layers[ikv].k;
|
auto * k = layers[ikv].k;
|
||||||
@ -960,10 +995,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
|
|
||||||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||||
|
|
||||||
if (kv_idxs && supports_set_rows) {
|
if (k_idxs && supports_set_rows) {
|
||||||
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
||||||
|
|
||||||
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
|
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
@ -978,7 +1013,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
return ggml_cpy(ctx, k_cur, k_view);
|
return ggml_cpy(ctx, k_cur, k_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * v = layers[ikv].v;
|
auto * v = layers[ikv].v;
|
||||||
@ -988,30 +1023,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
|
|
||||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||||
|
|
||||||
if (kv_idxs && supports_set_rows) {
|
if (v_idxs && supports_set_rows) {
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
||||||
|
|
||||||
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
|
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// the row becomes a single element
|
// the row becomes a single element
|
||||||
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]);
|
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
|
||||||
|
|
||||||
// note: the V cache is transposed when not using flash attention
|
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
|
||||||
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
|
||||||
|
|
||||||
// note: we can be more explicit here at the cost of extra cont
|
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||||
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
|
|
||||||
//v_cur = ggml_transpose(ctx, v_cur);
|
|
||||||
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
|
||||||
|
|
||||||
// we broadcast the KV indices n_embd_v_gqa times
|
|
||||||
// v [1, n_kv*n_seq_virt, n_embd_v_gqa]
|
|
||||||
// v_cur [1, n_tokens, n_embd_v_gqa]
|
|
||||||
// kv_idxs [n_tokens, 1, 1]
|
|
||||||
|
|
||||||
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
@ -1036,7 +1060,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
return ggml_cpy(ctx, v_cur, v_view);
|
return ggml_cpy(ctx, v_cur, v_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
|
||||||
|
ggml_set_input(k_idxs);
|
||||||
|
|
||||||
|
return k_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * v_idxs;
|
||||||
|
|
||||||
|
if (!v_trans) {
|
||||||
|
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
} else {
|
||||||
|
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
|
||||||
|
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_input(v_idxs);
|
||||||
|
|
||||||
|
return v_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
if (!supports_set_rows) {
|
if (!supports_set_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1056,6 +1107,58 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
if (!v_trans) {
|
||||||
|
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||||
|
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
|
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// note: the V cache is transposed when not using flash attention
|
||||||
|
const int64_t kv_size = get_size();
|
||||||
|
|
||||||
|
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
|
||||||
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||||
|
const int64_t offs = sinfo.seq_id_virt[s]*kv_size*n_embd_v_gqa;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
|
||||||
|
int32_t * data = (int32_t *) dst->data;
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||||
|
const auto & cells = v_cells[s];
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
|
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
@ -1137,20 +1240,6 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
|
||||||
|
|
||||||
int32_t * data = (int32_t *) dst->data;
|
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
|
||||||
const auto & cells = v_cells[s];
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
|
||||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
@ -2112,22 +2201,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
|||||||
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
|
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
|
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_k_idxs(ctx, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_v_idxs(ctx, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
|
||||||
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
@ -143,8 +143,8 @@ public:
|
|||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// preparation API
|
// preparation API
|
||||||
@ -165,12 +165,18 @@ public:
|
|||||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||||
|
|
||||||
//
|
//
|
||||||
// set_input API
|
// input API
|
||||||
//
|
//
|
||||||
|
|
||||||
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
|
||||||
|
void set_input_k_shift(ggml_tensor * dst) const;
|
||||||
|
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -184,6 +190,9 @@ private:
|
|||||||
|
|
||||||
ggml_tensor * k;
|
ggml_tensor * k;
|
||||||
ggml_tensor * v;
|
ggml_tensor * v;
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> k_seq;
|
||||||
|
std::vector<ggml_tensor *> v_seq;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
@ -309,12 +318,17 @@ public:
|
|||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
void set_input_k_shift(ggml_tensor * dst) const;
|
void set_input_k_shift(ggml_tensor * dst) const;
|
||||||
|
|
||||||
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user