mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-14 06:47:15 +00:00
llama : reorder encode/decode in sources
This commit is contained in:
@ -1655,6 +1655,168 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
|
||||
int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
is_encoding = true;
|
||||
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||
|
||||
if (t_compute_start_us == 0) {
|
||||
t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
|
||||
n_queued_tokens += n_tokens;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||
|
||||
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_tokens) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
return -2;
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
output_ids[i] = i;
|
||||
}
|
||||
|
||||
inp_embd_enc = NULL;
|
||||
n_outputs = n_tokens;
|
||||
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
// TODO: do reserve
|
||||
GGML_ASSERT(need_reserve == false);
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
|
||||
auto * gf = res.gf;
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
if (llama_model_has_decoder(&model)) {
|
||||
embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = embd_enc.data();
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
seq_ids_enc.resize(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||
seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||
// wait for an encoder model that requires this pooling type in order to test it
|
||||
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
GGML_ABORT("RANK pooling not implemented yet");
|
||||
}
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
is_encoding = false;
|
||||
|
||||
@ -2020,168 +2182,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
is_encoding = true;
|
||||
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||
|
||||
if (t_compute_start_us == 0) {
|
||||
t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
|
||||
n_queued_tokens += n_tokens;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||
|
||||
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_tokens) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
return -2;
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
output_ids[i] = i;
|
||||
}
|
||||
|
||||
inp_embd_enc = NULL;
|
||||
n_outputs = n_tokens;
|
||||
|
||||
//batch_manager->prepare(ubatch);
|
||||
|
||||
// TODO: do reserve
|
||||
GGML_ASSERT(need_reserve == false);
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
|
||||
auto * gf = res.gf;
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
if (llama_model_has_decoder(&model)) {
|
||||
embd_enc.resize(n_tokens*n_embd);
|
||||
float * embd_out = embd_enc.data();
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
seq_ids_enc.resize(n_tokens);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||
seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd;
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
|
||||
// wait for an encoder model that requires this pooling type in order to test it
|
||||
// https://github.com/ggerganov/llama.cpp/pull/9510
|
||||
GGML_ABORT("RANK pooling not implemented yet");
|
||||
}
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_pos llama_context_kv_self::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
|
@ -116,6 +116,17 @@ struct llama_context : public llama_graph_i {
|
||||
// TODO: maybe remove this
|
||||
virtual void output_reorder();
|
||||
|
||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int encode(llama_batch & inp_batch) = 0;
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
// in case of unsuccessful decoding (error or warning),
|
||||
// the kv_cache state will be returned to its original state
|
||||
@ -130,17 +141,6 @@ struct llama_context : public llama_graph_i {
|
||||
//
|
||||
virtual int decode(llama_batch & inp_batch) = 0;
|
||||
|
||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int encode(llama_batch & inp_batch) = 0;
|
||||
|
||||
//
|
||||
// graph build API (generic)
|
||||
//
|
||||
@ -336,8 +336,8 @@ public:
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
Reference in New Issue
Block a user