llama : add "virtual sequences"

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-23 16:29:02 +03:00
parent 36f8e20d08
commit 52b9007176
15 changed files with 504 additions and 216 deletions

View File

@ -230,6 +230,7 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
uint64_t nb31;
uint64_t nb32;
int32_t ne1;
int32_t ne2;
float scale;

View File

@ -4882,6 +4882,7 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,

View File

@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + iq3*args.nb32);
const float m = pm[ic + tiisg];
@ -4131,7 +4131,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31);
device const half * pm = (device const half *) (mask + iq1*args.nb31 + iq3*args.nb32);
float slope = 1.0f;

View File

@ -3526,7 +3526,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
@ -4504,7 +4504,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");

View File

@ -460,6 +460,8 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
@ -476,9 +478,14 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}
// accept only increasing sequence ids
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
if (add) {
cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) {
break;
}

View File

@ -33,6 +33,9 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
const char * LLAMA_HT = getenv("LLAMA_HT");
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;

View File

@ -11,8 +11,9 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
uint32_t n_seq_virt;
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base;
float rope_freq_scale;

View File

@ -1031,6 +1031,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
@ -1079,7 +1083,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
#endif
}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
} else {
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
@ -1124,7 +1128,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
@ -1203,11 +1207,12 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
ggml_set_input(inp->self_kv_idxs);
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);

View File

@ -254,8 +254,8 @@ public:
// TODO: should this be I64?
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -285,10 +285,10 @@ public:
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
const llama_hparams & hparams;
const llama_cparams & cparams;

View File

@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad,
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad,
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
hparams.n_swa, hparams.swa_type);
}
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split
do {
if (n_seq_virt > 1) {
// requires equal splits
break;
}
balloc.split_reset();
std::vector<llama_ubatch> ubatches;

View File

@ -22,6 +22,7 @@ public:
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch,
uint32_t n_pad);
@ -68,6 +69,8 @@ public:
private:
const llama_hparams & hparams;
const uint32_t n_seq_virt = 1;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};

View File

