diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f7c83e886..4341c571e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -4673,6 +4673,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); cross->t_embd = t_embd; + cross->v_embd = embd; // remember the sequence ids used during the encoding - needed for cross attention later cross->seq_ids_enc.resize(n_tokens); @@ -4701,12 +4702,11 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) { // call base functionality llama_context_kv_self::input_set(ubatch); - //if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) { - // assert(inp.cross_embd->type == GGML_TYPE_F32); - // assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd); + if (inp.cross_embd && cross->t_embd) { + assert(inp.cross_embd->type == GGML_TYPE_F32); - // ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd)); - //} + ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd)); + } if (inp.cross_kq_mask) { const int64_t n_enc = inp.cross_kq_mask->ne[0]; @@ -4749,16 +4749,17 @@ ggml_cgraph * llama_context_dec::graph_init() { ggml_tensor * llama_context_dec::build_inp_cross_embd( ggml_context * ctx0) { // if we have the output embeddings from the encoder, use them directly - if (cross->t_embd) { - inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd); + // TODO: needs more work to be correct, for now just use the tensor shape + //if (cross->t_embd) { + // inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd); - return inp.cross_embd; - } + // return inp.cross_embd; + //} const auto & hparams = model.hparams; - const auto n_embd = hparams.n_embd; - const auto n_enc = hparams.n_ctx_train; + const auto n_embd = cross->t_embd ? cross->t_embd->ne[0] : hparams.n_embd; + const auto n_enc = cross->t_embd ? cross->t_embd->ne[1] : hparams.n_ctx_train; inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(inp.cross_embd); diff --git a/src/llama-context.h b/src/llama-context.h index af35b577b..1b807ccf8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -750,9 +750,14 @@ private: // TODO: tmp - need something better to pass the data from the encoder to the decoder struct llama_cross { - // the output embeddings from the encoder + // the output embeddings from the encoder as a ggml tensor + // TODO: this needs more work to be correct, for now copy the embeddings data to host memory + // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524 ggml_tensor * t_embd = nullptr; + // embeddings data copied to host memory (tmp) + float * v_embd = nullptr; + // needed to construct the cross-attention mask in the decoder std::vector> seq_ids_enc; };