context : pass embeddings tensor from encoder to decoder

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-25 16:11:17 +02:00
parent e2b3294f2c
commit 4efe989886
2 changed files with 29 additions and 23 deletions

View File

@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
// llama_context_enc
//
// TODO: avoid copy-paste of the entire encode() function
int llama_context_enc::encode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
cross->n_outputs = n_tokens;
cross->embd_enc = embd;
cross->t_embd = t_embd;
// remember the sequence ids used during the encoding - needed for cross attention later
cross->seq_ids_enc.resize(n_tokens);
@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
void llama_context_dec::reserve() {
// simulate full KV cache
cross->n_outputs = cparams.n_ubatch;
LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs);
cross->t_embd = nullptr;
llama_context_kv_self::reserve();
}
@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_kv_self::input_set(ubatch);
if (inp.cross_embd) {
assert(inp.cross_embd->type == GGML_TYPE_F32);
assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
// assert(inp.cross_embd->type == GGML_TYPE_F32);
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
}
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
//}
if (inp.cross_kq_mask) {
const int64_t n_output_enc = cross->n_outputs;
const int64_t n_enc = inp.cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
@ -4721,7 +4719,7 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_output_enc; ++i) {
for (int i = 0; i < n_enc; ++i) {
float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[j][s];
@ -4729,13 +4727,13 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
f = 0.0f;
}
}
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_output_enc; ++j) {
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
}
}
@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
ggml_tensor * llama_context_dec::build_inp_cross_embd(
ggml_context * ctx0) {
// if we have the output embeddings from the encoder, use them directly
if (cross->t_embd) {
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
return inp.cross_embd;
}
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
const int32_t n_outputs_enc = cross->n_outputs;
const auto n_embd = hparams.n_embd;
const auto n_enc = hparams.n_ctx_train;
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
ggml_set_input(inp.cross_embd);
return inp.cross_embd;
@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
bool swa) {
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
const int32_t n_outputs_enc = cross->n_outputs;
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp.cross_kq_mask);
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;

View File

@ -748,11 +748,12 @@ private:
llama_kv_cache_recurrent kv_self;
};
// TODO: tmp - need something better
// TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross {
int32_t n_outputs;
float * embd_enc;
// the output embeddings from the encoder
ggml_tensor * t_embd = nullptr;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};