mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-14 22:58:10 +00:00
llama : reorder encode/decode in sources
This commit is contained in:
@ -1655,6 +1655,168 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
|
|||||||
return llama_context::graph_init();
|
return llama_context::graph_init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||||
|
is_encoding = true;
|
||||||
|
|
||||||
|
if (inp_batch.n_tokens == 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||||
|
|
||||||
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
|
if (batch.token) {
|
||||||
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||||
|
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||||
|
|
||||||
|
if (t_compute_start_us == 0) {
|
||||||
|
t_compute_start_us = ggml_time_us();
|
||||||
|
}
|
||||||
|
|
||||||
|
n_queued_tokens += n_tokens;
|
||||||
|
|
||||||
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
|
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||||
|
|
||||||
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
||||||
|
|
||||||
|
// reserve output buffer
|
||||||
|
if (output_reserve(n_tokens) < n_tokens) {
|
||||||
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||||
|
return -2;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
output_ids[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
inp_embd_enc = NULL;
|
||||||
|
n_outputs = n_tokens;
|
||||||
|
|
||||||
|
//batch_manager->prepare(ubatch);
|
||||||
|
|
||||||
|
// TODO: do reserve
|
||||||
|
GGML_ASSERT(need_reserve == false);
|
||||||
|
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
auto ctx = graph_init();
|
||||||
|
auto res = graph_build(ctx, ubatch, false);
|
||||||
|
|
||||||
|
auto * gf = res.gf;
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||||
|
|
||||||
|
input_set(ubatch);
|
||||||
|
|
||||||
|
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||||
|
switch (compute_status) {
|
||||||
|
case GGML_STATUS_SUCCESS:
|
||||||
|
break;
|
||||||
|
case GGML_STATUS_ABORTED:
|
||||||
|
return 2;
|
||||||
|
case GGML_STATUS_ALLOC_FAILED:
|
||||||
|
return -2;
|
||||||
|
case GGML_STATUS_FAILED:
|
||||||
|
default:
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||||
|
|
||||||
|
// extract embeddings
|
||||||
|
if (t_embd) {
|
||||||
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||||
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
|
if (llama_model_has_decoder(&model)) {
|
||||||
|
embd_enc.resize(n_tokens*n_embd);
|
||||||
|
float * embd_out = embd_enc.data();
|
||||||
|
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||||
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||||
|
|
||||||
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
|
seq_ids_enc.resize(n_tokens);
|
||||||
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
|
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
|
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||||
|
seq_ids_enc[i].insert(seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(embd != nullptr);
|
||||||
|
|
||||||
|
switch (cparams.pooling_type) {
|
||||||
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
|
{
|
||||||
|
// extract token embeddings
|
||||||
|
GGML_ASSERT(embd != nullptr);
|
||||||
|
float * embd_out = embd;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
|
{
|
||||||
|
// extract sequence embeddings
|
||||||
|
auto & embd_seq_out = embd_seq;
|
||||||
|
embd_seq_out.clear();
|
||||||
|
|
||||||
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
embd_seq_out[seq_id].resize(n_embd);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
|
{
|
||||||
|
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||||
|
// wait for an encoder model that requires this pooling type in order to test it
|
||||||
|
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||||
|
GGML_ABORT("RANK pooling not implemented yet");
|
||||||
|
}
|
||||||
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
|
{
|
||||||
|
GGML_ABORT("unknown pooling type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
|
// overlap with device computation.
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||||
is_encoding = false;
|
is_encoding = false;
|
||||||
|
|
||||||
@ -2020,168 +2182,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|
||||||
is_encoding = true;
|
|
||||||
|
|
||||||
if (inp_batch.n_tokens == 0) {
|
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
||||||
|
|
||||||
if (batch.token) {
|
|
||||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
||||||
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
||||||
|
|
||||||
if (t_compute_start_us == 0) {
|
|
||||||
t_compute_start_us = ggml_time_us();
|
|
||||||
}
|
|
||||||
|
|
||||||
n_queued_tokens += n_tokens;
|
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
|
||||||
|
|
||||||
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
||||||
|
|
||||||
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
||||||
|
|
||||||
// reserve output buffer
|
|
||||||
if (output_reserve(n_tokens) < n_tokens) {
|
|
||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
||||||
return -2;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
|
||||||
output_ids[i] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
inp_embd_enc = NULL;
|
|
||||||
n_outputs = n_tokens;
|
|
||||||
|
|
||||||
//batch_manager->prepare(ubatch);
|
|
||||||
|
|
||||||
// TODO: do reserve
|
|
||||||
GGML_ASSERT(need_reserve == false);
|
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
||||||
|
|
||||||
auto ctx = graph_init();
|
|
||||||
auto res = graph_build(ctx, ubatch, false);
|
|
||||||
|
|
||||||
auto * gf = res.gf;
|
|
||||||
|
|
||||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
||||||
|
|
||||||
input_set(ubatch);
|
|
||||||
|
|
||||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
|
||||||
switch (compute_status) {
|
|
||||||
case GGML_STATUS_SUCCESS:
|
|
||||||
break;
|
|
||||||
case GGML_STATUS_ABORTED:
|
|
||||||
return 2;
|
|
||||||
case GGML_STATUS_ALLOC_FAILED:
|
|
||||||
return -2;
|
|
||||||
case GGML_STATUS_FAILED:
|
|
||||||
default:
|
|
||||||
return -3;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
|
||||||
|
|
||||||
// extract embeddings
|
|
||||||
if (t_embd) {
|
|
||||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
||||||
GGML_ASSERT(backend_embd != nullptr);
|
|
||||||
|
|
||||||
if (llama_model_has_decoder(&model)) {
|
|
||||||
embd_enc.resize(n_tokens*n_embd);
|
|
||||||
float * embd_out = embd_enc.data();
|
|
||||||
|
|
||||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
|
||||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
||||||
|
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
|
||||||
seq_ids_enc.resize(n_tokens);
|
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
|
||||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
|
||||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
|
||||||
seq_ids_enc[i].insert(seq_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(embd != nullptr);
|
|
||||||
|
|
||||||
switch (cparams.pooling_type) {
|
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
|
||||||
{
|
|
||||||
// extract token embeddings
|
|
||||||
GGML_ASSERT(embd != nullptr);
|
|
||||||
float * embd_out = embd;
|
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
|
||||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
|
||||||
} break;
|
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
|
||||||
case LLAMA_POOLING_TYPE_LAST:
|
|
||||||
{
|
|
||||||
// extract sequence embeddings
|
|
||||||
auto & embd_seq_out = embd_seq;
|
|
||||||
embd_seq_out.clear();
|
|
||||||
|
|
||||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
||||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
embd_seq_out[seq_id].resize(n_embd);
|
|
||||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case LLAMA_POOLING_TYPE_RANK:
|
|
||||||
{
|
|
||||||
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
|
||||||
// wait for an encoder model that requires this pooling type in order to test it
|
|
||||||
// https://github.com/ggerganov/llama.cpp/pull/9510
|
|
||||||
GGML_ABORT("RANK pooling not implemented yet");
|
|
||||||
}
|
|
||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
||||||
{
|
|
||||||
GGML_ABORT("unknown pooling type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
||||||
// overlap with device computation.
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_pos llama_context_kv_self::pos_max() const {
|
llama_pos llama_context_kv_self::pos_max() const {
|
||||||
return kv_self.pos_max();
|
return kv_self.pos_max();
|
||||||
}
|
}
|
||||||
|
@ -116,6 +116,17 @@ struct llama_context : public llama_graph_i {
|
|||||||
// TODO: maybe remove this
|
// TODO: maybe remove this
|
||||||
virtual void output_reorder();
|
virtual void output_reorder();
|
||||||
|
|
||||||
|
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||||
|
//
|
||||||
|
// - lctx: llama context
|
||||||
|
// - batch: batch to evaluate
|
||||||
|
//
|
||||||
|
// return 0 on success
|
||||||
|
// return positive int on warning
|
||||||
|
// return negative int on error
|
||||||
|
//
|
||||||
|
virtual int encode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
// decode a batch of tokens by evaluating the transformer
|
||||||
// in case of unsuccessful decoding (error or warning),
|
// in case of unsuccessful decoding (error or warning),
|
||||||
// the kv_cache state will be returned to its original state
|
// the kv_cache state will be returned to its original state
|
||||||
@ -130,17 +141,6 @@ struct llama_context : public llama_graph_i {
|
|||||||
//
|
//
|
||||||
virtual int decode(llama_batch & inp_batch) = 0;
|
virtual int decode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
|
||||||
//
|
|
||||||
// - lctx: llama context
|
|
||||||
// - batch: batch to evaluate
|
|
||||||
//
|
|
||||||
// return 0 on success
|
|
||||||
// return positive int on warning
|
|
||||||
// return negative int on error
|
|
||||||
//
|
|
||||||
virtual int encode(llama_batch & inp_batch) = 0;
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// graph build API (generic)
|
// graph build API (generic)
|
||||||
//
|
//
|
||||||
@ -336,8 +336,8 @@ public:
|
|||||||
|
|
||||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||||
|
|
||||||
virtual int decode(llama_batch & inp_batch) override;
|
|
||||||
virtual int encode(llama_batch & inp_batch) override;
|
virtual int encode(llama_batch & inp_batch) override;
|
||||||
|
virtual int decode(llama_batch & inp_batch) override;
|
||||||
|
|
||||||
// max token position across all sequences in the current context
|
// max token position across all sequences in the current context
|
||||||
llama_pos pos_max() const;
|
llama_pos pos_max() const;
|
||||||
|
Reference in New Issue
Block a user