mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-01 13:05:52 +00:00
@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
auto ubatch = sbatch.split_simple(n_ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
@ -123,19 +126,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// if it fails, try equal split
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
auto ubatch = sbatch.split_equal(n_ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
@ -151,7 +157,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
@ -214,15 +220,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
||||
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
||||
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||
}
|
||||
|
||||
@ -252,12 +256,6 @@ bool llama_kv_cache_unified_iswa_state::apply() {
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
Reference in New Issue
Block a user