mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-14 06:47:15 +00:00
context : wrap input tensors in struct
ggml-ci
This commit is contained in:
@ -966,32 +966,32 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
if (ubatch.token) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens));
|
||||
ggml_backend_tensor_set(inp.tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp.tokens));
|
||||
}
|
||||
|
||||
if (ubatch.embd) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd));
|
||||
ggml_backend_tensor_set(inp.embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp.embd));
|
||||
}
|
||||
|
||||
if (ubatch.pos && inp_pos) {
|
||||
if (ubatch.pos && inp.pos) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos));
|
||||
ggml_backend_tensor_set(inp.pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp.pos));
|
||||
}
|
||||
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
//GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs");
|
||||
//GGML_ASSERT(inp.out_ids && "every model that can must skip unused outputs");
|
||||
|
||||
if (!inp_out_ids) {
|
||||
LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__);
|
||||
if (!inp.out_ids) {
|
||||
LLAMA_LOG_WARN("%s: 'inp.out_ids' is not created\n", __func__);
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer));
|
||||
int32_t * data = (int32_t *) inp_out_ids->data;
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.out_ids->buffer));
|
||||
int32_t * data = (int32_t *) inp.out_ids->data;
|
||||
|
||||
if (n_outputs == n_tokens) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
@ -1020,11 +1020,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
GGML_ASSERT(inp_mean);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer));
|
||||
GGML_ASSERT(inp.mean);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.mean->buffer));
|
||||
|
||||
float * data = (float *) inp_mean->data;
|
||||
memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean));
|
||||
float * data = (float *) inp.mean->data;
|
||||
memset(inp.mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp.mean));
|
||||
|
||||
std::vector<uint64_t> sum(n_tokens, 0);
|
||||
|
||||
@ -1061,11 +1061,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
GGML_ASSERT(inp_cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
|
||||
GGML_ASSERT(inp.cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) inp_cls->data;
|
||||
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
|
||||
uint32_t * data = (uint32_t *) inp.cls->data;
|
||||
memset(inp.cls->data, 0, n_tokens * ggml_element_size(inp.cls));
|
||||
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
@ -1088,11 +1088,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
GGML_ASSERT(inp_cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
|
||||
GGML_ASSERT(inp.cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) inp_cls->data;
|
||||
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
|
||||
uint32_t * data = (uint32_t *) inp.cls->data;
|
||||
memset(inp.cls->data, 0, n_tokens * ggml_element_size(inp.cls));
|
||||
|
||||
std::vector<int> last_pos(n_tokens, -1);
|
||||
std::vector<int> last_row(n_tokens, -1);
|
||||
@ -1120,15 +1120,15 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (inp_kq_mask) {
|
||||
if (inp.kq_mask) {
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = ubatch.n_tokens;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
|
||||
float * data = (float *) inp_kq_mask->data;
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
|
||||
float * data = (float *) inp.kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
@ -1165,9 +1165,9 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
const int64_t n_stride = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
|
||||
|
||||
float * data = (float *) inp_kq_mask->data;
|
||||
float * data = (float *) inp.kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
@ -1329,15 +1329,7 @@ void llama_context::output_reorder() {
|
||||
//
|
||||
|
||||
ggml_cgraph * llama_context::graph_init() {
|
||||
inp_tokens = nullptr;
|
||||
inp_embd = nullptr;
|
||||
inp_pos = nullptr;
|
||||
inp_out_ids = nullptr;
|
||||
inp_mean = nullptr;
|
||||
inp_cls = nullptr;
|
||||
|
||||
inp_kq_mask = nullptr;
|
||||
inp_kq_mask_cnv = nullptr;
|
||||
inp = {};
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
@ -1563,11 +1555,11 @@ ggml_tensor * llama_context::build_inp_embd(
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (ubatch.token) {
|
||||
inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp_tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp_tokens);
|
||||
inp.tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp.tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp.tokens);
|
||||
|
||||
inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens);
|
||||
inpL = ggml_get_rows(ctx0, tok_embd, inp.tokens);
|
||||
|
||||
// apply lora for embedding tokens if needed
|
||||
for (const auto & lora : loras) {
|
||||
@ -1581,15 +1573,15 @@ ggml_tensor * llama_context::build_inp_embd(
|
||||
|
||||
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
|
||||
ctx0, lw->b, // non-transposed lora_b
|
||||
ggml_get_rows(ctx0, lw->a, inp_tokens)
|
||||
ggml_get_rows(ctx0, lw->a, inp.tokens)
|
||||
), scale);
|
||||
|
||||
inpL = ggml_add(ctx0, inpL, inpL_delta);
|
||||
}
|
||||
} else {
|
||||
inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = inp_embd;
|
||||
ggml_set_input(inp_embd);
|
||||
inp.embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = inp.embd;
|
||||
ggml_set_input(inp.embd);
|
||||
}
|
||||
|
||||
// For Granite architecture
|
||||
@ -1605,38 +1597,38 @@ ggml_tensor * llama_context::build_inp_embd(
|
||||
ggml_tensor * llama_context::build_inp_pos(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) {
|
||||
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
|
||||
ggml_set_input(inp_pos);
|
||||
inp.pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
|
||||
ggml_set_input(inp.pos);
|
||||
|
||||
return inp_pos;
|
||||
return inp.pos;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::build_inp_out_ids(
|
||||
ggml_context * ctx0) {
|
||||
const int32_t n_out_ids = n_outputs;
|
||||
|
||||
inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
|
||||
ggml_set_input(inp_out_ids);
|
||||
inp.out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
|
||||
ggml_set_input(inp.out_ids);
|
||||
|
||||
return inp_out_ids;
|
||||
return inp.out_ids;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::build_inp_mean(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) {
|
||||
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||
ggml_set_input(inp_mean);
|
||||
inp.mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||
ggml_set_input(inp.mean);
|
||||
|
||||
return inp_mean;
|
||||
return inp.mean;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::build_inp_cls(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) {
|
||||
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(inp_cls);
|
||||
inp.cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(inp.cls);
|
||||
|
||||
return inp_cls;
|
||||
return inp.cls;
|
||||
}
|
||||
|
||||
void llama_context::build_attn_inp(
|
||||
@ -1648,11 +1640,11 @@ void llama_context::build_attn_inp(
|
||||
GGML_UNUSED(causal);
|
||||
GGML_UNUSED(swa);
|
||||
|
||||
inp_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp_kq_mask);
|
||||
ggml_set_input(inp.kq_mask);
|
||||
|
||||
inp_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_kq_mask, GGML_TYPE_F16) : inp_kq_mask;
|
||||
inp.kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.kq_mask, GGML_TYPE_F16) : inp.kq_mask;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::build_attn(
|
||||
@ -1673,7 +1665,7 @@ ggml_tensor * llama_context::build_attn(
|
||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
const auto & kq_mask = inp_kq_mask_cnv;
|
||||
const auto & kq_mask = inp.kq_mask_cnv;
|
||||
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
@ -2923,10 +2915,10 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
const llama_hparams & hparams = model.hparams;
|
||||
|
||||
if (inp_self_k_shift) {
|
||||
assert(ggml_backend_buffer_is_host(inp_self_k_shift->buffer));
|
||||
if (inp.self_k_shift) {
|
||||
assert(ggml_backend_buffer_is_host(inp.self_k_shift->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) inp_self_k_shift->data;
|
||||
int32_t * data = (int32_t *) inp.self_k_shift->data;
|
||||
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
data[i] = kv_self.cells[i].delta;
|
||||
@ -2939,7 +2931,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
// call base functionality
|
||||
llama_context::input_set(ubatch);
|
||||
|
||||
if (inp_self_kq_mask || inp_self_kq_mask_swa) {
|
||||
if (inp.self_kq_mask || inp.self_kq_mask_swa) {
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn && !is_encoding) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
@ -2950,14 +2942,14 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (inp_self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask->buffer));
|
||||
data = (float *) inp_self_kq_mask->data;
|
||||
if (inp.self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
|
||||
data = (float *) inp.self_kq_mask->data;
|
||||
}
|
||||
|
||||
if (inp_self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) inp_self_kq_mask_swa->data;
|
||||
if (inp.self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) inp.self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
@ -3020,9 +3012,9 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_stride = hparams.causal_attn && !is_encoding ? kv_self.n : n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask->buffer));
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) inp_self_kq_mask->data;
|
||||
float * data = (float *) inp.self_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
@ -3156,20 +3148,16 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
|
||||
inp_pos_bucket = nullptr;
|
||||
inp_kq_mask_cross = nullptr;
|
||||
|
||||
inp_self_kq_mask = nullptr;
|
||||
inp_self_kq_mask_cnv = nullptr;
|
||||
inp_self_kq_mask_swa = nullptr;
|
||||
inp_self_kq_mask_swa_cnv = nullptr;
|
||||
inp_self_k_shift = nullptr;
|
||||
inp = {};
|
||||
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) {
|
||||
inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
|
||||
ggml_set_input(inp_self_k_shift);
|
||||
inp.self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
|
||||
ggml_set_input(inp.self_k_shift);
|
||||
|
||||
return inp_self_k_shift;
|
||||
return inp.self_k_shift;
|
||||
}
|
||||
|
||||
void llama_context_kv_self::build_attn_inp(
|
||||
@ -3179,26 +3167,26 @@ void llama_context_kv_self::build_attn_inp(
|
||||
bool swa) {
|
||||
const auto n_kv = kv_self.n;
|
||||
|
||||
inp_self_kq_mask = causal
|
||||
inp.self_kq_mask = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp_self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp_self_kq_mask);
|
||||
//cb(inp.self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp.self_kq_mask);
|
||||
|
||||
inp_self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_self_kq_mask, GGML_TYPE_F16) : inp_self_kq_mask;
|
||||
inp.self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask, GGML_TYPE_F16) : inp.self_kq_mask;
|
||||
|
||||
if (swa) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
|
||||
inp_self_kq_mask_swa = causal
|
||||
inp.self_kq_mask_swa = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp_self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp_self_kq_mask_swa);
|
||||
//cb(inp.self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp.self_kq_mask_swa);
|
||||
|
||||
inp_self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_self_kq_mask_swa, GGML_TYPE_F16) : inp_self_kq_mask_swa;
|
||||
inp.self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask_swa, GGML_TYPE_F16) : inp.self_kq_mask_swa;
|
||||
}
|
||||
}
|
||||
|
||||
@ -3277,7 +3265,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
||||
}
|
||||
};
|
||||
|
||||
const auto & kq_mask = is_sliding ? inp_self_kq_mask_swa_cnv : inp_self_kq_mask_cnv;
|
||||
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
|
||||
|
||||
const auto n_kv = kv_self.n;
|
||||
|
||||
@ -4145,9 +4133,9 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
if (inp_s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
|
||||
float * data = (float *) inp_s_mask->data;
|
||||
if (inp.s_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer));
|
||||
float * data = (float *) inp.s_mask->data;
|
||||
|
||||
// clear unused states
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
@ -4164,9 +4152,9 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (inp_s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_s_copy->data;
|
||||
if (inp.s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp.s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
@ -4190,8 +4178,8 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context_recurrent::graph_init() {
|
||||
inp_s_copy = nullptr;
|
||||
inp_s_mask = nullptr;
|
||||
inp.s_copy = nullptr;
|
||||
inp.s_mask = nullptr;
|
||||
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
@ -4200,22 +4188,22 @@ ggml_tensor * llama_context_recurrent::build_inp_s_copy(
|
||||
ggml_context * ctx0) {
|
||||
const auto n_kv = kv_self.n;
|
||||
|
||||
inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
//cb(inp_s_copy, "inp_s_copy", -1);
|
||||
ggml_set_input(inp_s_copy);
|
||||
inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
//cb(inp.s_copy, "inp_s_copy", -1);
|
||||
ggml_set_input(inp.s_copy);
|
||||
|
||||
return inp_s_copy;
|
||||
return inp.s_copy;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
|
||||
ggml_context * ctx0) {
|
||||
const auto n_kv = kv_self.n;
|
||||
|
||||
inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||
//cb(inp_s_mask, "inp_s_mask", -1);
|
||||
ggml_set_input(inp_s_mask);
|
||||
inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||
//cb(inp.s_mask, "inp_s_mask", -1);
|
||||
ggml_set_input(inp.s_mask);
|
||||
|
||||
return inp_s_mask;
|
||||
return inp.s_mask;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_recurrent::build_copy_mask_state(
|
||||
|
@ -139,17 +139,19 @@ protected:
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch);
|
||||
|
||||
struct {
|
||||
// base input tensors
|
||||
ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
||||
ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
ggml_tensor * tokens; // I32 [n_batch]
|
||||
ggml_tensor * embd; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * pos; // I32 [n_batch]
|
||||
ggml_tensor * out_ids; // I32 [n_outputs]
|
||||
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
||||
ggml_tensor * cls; // I32 [n_batch]
|
||||
|
||||
// KQ mask input tensors
|
||||
ggml_tensor * inp_kq_mask; // F32 [n_tokens, n_batch]
|
||||
ggml_tensor * inp_kq_mask_cnv; // [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
|
||||
} inp;
|
||||
|
||||
//
|
||||
// output
|
||||
@ -409,11 +411,13 @@ protected:
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
ggml_tensor * inp_self_kq_mask; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * inp_self_kq_mask_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * inp_self_kq_mask_swa; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * inp_self_kq_mask_swa_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * inp_self_k_shift; // I32 [kv_size]
|
||||
struct {
|
||||
ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * self_k_shift; // I32 [kv_size]
|
||||
} inp;
|
||||
|
||||
//
|
||||
// graph
|
||||
@ -519,8 +523,10 @@ protected:
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
struct {
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
} inp;
|
||||
|
||||
//
|
||||
// graph
|
||||
|
Reference in New Issue
Block a user