mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
cont : migrate to using set of indices instead of slot head
ggml-ci
This commit is contained in:
@ -113,20 +113,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_swa.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// if it fails, try equal split
|
// if it fails, try equal split
|
||||||
@ -144,20 +144,20 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_swa.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// TODO: if we fail again, we should attempt different splitting strategies
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
@ -220,13 +220,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|||||||
|
|
||||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,6 +74,8 @@ private:
|
|||||||
|
|
||||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
@ -90,8 +92,8 @@ public:
|
|||||||
// used to create a batch processing context from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_iswa_context();
|
virtual ~llama_kv_cache_unified_iswa_context();
|
||||||
|
@ -334,13 +334,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads = prepare(ubatches);
|
auto sinfos = prepare(ubatches);
|
||||||
if (heads.empty()) {
|
if (sinfos.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(
|
return std::make_unique<llama_kv_cache_unified_context>(
|
||||||
this, std::move(heads), std::move(ubatches));
|
this, std::move(sinfos), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
@ -383,8 +383,8 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|||||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
llama_kv_cache_unified::ubatch_heads res;
|
llama_kv_cache_unified::slot_info_vec_t res;
|
||||||
|
|
||||||
struct state {
|
struct state {
|
||||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||||
@ -400,20 +400,25 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
|
|||||||
|
|
||||||
for (const auto & ubatch : ubatches) {
|
for (const auto & ubatch : ubatches) {
|
||||||
// only find a suitable slot for the ubatch. don't modify the cells yet
|
// only find a suitable slot for the ubatch. don't modify the cells yet
|
||||||
const int32_t head_new = find_slot(ubatch);
|
const auto sinfo_new = find_slot(ubatch);
|
||||||
if (head_new < 0) {
|
if (sinfo_new.empty()) {
|
||||||
success = false;
|
success = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// remeber the position that we found
|
// remeber the position that we found
|
||||||
res.push_back(head_new);
|
res.push_back(sinfo_new);
|
||||||
|
|
||||||
|
// TODO: temporary
|
||||||
|
if (supports_set_rows) {
|
||||||
|
GGML_ASSERT(sinfo_new.is_cont());
|
||||||
|
}
|
||||||
|
|
||||||
// store the old state of the cells in the recovery stack
|
// store the old state of the cells in the recovery stack
|
||||||
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
|
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)});
|
||||||
|
|
||||||
// now emplace the ubatch
|
// now emplace the ubatch
|
||||||
apply_ubatch(head_new, ubatch);
|
apply_ubatch(sinfo_new, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate backwards and restore the cells to their original state
|
// iterate backwards and restore the cells to their original state
|
||||||
@ -520,7 +525,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
uint32_t head_cur = this->head;
|
uint32_t head_cur = this->head;
|
||||||
@ -533,7 +538,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
|
|
||||||
if (n_tokens > cells.size()) {
|
if (n_tokens > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
@ -649,14 +654,21 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
|
|
||||||
if (n_tested >= cells.size()) {
|
if (n_tested >= cells.size()) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return head_cur;
|
slot_info res;
|
||||||
|
|
||||||
|
res.idxs.resize(n_tokens);
|
||||||
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
res.idxs[i] = head_cur + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
// for non-SWA cache, this would be always empty
|
// for non-SWA cache, this would be always empty
|
||||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||||
@ -664,22 +676,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
assert(ubatch.n_tokens == sinfo.idxs.size());
|
||||||
if (!cells.is_empty(head_cur + i)) {
|
|
||||||
assert(cells.seq_count(head_cur + i) == 1);
|
|
||||||
|
|
||||||
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
const llama_pos pos = cells.pos_get(head_cur + i);
|
const auto idx = sinfo.idxs[i];
|
||||||
|
|
||||||
|
if (!cells.is_empty(idx)) {
|
||||||
|
assert(cells.seq_count(idx) == 1);
|
||||||
|
|
||||||
|
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||||
|
const llama_pos pos = cells.pos_get(idx);
|
||||||
|
|
||||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||||
|
|
||||||
cells.rm(head_cur + i);
|
cells.rm(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
cells.pos_set(idx, ubatch.pos[i]);
|
||||||
|
|
||||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -700,7 +716,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
}
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
head = head_cur + ubatch.n_tokens;
|
head = sinfo.idxs.back() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::get_can_shift() const {
|
bool llama_kv_cache_unified::get_can_shift() const {
|
||||||
@ -753,7 +769,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * k = layers[ikv].k;
|
auto * k = layers[ikv].k;
|
||||||
@ -772,12 +788,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||||
n_tokens*n_embd_k_gqa,
|
n_tokens*n_embd_k_gqa,
|
||||||
ggml_row_size(k->type, n_embd_k_gqa)*head_cur);
|
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
||||||
|
|
||||||
return ggml_cpy(ctx, k_cur, k_view);
|
return ggml_cpy(ctx, k_cur, k_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * v = layers[ikv].v;
|
auto * v = layers[ikv].v;
|
||||||
@ -814,19 +830,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
v_view = ggml_view_1d(ctx, v,
|
v_view = ggml_view_1d(ctx, v,
|
||||||
n_tokens*n_embd_v_gqa,
|
n_tokens*n_embd_v_gqa,
|
||||||
ggml_row_size(v->type, n_embd_v_gqa)*head_cur);
|
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
|
||||||
} else {
|
} else {
|
||||||
v_cur = ggml_transpose(ctx, v_cur);
|
v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
|
||||||
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
||||||
(v->ne[1])*ggml_element_size(v),
|
(v->ne[1] )*ggml_element_size(v),
|
||||||
(head_cur)*ggml_element_size(v));
|
(sinfo.head())*ggml_element_size(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_cpy(ctx, v_cur, v_view);
|
return ggml_cpy(ctx, v_cur, v_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
|
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
if (!supports_set_rows) {
|
if (!supports_set_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -837,7 +853,7 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
|
|||||||
int64_t * data = (int64_t *) dst->data;
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
data[i] = head_cur + i;
|
data[i] = sinfo.idxs[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1580,13 +1596,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
ubatch.seq_id[i] = &dest_seq_id;
|
ubatch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto head_cur = find_slot(ubatch);
|
const auto sinfo = find_slot(ubatch);
|
||||||
if (head_cur < 0) {
|
if (sinfo.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
apply_ubatch(head_cur, ubatch);
|
apply_ubatch(sinfo, ubatch);
|
||||||
|
|
||||||
|
const auto head_cur = sinfo.head();
|
||||||
|
|
||||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||||
head = head_cur;
|
head = head_cur;
|
||||||
@ -1772,7 +1790,10 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
|
|||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
head = 0;
|
|
||||||
|
sinfos.resize(1);
|
||||||
|
sinfos[0].idxs.resize(1);
|
||||||
|
sinfos[0].idxs[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
@ -1787,8 +1808,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified::ubatch_heads heads,
|
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
@ -1796,7 +1817,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|||||||
bool llama_kv_cache_unified_context::next() {
|
bool llama_kv_cache_unified_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_cur >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1813,10 +1834,9 @@ bool llama_kv_cache_unified_context::apply() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
||||||
|
|
||||||
n_kv = kv->get_n_kv();
|
n_kv = kv->get_n_kv();
|
||||||
head = heads[i_next];
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1828,7 +1848,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
|||||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_cur];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||||
@ -1844,11 +1864,11 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
|
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
|
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
@ -1856,7 +1876,7 @@ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
kv->set_input_kv_idxs(dst, ubatch, head);
|
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
|
@ -24,8 +24,6 @@ public:
|
|||||||
// this callback is used to filter out layers that should not be included in the cache
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
using ubatch_heads = std::vector<uint32_t>;
|
|
||||||
|
|
||||||
struct defrag_info {
|
struct defrag_info {
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return ids.empty();
|
return ids.empty();
|
||||||
@ -37,6 +35,40 @@ public:
|
|||||||
std::vector<uint32_t> ids;
|
std::vector<uint32_t> ids;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct slot_info {
|
||||||
|
// data for ggml_set_rows
|
||||||
|
using idx_vec_t = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
idx_vec_t idxs;
|
||||||
|
|
||||||
|
uint32_t head() const {
|
||||||
|
return idxs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return idxs.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: tmp until kv cells support non-cont slots
|
||||||
|
bool is_cont() const {
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
for (uint32_t i = 1; i < idxs.size(); ++i) {
|
||||||
|
if (idxs[i] != idxs[i - 1] + 1) {
|
||||||
|
res = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
//std::vector<idx_vec_t> seq_idxs;
|
||||||
|
};
|
||||||
|
|
||||||
|
using slot_info_vec_t = std::vector<slot_info>;
|
||||||
|
|
||||||
llama_kv_cache_unified(
|
llama_kv_cache_unified(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
layer_filter_cb && filter,
|
layer_filter_cb && filter,
|
||||||
@ -102,31 +134,36 @@ public:
|
|||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// preparation API
|
// preparation API
|
||||||
//
|
//
|
||||||
|
|
||||||
// find places for the provided ubatches in the cache, returns the head locations
|
// find places for the provided ubatches in the cache, returns the slot infos
|
||||||
// return empty vector on failure
|
// return empty vector on failure
|
||||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
|
// find a continuous slot of kv cells that can hold the ubatch
|
||||||
// return the cell position where we can insert the ubatch
|
// return the cell position where we can insert the ubatch
|
||||||
// return -1 on failure to find a contiguous slot of kv cells
|
// return -1 on failure to find a slot
|
||||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
slot_info find_slot(const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
// find a set of kv cells that can hold the ubatch
|
||||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
// TODO: implement
|
||||||
|
//slot_info find_slot_ext(const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||||
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||||
|
|
||||||
//
|
//
|
||||||
// set_input API
|
// set_input API
|
||||||
//
|
//
|
||||||
|
|
||||||
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const;
|
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
@ -217,8 +254,8 @@ private:
|
|||||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// some shorthands
|
// some shorthands
|
||||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_context(llama_memory_status status);
|
llama_kv_cache_unified_context(llama_memory_status status);
|
||||||
@ -237,7 +274,7 @@ public:
|
|||||||
// used to create a batch procesing context from a batch
|
// used to create a batch procesing context from a batch
|
||||||
llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
ubatch_heads heads,
|
slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_context();
|
virtual ~llama_kv_cache_unified_context();
|
||||||
@ -290,10 +327,10 @@ private:
|
|||||||
// batch processing context
|
// batch processing context
|
||||||
//
|
//
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the cur ubatch to process
|
||||||
size_t i_next = 0;
|
size_t i_cur = 0;
|
||||||
|
|
||||||
ubatch_heads heads;
|
slot_info_vec_t sinfos;
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
@ -304,7 +341,4 @@ private:
|
|||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// as the cache gets filled, the benefit from this heuristic disappears
|
// as the cache gets filled, the benefit from this heuristic disappears
|
||||||
int32_t n_kv;
|
int32_t n_kv;
|
||||||
|
|
||||||
// the beginning of the current slot in which the ubatch will be inserted
|
|
||||||
int32_t head;
|
|
||||||
};
|
};
|
||||||
|
@ -195,11 +195,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|||||||
|
|
||||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,8 @@ private:
|
|||||||
|
|
||||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||||
|
|
||||||
@ -107,7 +109,7 @@ public:
|
|||||||
// init success
|
// init success
|
||||||
llama_memory_hybrid_context(
|
llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
~llama_memory_hybrid_context() = default;
|
~llama_memory_hybrid_context() = default;
|
||||||
|
Reference in New Issue
Block a user