llama : reorder encode/decode in sources

This commit is contained in:
Georgi Gerganov
2025-02-18 14:47:53 +02:00
parent bc6f187e9c
commit befe14f06f
2 changed files with 174 additions and 174 deletions

View File

@ -1655,6 +1655,168 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
return llama_context::graph_init();
}
int llama_context_kv_self::encode(llama_batch & inp_batch) {
is_encoding = true;
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
n_queued_tokens += n_tokens;
const int64_t n_embd = hparams.n_embd;
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
};
for (int32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
inp_embd_enc = NULL;
n_outputs = n_tokens;
//batch_manager->prepare(ubatch);
// TODO: do reserve
GGML_ASSERT(need_reserve == false);
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
auto ctx = graph_init();
auto res = graph_build(ctx, ubatch, false);
auto * gf = res.gf;
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
// extract embeddings
if (t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
if (llama_model_has_decoder(&model)) {
embd_enc.resize(n_tokens*n_embd);
float * embd_out = embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
// remember the sequence ids used during the encoding - needed for cross attention later
seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (int32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
return 0;
}
int llama_context_kv_self::decode(llama_batch & inp_batch) {
is_encoding = false;
@ -2020,168 +2182,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
return 0;
}
int llama_context_kv_self::encode(llama_batch & inp_batch) {
is_encoding = true;
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
n_queued_tokens += n_tokens;
const int64_t n_embd = hparams.n_embd;
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
};
for (int32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
inp_embd_enc = NULL;
n_outputs = n_tokens;
//batch_manager->prepare(ubatch);
// TODO: do reserve
GGML_ASSERT(need_reserve == false);
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
auto ctx = graph_init();
auto res = graph_build(ctx, ubatch, false);
auto * gf = res.gf;
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
// extract embeddings
if (t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
if (llama_model_has_decoder(&model)) {
embd_enc.resize(n_tokens*n_embd);
float * embd_out = embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
// remember the sequence ids used during the encoding - needed for cross attention later
seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
seq_ids_enc[i].insert(seq_id);
}
}
} else {
GGML_ASSERT(embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (int32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
return 0;
}
llama_pos llama_context_kv_self::pos_max() const {
return kv_self.pos_max();
}

View File

@ -116,6 +116,17 @@ struct llama_context : public llama_graph_i {
// TODO: maybe remove this
virtual void output_reorder();
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
@ -130,17 +141,6 @@ struct llama_context : public llama_graph_i {
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
//
// graph build API (generic)
//
@ -336,8 +336,8 @@ public:
virtual void input_set(const llama_ubatch & ubatch) override;
virtual int decode(llama_batch & inp_batch) override;
virtual int encode(llama_batch & inp_batch) override;
virtual int decode(llama_batch & inp_batch) override;
// max token position across all sequences in the current context
llama_pos pos_max() const;