mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
kv-cache : fix split_equal handling in unified implementation (#14130)
ggml-ci
This commit is contained in:
@ -877,6 +877,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
||||||
|
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
|
// TODO: the seuqence indexing here is likely not correct in the general case
|
||||||
|
// probably works only for split_simple
|
||||||
cross.seq_ids_enc.resize(n_tokens);
|
cross.seq_ids_enc.resize(n_tokens);
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
cross.seq_ids_enc[i].clear();
|
cross.seq_ids_enc[i].clear();
|
||||||
|
@ -98,33 +98,66 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
||||||
GGML_UNUSED(embd_pooled);
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
// TODO: if we fail with split_simple, we should attempt different splitting strategies
|
// first try simple split
|
||||||
|
do {
|
||||||
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
while (sbatch.n_tokens > 0) {
|
||||||
|
auto ubatch = sbatch.split_simple(n_ubatch);
|
||||||
|
|
||||||
|
ubatches.push_back(ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_base = kv_base->prepare(ubatches);
|
||||||
|
if (heads_base.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_swa = kv_swa->prepare(ubatches);
|
||||||
|
if (heads_swa.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||||
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
|
} while (false);
|
||||||
|
|
||||||
|
// if it fails, try equal split
|
||||||
|
do {
|
||||||
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
while (sbatch.n_tokens > 0) {
|
||||||
|
auto ubatch = sbatch.split_equal(n_ubatch);
|
||||||
|
|
||||||
|
ubatches.push_back(ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_base = kv_base->prepare(ubatches);
|
||||||
|
if (heads_base.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto heads_swa = kv_swa->prepare(ubatches);
|
||||||
|
if (heads_swa.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||||
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
|
} while (false);
|
||||||
|
|
||||||
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||||
|
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
|
||||||
|
|
||||||
while (sbatch.n_tokens > 0) {
|
|
||||||
auto ubatch = sbatch.split_simple(n_ubatch);
|
|
||||||
|
|
||||||
ubatches.push_back(ubatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
|
||||||
if (heads_base.empty()) {
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
|
||||||
if (heads_swa.empty()) {
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
||||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||||
|
@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|||||||
bool logits_all) {
|
bool logits_all) {
|
||||||
GGML_UNUSED(embd_pooled);
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
do {
|
||||||
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (sbatch.n_tokens > 0) {
|
while (sbatch.n_tokens > 0) {
|
||||||
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads = prepare(ubatches);
|
auto heads = prepare(ubatches);
|
||||||
if (heads.empty()) {
|
if (heads.empty()) {
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(
|
return std::make_unique<llama_kv_cache_unified_state>(
|
||||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||||
|
} while (false);
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||||
@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
LLAMA_LOG_CONT("\n");
|
|
||||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
||||||
|
|
||||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||||
@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
if (cells.is_empty(i)) {
|
if (cells.is_empty(i)) {
|
||||||
ss += '.';
|
ss += '.';
|
||||||
} else {
|
} else {
|
||||||
ss += std::to_string(cells.seq_get(i));
|
assert(cells.seq_count(i) >= 1);
|
||||||
|
|
||||||
|
if (cells.seq_count(i) == 1) {
|
||||||
|
ss += std::to_string(cells.seq_get(i));
|
||||||
|
} else {
|
||||||
|
ss += 'M';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (i%256 == 255) {
|
if (i%256 == 255) {
|
||||||
ss += " *";
|
ss += " *";
|
||||||
@ -636,6 +645,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
||||||
|
if (debug > 0) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
||||||
|
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
||||||
|
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
||||||
|
}
|
||||||
|
|
||||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
// for non-SWA cache, this would be always empty
|
// for non-SWA cache, this would be always empty
|
||||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||||
@ -643,22 +658,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
if (!cells.is_empty(head_cur + i)) {
|
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
|
||||||
assert(cells.seq_count(head_cur + i) == 1);
|
const uint32_t idx = s*ubatch.n_seq_tokens + j;
|
||||||
|
|
||||||
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
if (!cells.is_empty(head_cur + idx)) {
|
||||||
const llama_pos pos = cells.pos_get(head_cur + i);
|
assert(cells.seq_count(head_cur + idx) == 1);
|
||||||
|
|
||||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
|
||||||
|
const llama_pos pos = cells.pos_get(head_cur + idx);
|
||||||
|
|
||||||
cells.rm(head_cur + i);
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||||
}
|
|
||||||
|
|
||||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
cells.rm(head_cur + idx);
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
|
||||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
|
|
||||||
|
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
|
||||||
|
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
head = head_cur + ubatch.n_tokens;
|
head = head_cur + ubatch.n_tokens;
|
||||||
}
|
}
|
||||||
@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||||
const int64_t n_seqs = ubatch->n_seqs;
|
const uint32_t n_seqs = ubatch->n_seqs;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
float * data = (float *) dst->data;
|
float * data = (float *) dst->data;
|
||||||
|
|
||||||
const auto n_kv = dst->ne[0];
|
const int64_t n_kv = dst->ne[0];
|
||||||
|
|
||||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||||
@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
// xxxxx-----
|
// xxxxx-----
|
||||||
// xxxxx-----
|
// xxxxx-----
|
||||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (uint32_t h = 0; h < 1; ++h) {
|
||||||
for (int s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||||
|
|
||||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
|
||||||
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
const uint32_t idx = s*n_seq_tokens + j;
|
||||||
|
|
||||||
|
const llama_pos p1 = ubatch->pos[idx];
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
float f = 0.0f;
|
float f = 0.0f;
|
||||||
@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mask padded tokens
|
// mask padded tokens
|
||||||
if (data) {
|
if (data) {
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
|
||||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
seq_rm(dest_seq_id, -1, -1);
|
seq_rm(dest_seq_id, -1, -1);
|
||||||
|
|
||||||
llama_sbatch sbatch;
|
llama_sbatch sbatch;
|
||||||
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||||
|
|
||||||
batch.n_tokens = cell_count;
|
ubatch.n_tokens = cell_count;
|
||||||
|
ubatch.n_seq_tokens = cell_count;
|
||||||
|
ubatch.n_seqs = 1;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
llama_pos pos;
|
llama_pos pos;
|
||||||
@ -1512,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
io.read_to(&seq_id, sizeof(seq_id));
|
io.read_to(&seq_id, sizeof(seq_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.pos[i] = pos;
|
ubatch.pos[i] = pos;
|
||||||
batch.n_seq_id[i] = n_seq_id;
|
ubatch.n_seq_id[i] = n_seq_id;
|
||||||
batch.seq_id[i] = &dest_seq_id;
|
ubatch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto head_cur = find_slot(batch);
|
const auto head_cur = find_slot(ubatch);
|
||||||
if (head_cur < 0) {
|
if (head_cur < 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
apply_ubatch(head_cur, batch);
|
apply_ubatch(head_cur, ubatch);
|
||||||
|
|
||||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||||
head = head_cur;
|
head = head_cur;
|
||||||
@ -1531,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||||
// Assume that this is one contiguous block of cells
|
// Assume that this is one contiguous block of cells
|
||||||
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
||||||
GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
|
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
||||||
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
|
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
||||||
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
||||||
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
||||||
} else {
|
} else {
|
||||||
|
Reference in New Issue
Block a user