mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-12 06:09:18 +00:00
context : disable encoder embd tensor for now
ggml-ci
This commit is contained in:
@ -4673,6 +4673,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
|||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
cross->t_embd = t_embd;
|
cross->t_embd = t_embd;
|
||||||
|
cross->v_embd = embd;
|
||||||
|
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
cross->seq_ids_enc.resize(n_tokens);
|
cross->seq_ids_enc.resize(n_tokens);
|
||||||
@ -4701,12 +4702,11 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
|||||||
// call base functionality
|
// call base functionality
|
||||||
llama_context_kv_self::input_set(ubatch);
|
llama_context_kv_self::input_set(ubatch);
|
||||||
|
|
||||||
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
|
if (inp.cross_embd && cross->t_embd) {
|
||||||
// assert(inp.cross_embd->type == GGML_TYPE_F32);
|
assert(inp.cross_embd->type == GGML_TYPE_F32);
|
||||||
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
|
|
||||||
|
|
||||||
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
|
ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd));
|
||||||
//}
|
}
|
||||||
|
|
||||||
if (inp.cross_kq_mask) {
|
if (inp.cross_kq_mask) {
|
||||||
const int64_t n_enc = inp.cross_kq_mask->ne[0];
|
const int64_t n_enc = inp.cross_kq_mask->ne[0];
|
||||||
@ -4749,16 +4749,17 @@ ggml_cgraph * llama_context_dec::graph_init() {
|
|||||||
ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
||||||
ggml_context * ctx0) {
|
ggml_context * ctx0) {
|
||||||
// if we have the output embeddings from the encoder, use them directly
|
// if we have the output embeddings from the encoder, use them directly
|
||||||
if (cross->t_embd) {
|
// TODO: needs more work to be correct, for now just use the tensor shape
|
||||||
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
|
//if (cross->t_embd) {
|
||||||
|
// inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
|
||||||
|
|
||||||
return inp.cross_embd;
|
// return inp.cross_embd;
|
||||||
}
|
//}
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const auto n_embd = hparams.n_embd;
|
const auto n_embd = cross->t_embd ? cross->t_embd->ne[0] : hparams.n_embd;
|
||||||
const auto n_enc = hparams.n_ctx_train;
|
const auto n_enc = cross->t_embd ? cross->t_embd->ne[1] : hparams.n_ctx_train;
|
||||||
|
|
||||||
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||||
ggml_set_input(inp.cross_embd);
|
ggml_set_input(inp.cross_embd);
|
||||||
|
@ -750,9 +750,14 @@ private:
|
|||||||
|
|
||||||
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
||||||
struct llama_cross {
|
struct llama_cross {
|
||||||
// the output embeddings from the encoder
|
// the output embeddings from the encoder as a ggml tensor
|
||||||
|
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
|
||||||
ggml_tensor * t_embd = nullptr;
|
ggml_tensor * t_embd = nullptr;
|
||||||
|
|
||||||
|
// embeddings data copied to host memory (tmp)
|
||||||
|
float * v_embd = nullptr;
|
||||||
|
|
||||||
// needed to construct the cross-attention mask in the decoder
|
// needed to construct the cross-attention mask in the decoder
|
||||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user