@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
bool offload,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % n_pad == 0);
@ -58,9 +59,27 @@ llama_kv_cache_unified::llama_kv_cache_unified(
return it->second;
};
head = 0;
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
cells.resize(kv_size);
v_heads.resize(n_seq_virt);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
v_heads[s] = 0;
}
v_cells.resize(n_seq_virt);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
v_cells[s].resize(kv_size);
}
// by default, all sequence ids are mapped to the 0th virtual sequence
seq_virt_idx.resize(LLAMA_MAX_SEQ, 0);
if (n_seq_virt > 1) {
seq_virt_idx.resize(n_seq_virt, 0);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
seq_virt_idx[s] = s;
}
}
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
@ -92,8 +111,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
ggml_tensor * k;
ggml_tensor * v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt);
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt);
ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il);
@ -122,8 +141,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
@ -140,9 +159,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
}
void llama_kv_cache_unified::clear(bool data) {
cells.reset();
head = 0;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
v_cells[s].reset();
v_heads[s] = 0;
}
if (data) {
for (auto & buf : bufs) {
@ -152,6 +172,9 @@ void llama_kv_cache_unified::clear(bool data) {
}
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
uint32_t new_head = cells.size();
if (p0 < 0) {
@ -198,30 +221,56 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
}
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) {
const auto s0 = seq_virt_idx[seq_id_src];
const auto s1 = seq_virt_idx[seq_id_dst];
if (s0 == s1) {
auto & cells = v_cells[s0];
if (seq_id_src == seq_id_dst) {
return;
}
if (p0 < 0) {
p0 = 0;
}
if (p1 < 0) {
p1 = std::numeric_limits<llama_pos>::max();
}
for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) {
continue;
}
if (cells.seq_has(i, seq_id_src)) {
cells.seq_add(i, seq_id_dst);
}
}
return;
}
if (p0 < 0) {
p0 = 0;
bool is_full = true;
if (p0 > 0 && p0 + 1 < (int) get_size()) {
is_full = false;
}
if (p1 < 0) {
p1 = std::numeric_limits<llama_pos>::max();
if (p1 > 0 && p1 + 1 < (int) get_size()) {
is_full = false;
}
for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) {
continue;
}
GGML_ASSERT(is_full && "seq_cp() is only supported for full contexts");
if (cells.seq_has(i, seq_id_src)) {
cells.seq_add(i, seq_id_dst);
}
}
GGML_ABORT("TODO: implement\n");
}
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
uint32_t new_head = cells.size();
for (uint32_t i = 0; i < cells.size(); ++i) {
@ -239,6 +288,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
if (shift == 0) {
return;
}
@ -278,6 +330,8 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
}
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
if (d == 1) {
return;
}
@ -307,10 +361,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
const auto & cells = v_cells[seq_virt_idx[seq_id]];
return cells.seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
const auto & cells = v_cells[seq_virt_idx[seq_id]];
return cells.seq_pos_max(seq_id);
}
@ -325,7 +383,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_simple(n_ubatch);
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
@ -356,7 +414,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
defrag_info dinfo;
// see if we need to defrag
{
if (n_seq_virt == 1) {
// note : for now do not consider defrag for n_seq_virt > 1
const auto & cells = v_cells[seq_virt_idx[0]];
bool do_defrag = optimize;
const auto thold = lctx->get_cparams().defrag_thold;
@ -386,16 +447,16 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::slot_info_vec_t res;
struct state {
uint32_t head_old; // old position of the head, before placing the ubatch
struct state_t {
slot_info sinfo; // slot info for the ubatch
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
};
// remember the old state of the cells so we can restore it in the end
std::vector<state> states;
std::vector<state_t> states;
bool success = true;
@ -414,16 +475,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
res.push_back(sinfo_new);
// store the old state of the cells in the recovery stack
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
{
state_t state = { sinfo_new, v_heads, {} };
for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) {
auto & cells = v_cells[sinfo_new.seq_id_virt[s]];
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
}
states.push_back(std::move(state));
}
// now emplace the ubatch
apply_ubatch(sinfo_new, ubatch);
}
GGML_ASSERT(!states.empty());
// iterate backwards and restore the cells to their original state
for (auto it = states.rbegin(); it != states.rend(); ++it) {
cells.set(it->sinfo.idxs, it->cells);
head = it->head_old;
const auto & sinfo = it->sinfo;
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
auto & cells = v_cells[sinfo.seq_id_virt[s]];
auto & head = v_heads[sinfo.seq_id_virt[s]];
cells.set(sinfo.idxs[s], it->v_cells[s]);
head = it->v_heads_old[s];
}
}
if (!success) {
@ -472,12 +552,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
updated = true;
}
cells.reset_shift();
for (uint32_t s = 0; s < n_seq_virt; ++s) {
auto & cells = v_cells[s];
cells.reset_shift();
}
}
if (!dinfo.empty()) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
// note: for now do not consider defrag for n_seq_virt > 1
auto & cells = v_cells[seq_virt_idx[0]];
auto & head = v_heads[seq_virt_idx[0]];
// apply moves:
{
const auto n_kv = dinfo.ids.size();
@ -525,23 +613,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
}
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
head_cur = 0;
}
if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return { };
}
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
const auto & cells = v_cells[seq_virt_idx[1]];
const uint32_t head_cur = v_heads[1];
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
if ((debug == 2 && n_swa > 0) || debug > 2) {
std::string ss;
@ -598,86 +676,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
}
}
uint32_t n_found = 0;
uint32_t n_tested = 0;
uint32_t n_tokens = ubatch.n_tokens;
uint32_t n_seqs = 1;
const uint32_t n_test = cont ? n_tokens : 1;
if (n_seq_virt > 1) {
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
slot_info res;
n_seqs = ubatch.n_seqs_unq;
n_tokens = n_tokens / n_seqs;
}
res.idxs.resize(n_tokens);
slot_info res = {
/*.s0 =*/ LLAMA_MAX_SEQ,
/*.s1 =*/ 0,
/*.seq_id_virt =*/ { },
/*.idxs =*/ { },
};
while (true) {
if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur;
head_cur = 0;
continue;
res.resize(n_seqs);
for (uint32_t s = 0; s < n_seqs; ++s) {
const auto seq_id = ubatch.seq_id_unq[s];
if (n_seq_virt > 1) {
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
}
for (uint32_t i = 0; i < n_test; i++) {
const auto idx = head_cur;
res.s0 = std::min<llama_seq_id>(res.s0, seq_virt_idx[seq_id]);
res.s1 = std::max<llama_seq_id>(res.s1, seq_virt_idx[seq_id]);
//const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
res.seq_id_virt[s] = seq_virt_idx[seq_id];
res.idxs[s].resize(n_tokens);
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(idx);
const auto & cells = v_cells[seq_virt_idx[seq_id]];
if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(idx);
uint32_t head_cur = v_heads[seq_virt_idx[seq_id]];
// (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos;
//}
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head_cur > cells.get_used() + 2*n_tokens) {
head_cur = 0;
}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(idx);
if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return { };
}
// SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
uint32_t n_found = 0;
uint32_t n_tested = 0;
const uint32_t n_test = cont ? n_tokens : 1;
while (true) {
if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur;
head_cur = 0;
continue;
}
for (uint32_t i = 0; i < n_test; i++) {
const auto idx = head_cur;
head_cur++;
n_tested++;
//const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos;
//}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
}
if (can_use) {
res.idxs[s][n_found] = idx;
n_found++;
} else {
if (cont) {
break;
}
}
}
head_cur++;
n_tested++;
if (can_use) {
res.idxs[n_found] = idx;
n_found++;
} else {
if (n_found == n_tokens) {
break;
}
if (cont) {
n_found = 0;
}
if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return { };
}
}
if (n_found == n_tokens) {
break;
}
if (cont) {
n_found = 0;
}
if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
// we didn't find a suitable slot - return empty result
if (n_found < n_tokens) {
return { };
}
}
// we didn't find a suitable slot - return empty result
if (n_found < n_tokens) {
res.clear();
}
return res;
}
@ -685,41 +810,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos_max_rm[s] = -1;
}
assert(ubatch.n_tokens == sinfo.idxs.size());
assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size());
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto idx = sinfo.idxs[i];
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
const uint32_t i = s*sinfo.size() + ii;
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
auto & cells = v_cells[sinfo.seq_id_virt[s]];
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
const auto idx = sinfo.idxs[s][ii];
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
cells.rm(idx);
}
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
cells.pos_set(idx, ubatch.pos[i]);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
cells.rm(idx);
}
cells.pos_set(idx, ubatch.pos[i]);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
}
}
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos_max_rm[s] == -1) {
continue;
}
GGML_ASSERT(s < seq_virt_idx.size());
auto & cells = v_cells[seq_virt_idx[s]];
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@ -729,7 +864,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
}
// move the head at the end of the slot
head = sinfo.idxs.back() + 1;
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
auto & head = v_heads[sinfo.seq_id_virt[s]];
head = sinfo.idxs[s].back() + 1;
}
}
bool llama_kv_cache_unified::get_can_shift() const {
@ -737,49 +876,82 @@ bool llama_kv_cache_unified::get_can_shift() const {
}
uint32_t llama_kv_cache_unified::get_size() const {
const auto & cells = v_cells[seq_virt_idx[0]];
return cells.size();
}
bool llama_kv_cache_unified::get_has_shift() const {
return cells.get_has_shift();
bool result = false;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
result |= v_cells[s].get_has_shift();
}
return result;
}
uint32_t llama_kv_cache_unified::get_n_kv() const {
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
uint32_t result = 0;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
const auto & cells = v_cells[s];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
}
return result;
}
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
return ggml_view_3d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
const auto ns = sinfo.s1 - sinfo.s0 + 1;
assert(ns > 0);
assert(ns <= (int) n_seq_virt);
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
0);
size_virt,
size_virt*sinfo.s0);
}
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v;
const auto ns = sinfo.s1 - sinfo.s0 + 1;
assert(ns > 0);
assert(ns <= n_seq_virt);
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
if (!v_trans) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_3d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
0);
size_virt,
size_virt*sinfo.s0);
}
// note: v->nb[1] > v->nb[2]
return ggml_view_3d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
0);
size_virt,
size_virt*sinfo.s0);
}
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
@ -793,12 +965,16 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (kv_idxs && supports_set_rows) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
@ -818,11 +994,13 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
if (kv_idxs && supports_set_rows) {
if (!v_trans) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]);
// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
@ -842,6 +1020,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
ggml_tensor * v_view = nullptr;
if (!v_trans) {
@ -865,12 +1045,17 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs[i];
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
for (uint32_t i = 0; i < sinfo.size(); ++i) {
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
}
}
}
@ -880,7 +1065,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
const int64_t n_kv = dst->ne[0];
const int64_t n_kv = dst->ne[0];
const int64_t n_seq_virt = dst->ne[2]; // num virtual sequences in the current ubatch
GGML_ASSERT(n_tokens%n_seq_virt == 0);
const int64_t n_tokens_per_seq = n_tokens/n_seq_virt;
const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD);
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@ -895,48 +1086,54 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = ubatch->seq_id[i][0];
for (uint32_t s = 0; s < n_seq_virt; ++s) {
for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) {
const uint32_t i = s*n_tokens_per_seq + ii;
const llama_pos p1 = ubatch->pos[i];
const llama_seq_id seq_id = ubatch->seq_id[i][0];
for (uint32_t j = 0; j < n_kv; ++j) {
float f = 0.0f;
const auto & cells = v_cells[seq_virt_idx[seq_id]];
bool masked = false;
const llama_pos p1 = ubatch->pos[i];
if (cells.is_empty(j)) {
masked = true;
} else {
const llama_pos p0 = cells.pos_get(j);
// mask the token if not the same sequence
masked = masked || (!cells.seq_has(j, seq_id));
// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
}
if (masked) {
f = -INFINITY;
}
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
}
}
// mask padded tokens
if (data) {
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (uint32_t j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
float f = 0.0f;
bool masked = false;
if (cells.is_empty(j)) {
masked = true;
} else {
const llama_pos p0 = cells.pos_get(j);
// mask the token if not the same sequence
masked = masked || (!cells.seq_has(j, seq_id));
// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
}
if (masked) {
f = -INFINITY;
}
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f;
}
// mask padded tokens
if (data) {
for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) {
for (uint32_t j = 0; j < n_kv; ++j) {
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY;
}
}
}
}
}
@ -948,14 +1145,21 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
int32_t * data = (int32_t *) dst->data;
for (uint32_t i = 0; i < cells.size(); ++i) {
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
const auto & cells = v_cells[s];
for (uint32_t i = 0; i < cells.size(); ++i) {
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
}
}
}
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences");
const auto & cells = v_cells[0];
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
@ -1062,7 +1266,7 @@ public:
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * k_shift; // I32 [kv_size]
ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt]
const llama_kv_cache_unified * kv_self;
};
@ -1086,7 +1290,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt);
ggml_set_input(inp->k_shift);
for (const auto & layer : layers) {
@ -1102,7 +1306,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
n_embd_head_k, n_head_kv, cells.size(),
n_embd_head_k, n_head_kv, get_size()*n_seq_virt,
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
0);
@ -1124,6 +1328,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
const defrag_info & dinfo) const {
auto res = std::make_unique<llm_graph_result>();
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
const auto & cells = v_cells[0];
const auto & ids = dinfo.ids;
#if 0
@ -1266,6 +1474,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
}
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
const auto & cells = v_cells[0];
const uint32_t n_layer = layers.size();
const uint32_t n_kv = cells.used_max_p1();
@ -1414,6 +1626,10 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
const auto & cells = v_cells[0];
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1)
uint32_t cell_range_begin = cells.size();
@ -1468,6 +1684,10 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
}
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
const auto & cells = v_cells[0];
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
std::vector<llama_seq_id> seq_ids;
@ -1494,6 +1714,10 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
}
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
const auto & cells = v_cells[0];
const uint32_t v_trans = this->v_trans ? 1 : 0;
const uint32_t n_layer = layers.size();
@ -1581,6 +1805,11 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
}
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
auto & cells = v_cells[0];
auto & head = v_heads[0];
if (dest_seq_id != -1) {
// single sequence
@ -1672,6 +1901,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
}
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
auto & cells = v_cells[0];
auto & head = v_heads[0];
uint32_t v_trans;
uint32_t n_layer;
@ -1809,8 +2043,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
n_kv = kv->get_size();
sinfos.resize(1);
sinfos[0].seq_id_virt.resize(1, 0);
sinfos[0].idxs.resize(1);
sinfos[0].idxs[0] = 0;
sinfos[0].idxs[0].resize(1, 0);
}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@ -1873,11 +2108,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
}
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv);
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv);
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {

View File

@ -39,10 +39,31 @@ public:
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> seq_id_virt;
std::vector<idx_vec_t> idxs;
uint32_t head() const {
return idxs[0];
GGML_ASSERT(idxs.size() == 1);
return idxs[0][0];
}
void resize(size_t n) {
seq_id_virt.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == seq_id_virt.size());
return idxs[0].size();
}
size_t n_seq_virt() const {
return seq_id_virt.size();
}
bool empty() const {
@ -52,9 +73,6 @@ public:
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
@ -68,6 +86,7 @@ public:
bool offload,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
@ -120,8 +139,8 @@ public:
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
@ -169,11 +188,8 @@ private:
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
const uint32_t n_seq_max = 1;
const uint32_t n_seq_virt = 1;
// required padding
const uint32_t n_pad = 1;
@ -193,7 +209,14 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a virtual sequence id
std::vector<uint32_t> seq_virt_idx;
std::vector<kv_layer> layers;

View File

@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload,
kv_size,
n_seq_max,
1,
n_pad,
n_swa,
swa_type

View File

@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.swa_full,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_seq_virt,
cparams.n_ubatch,
padding);
} else {
@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
cparams.n_ctx,
cparams.n_seq_max,
cparams.n_seq_virt,
padding,
hparams.n_swa,
hparams.swa_type);