mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-21 18:28:31 +00:00
context : pass embeddings tensor from encoder to decoder
ggml-ci
This commit is contained in:
@ -4540,6 +4540,7 @@ size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_
|
||||
// llama_context_enc
|
||||
//
|
||||
|
||||
// TODO: avoid copy-paste of the entire encode() function
|
||||
int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
@ -4671,8 +4672,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
cross->n_outputs = n_tokens;
|
||||
cross->embd_enc = embd;
|
||||
cross->t_embd = t_embd;
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
cross->seq_ids_enc.resize(n_tokens);
|
||||
@ -4692,9 +4692,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
|
||||
void llama_context_dec::reserve() {
|
||||
// simulate full KV cache
|
||||
cross->n_outputs = cparams.n_ubatch;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs);
|
||||
cross->t_embd = nullptr;
|
||||
|
||||
llama_context_kv_self::reserve();
|
||||
}
|
||||
@ -4703,15 +4701,15 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
||||
// call base functionality
|
||||
llama_context_kv_self::input_set(ubatch);
|
||||
|
||||
if (inp.cross_embd) {
|
||||
assert(inp.cross_embd->type == GGML_TYPE_F32);
|
||||
assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
|
||||
//if (inp.cross_embd && inp.cross_embd->op != GGML_OP_NONE) {
|
||||
// assert(inp.cross_embd->type == GGML_TYPE_F32);
|
||||
// assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd);
|
||||
|
||||
ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
|
||||
}
|
||||
// ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd));
|
||||
//}
|
||||
|
||||
if (inp.cross_kq_mask) {
|
||||
const int64_t n_output_enc = cross->n_outputs;
|
||||
const int64_t n_enc = inp.cross_kq_mask->ne[0];
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
|
||||
@ -4721,7 +4719,7 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_output_enc; ++i) {
|
||||
for (int i = 0; i < n_enc; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
||||
@ -4729,13 +4727,13 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
||||
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_output_enc; ++j) {
|
||||
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4750,12 +4748,19 @@ ggml_cgraph * llama_context_dec::graph_init() {
|
||||
|
||||
ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
||||
ggml_context * ctx0) {
|
||||
// if we have the output embeddings from the encoder, use them directly
|
||||
if (cross->t_embd) {
|
||||
inp.cross_embd = ggml_view_tensor(ctx0, cross->t_embd);
|
||||
|
||||
return inp.cross_embd;
|
||||
}
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const int32_t n_outputs_enc = cross->n_outputs;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_enc = hparams.n_ctx_train;
|
||||
|
||||
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
|
||||
inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||
ggml_set_input(inp.cross_embd);
|
||||
|
||||
return inp.cross_embd;
|
||||
@ -4768,9 +4773,9 @@ void llama_context_dec::build_attn_inp(
|
||||
bool swa) {
|
||||
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
|
||||
|
||||
const int32_t n_outputs_enc = cross->n_outputs;
|
||||
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
|
||||
|
||||
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(inp.cross_kq_mask);
|
||||
|
||||
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;
|
||||
|
@ -748,11 +748,12 @@ private:
|
||||
llama_kv_cache_recurrent kv_self;
|
||||
};
|
||||
|
||||
// TODO: tmp - need something better
|
||||
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
||||
struct llama_cross {
|
||||
int32_t n_outputs;
|
||||
float * embd_enc;
|
||||
// the output embeddings from the encoder
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
|
||||
// needed to construct the cross-attention mask in the decoder
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user