mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-05 11:33:31 +00:00
batch : optional requirement for sequential sequence ids
ggml-ci
This commit is contained in:
@ -457,7 +457,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||||||
return ubatch_add(idxs, idxs.size(), false);
|
return ubatch_add(idxs, idxs.size(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||||
std::vector<seq_set_t> cur_seq_set;
|
std::vector<seq_set_t> cur_seq_set;
|
||||||
|
|
||||||
llama_seq_id last_seq_id = -1;
|
llama_seq_id last_seq_id = -1;
|
||||||
@ -479,7 +479,9 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// accept only increasing sequence ids
|
// accept only increasing sequence ids
|
||||||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
if (sequential) {
|
||||||
|
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||||
|
}
|
||||||
|
|
||||||
if (add) {
|
if (add) {
|
||||||
cur_seq_set.push_back(seq_set[i]);
|
cur_seq_set.push_back(seq_set[i]);
|
||||||
|
@ -69,7 +69,8 @@ public:
|
|||||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||||
|
|
||||||
// make ubatches of equal-length sequences sets
|
// make ubatches of equal-length sequences sets
|
||||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||||
|
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||||
|
|
||||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||||
|
@ -102,7 +102,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
// first try simple split
|
// first try simple split
|
||||||
do {
|
do {
|
||||||
if (n_seq_virt > 1) {
|
if (n_seq_virt > 1) {
|
||||||
// requires equal splits
|
// requires equal splits, so we skip the simple split
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto ubatch = balloc.split_equal(n_ubatch);
|
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
|
@ -418,7 +418,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch);
|
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
|
@ -71,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||||||
// if all tokens are output, split by sequence
|
// if all tokens are output, split by sequence
|
||||||
ubatch = balloc.split_seq(n_ubatch);
|
ubatch = balloc.split_seq(n_ubatch);
|
||||||
} else {
|
} else {
|
||||||
ubatch = balloc.split_equal(n_ubatch);
|
ubatch = balloc.split_equal(n_ubatch, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
|
@ -372,7 +372,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||||||
// if all tokens are output, split by sequence
|
// if all tokens are output, split by sequence
|
||||||
ubatch = balloc.split_seq(n_ubatch);
|
ubatch = balloc.split_seq(n_ubatch);
|
||||||
} else {
|
} else {
|
||||||
ubatch = balloc.split_equal(n_ubatch);
|
ubatch = balloc.split_equal(n_ubatch, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
|
Reference in New Issue
Block a user