From 4efe9898862ccea908176a6801c643382f2e27f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 25 Feb 2025 16:11:17 +0200 Subject: [PATCH] context : pass embeddings tensor from encoder to decoder ggml-ci --- src/llama-context.cpp | 45 ++++++++++++++++++++++++------------------- src/llama-context.h | 7 ++++--- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index dacf80908..f7c83e886 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_ // llama_context_enc // +// TODO: avoid copy-paste of the entire encode() function int llama_context_enc::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); - cross->n_outputs = n_tokens; - cross->embd_enc = embd; + cross->t_embd = t_embd; // remember the sequence ids used during the encoding - needed for cross attention later cross->seq_ids_enc.resize(n_tokens); @@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) { void llama_context_dec::reserve() { // simulate full KV cache - cross->n_outputs = cparams.n_ubatch; - - LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs); + cross->t_embd = nullptr; llama_context_kv_self::reserve(); } @@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) { // call base functionality llama_context_kv_self::input_set(ubatch); - if (inp.cross_embd) { - assert(inp.cross_embd->type == GGML_TYPE_F32); - assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd); + //if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) { + // assert(inp.cross_embd->type == GGML_TYPE_F32); + // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd); - ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd)); - } + // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd)); + //} if (inp.cross_kq_mask) { - const int64_t n_output_enc = cross->n_outputs; + const int64_t n_enc = inp.cross_kq_mask->ne[0]; const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer)); @@ -4721,7 +4719,7 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_output_enc; ++i) { + for (int i = 0; i < n_enc; ++i) { float f = -INFINITY; for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { const llama_seq_id seq_id = ubatch.seq_id[j][s]; @@ -4729,13 +4727,13 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) { f = 0.0f; } } - data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; + data[h*(n_enc*n_tokens) + j*n_enc + i] = f; } } for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_output_enc; ++j) { - data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; + for (int j = 0; j < n_enc; ++j) { + data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; } } } @@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() { ggml_tensor * llama_context_dec::build_inp_cross_embd( ggml_context * ctx0) { + // if we have the output embeddings from the encoder, use them directly + if (cross->t_embd) { + inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd); + + return inp.cross_embd; + } + const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; - const int32_t n_outputs_enc = cross->n_outputs; + const auto n_embd = hparams.n_embd; + const auto n_enc = hparams.n_ctx_train; - inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); + inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(inp.cross_embd); return inp.cross_embd; @@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp( bool swa) { llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa); - const int32_t n_outputs_enc = cross->n_outputs; + const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train; - inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); ggml_set_input(inp.cross_kq_mask); inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask; diff --git a/src/llama-context.h b/src/llama-context.h index 3165865a7..af35b577b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -748,11 +748,12 @@ private: llama_kv_cache_recurrent kv_self; }; -// TODO: tmp - need something better +// TODO: tmp - need something better to pass the data from the encoder to the decoder struct llama_cross { - int32_t n_outputs; - float * embd_enc; + // the output embeddings from the encoder + ggml_tensor * t_embd = nullptr; + // needed to construct the cross-attention mask in the decoder std::vector> seq_ids_enc; };