mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 12:25:03 +00:00
batch : require non-coupled batch with sequential split_equal
ggml-ci
This commit is contained in:
@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
|||||||
|
|
||||||
// note: tracking the other way around is not necessary for now
|
// note: tracking the other way around is not necessary for now
|
||||||
//seq_cpl[s0][s1] = true;
|
//seq_cpl[s0][s1] = true;
|
||||||
|
|
||||||
|
has_cpl = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|||||||
return n_outputs;
|
return n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_batch_allocr::get_n_used() const {
|
||||||
|
return n_used;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||||
return out_ids;
|
return out_ids;
|
||||||
}
|
}
|
||||||
@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
void llama_batch_allocr::split_reset() {
|
void llama_batch_allocr::split_reset() {
|
||||||
out_ids.clear();
|
out_ids.clear();
|
||||||
|
|
||||||
|
n_used = 0;
|
||||||
|
|
||||||
used.clear();
|
used.clear();
|
||||||
used.resize(get_n_tokens(), false);
|
used.resize(get_n_tokens(), false);
|
||||||
|
|
||||||
@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||||||
idxs.push_back(cur_idx);
|
idxs.push_back(cur_idx);
|
||||||
|
|
||||||
used[cur_idx] = true;
|
used[cur_idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
++cur_idx;
|
++cur_idx;
|
||||||
|
|
||||||
@ -458,6 +467,12 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||||
|
if (sequential && has_cpl) {
|
||||||
|
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<seq_set_t> cur_seq_set;
|
std::vector<seq_set_t> cur_seq_set;
|
||||||
|
|
||||||
llama_seq_id last_seq_id = -1;
|
llama_seq_id last_seq_id = -1;
|
||||||
@ -536,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential)
|
|||||||
idxs_per_seq[s].push_back(idx);
|
idxs_per_seq[s].push_back(idx);
|
||||||
|
|
||||||
used[idx] = true;
|
used[idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
++cur_idx[s];
|
++cur_idx[s];
|
||||||
}
|
}
|
||||||
@ -577,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|||||||
idxs.push_back(cur_idx);
|
idxs.push_back(cur_idx);
|
||||||
|
|
||||||
used[cur_idx] = true;
|
used[cur_idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
if (idxs.size() >= n_ubatch) {
|
if (idxs.size() >= n_ubatch) {
|
||||||
break;
|
break;
|
||||||
|
@ -54,6 +54,7 @@ public:
|
|||||||
|
|
||||||
uint32_t get_n_tokens() const;
|
uint32_t get_n_tokens() const;
|
||||||
uint32_t get_n_outputs() const;
|
uint32_t get_n_outputs() const;
|
||||||
|
uint32_t get_n_used() const;
|
||||||
|
|
||||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||||
std::vector<int32_t> & get_out_ids();
|
std::vector<int32_t> & get_out_ids();
|
||||||
@ -113,6 +114,9 @@ private:
|
|||||||
using pos_set_t = std::set<llama_pos>;
|
using pos_set_t = std::set<llama_pos>;
|
||||||
using seq_cpl_t = std::vector<bool>;
|
using seq_cpl_t = std::vector<bool>;
|
||||||
|
|
||||||
|
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||||
|
bool has_cpl;
|
||||||
|
|
||||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||||
|
|
||||||
@ -126,6 +130,8 @@ private:
|
|||||||
// batch indices of the output
|
// batch indices of the output
|
||||||
std::vector<int32_t> out_ids;
|
std::vector<int32_t> out_ids;
|
||||||
|
|
||||||
|
uint32_t n_used;
|
||||||
|
|
||||||
// used[i] indicates if token i has already been used in a previous ubatch
|
// used[i] indicates if token i has already been used in a previous ubatch
|
||||||
std::vector<bool> used;
|
std::vector<bool> used;
|
||||||
|
|
||||||
|
@ -119,6 +119,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
auto sinfos_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (sinfos_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
@ -150,6 +155,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
auto sinfos_base = kv_base->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (sinfos_base.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
|
@ -427,6 +427,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
auto sinfos = prepare(ubatches);
|
auto sinfos = prepare(ubatches);
|
||||||
if (sinfos.empty()) {
|
if (sinfos.empty()) {
|
||||||
break;
|
break;
|
||||||
|
@ -81,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// prepare the recurrent batches first
|
// prepare the recurrent batches first
|
||||||
if (!mem_recr->prepare(ubatches)) {
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||||
|
@ -365,6 +365,9 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
do {
|
||||||
|
balloc.split_reset();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
llama_ubatch ubatch;
|
llama_ubatch ubatch;
|
||||||
|
|
||||||
@ -382,9 +385,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (!prepare(ubatches)) {
|
if (!prepare(ubatches)) {
|
||||||
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user