mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|||||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (logits_all) {
|
if (batch->logits) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
|
||||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
|
||||||
out_ids.push_back(ids[seq.offset + i]);
|
|
||||||
}
|
|
||||||
} else if (batch->logits) {
|
|
||||||
if (ubatch.equal_seqs) {
|
if (ubatch.equal_seqs) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
size_t id = ids[seq.offset + i];
|
size_t id = ids[seq.offset + i];
|
||||||
@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
|||||||
return ubatch;
|
return ubatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
|
||||||
GGML_ASSERT(batch.n_tokens >= 0);
|
GGML_ASSERT(batch.n_tokens >= 0);
|
||||||
this->batch = &batch;
|
this->batch = &batch;
|
||||||
this->n_embd = n_embd;
|
this->n_embd = n_embd;
|
||||||
this->logits_all = logits_all;
|
|
||||||
|
|
||||||
n_tokens = batch.n_tokens;
|
n_tokens = batch.n_tokens;
|
||||||
ids.resize(n_tokens);
|
ids.resize(n_tokens);
|
||||||
|
@ -39,8 +39,6 @@ struct llama_sbatch {
|
|||||||
|
|
||||||
size_t n_embd;
|
size_t n_embd;
|
||||||
|
|
||||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
|
||||||
|
|
||||||
// sorted indices into the batch
|
// sorted indices into the batch
|
||||||
std::vector<int64_t> ids;
|
std::vector<int64_t> ids;
|
||||||
// batch indices of the output
|
// batch indices of the output
|
||||||
@ -76,7 +74,7 @@ struct llama_sbatch {
|
|||||||
llama_ubatch split_seq(size_t n_ubatch);
|
llama_ubatch split_seq(size_t n_ubatch);
|
||||||
|
|
||||||
llama_sbatch() = default;
|
llama_sbatch() = default;
|
||||||
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
|
||||||
};
|
};
|
||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
|
@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
||||||
|
|
||||||
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
||||||
|
|
||||||
@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
llama_memory_state_ptr mstate;
|
llama_memory_state_ptr mstate;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||||
if (!mstate) {
|
if (!mstate) {
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter(
|
|||||||
|
|
||||||
int64_t n_outputs_all = n_tokens_all;
|
int64_t n_outputs_all = n_tokens_all;
|
||||||
|
|
||||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||||
break;
|
break;
|
||||||
|
@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||||
GGML_UNUSED(embd_pooled);
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
@ -32,8 +32,7 @@ public:
|
|||||||
llama_memory_state_ptr init_batch(
|
llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_pooled,
|
bool embd_pooled) override;
|
||||||
bool logits_all) override;
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
return kv_swa->seq_pos_max(seq_id);
|
return kv_swa->seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||||
GGML_UNUSED(embd_pooled);
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
// first try simple split
|
// first try simple split
|
||||||
do {
|
do {
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|||||||
|
|
||||||
// if it fails, try equal split
|
// if it fails, try equal split
|
||||||
do {
|
do {
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
@ -34,8 +34,7 @@ public:
|
|||||||
llama_memory_state_ptr init_batch(
|
llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_pooled,
|
bool embd_pooled) override;
|
||||||
bool logits_all) override;
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_pooled,
|
bool embd_pooled) {
|
||||||
bool logits_all) {
|
|
||||||
GGML_UNUSED(embd_pooled);
|
GGML_UNUSED(embd_pooled);
|
||||||
|
|
||||||
do {
|
do {
|
||||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (sbatch.n_tokens > 0) {
|
while (sbatch.n_tokens > 0) {
|
||||||
|
@ -59,8 +59,7 @@ public:
|
|||||||
llama_memory_state_ptr init_batch(
|
llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_pooled,
|
bool embd_pooled) override;
|
||||||
bool logits_all) override;
|
|
||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
|
@ -73,8 +73,7 @@ struct llama_memory_i {
|
|||||||
virtual llama_memory_state_ptr init_batch(
|
virtual llama_memory_state_ptr init_batch(
|
||||||
const llama_batch & batch,
|
const llama_batch & batch,
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
bool embd_pooled,
|
bool embd_pooled) = 0;
|
||||||
bool logits_all) = 0;
|
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
virtual llama_memory_state_ptr init_full() = 0;
|
virtual llama_memory_state_ptr init_full() = 0;
|
||||||
|
Reference in New Issue
Block a user