context : move common inputs to base class

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-14 16:48:21 +02:00
parent d5e8e1a2ba
commit 828064564c
2 changed files with 111 additions and 111 deletions

View File

@ -987,6 +987,95 @@ ggml_tensor * llama_context::build_rope_factors(int il) {
return model.layers[il].rope_short;
}
ggml_tensor * llama_context::build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
struct ggml_tensor * inpL;
if (ubatch.token) {
inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp_tokens, "inp_tokens", -1);
ggml_set_input(inp_tokens);
inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens);
// apply lora for embedding tokens if needed
for (const auto & lora : loras) {
struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
if (lw == nullptr) {
continue;
}
const float adapter_scale = lora.second;
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
ctx0, lw->b, // non-transposed lora_b
ggml_get_rows(ctx0, lw->a, inp_tokens)
), scale);
inpL = ggml_add(ctx0, inpL, inpL_delta);
}
} else {
inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = inp_embd;
ggml_set_input(inp_embd);
}
// For Granite architecture
if (hparams.f_embedding_scale != 0.0f) {
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
}
//cb(inpL, "inp_embd", -1);
return inpL;
}
ggml_tensor * llama_context::build_inp_pos(
ggml_context * ctx0,
int32_t n_tokens) {
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
ggml_set_input(inp_pos);
return inp_pos;
}
ggml_tensor * llama_context::build_inp_out_ids(
ggml_context * ctx0,
int32_t n_tokens,
bool worst_case) {
const int32_t n_out_ids = worst_case ? n_tokens : n_outputs;
inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
ggml_set_input(inp_out_ids);
return inp_out_ids;
}
ggml_tensor * llama_context::build_inp_mean(
ggml_context * ctx0,
int32_t n_tokens) {
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
ggml_set_input(inp_mean);
return inp_mean;
}
ggml_tensor * llama_context::build_inp_cls(
ggml_context * ctx0,
int32_t n_tokens) {
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp_cls);
return inp_cls;
}
//
// state
//
@ -2682,95 +2771,6 @@ ggml_tensor * llama_context_kv_self::build_soft_max_ext(
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
}
ggml_tensor * llama_context_kv_self::build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
struct ggml_tensor * inpL;
if (ubatch.token) {
inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp_tokens, "inp_tokens", -1);
ggml_set_input(inp_tokens);
inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens);
// apply lora for embedding tokens if needed
for (const auto & lora : loras) {
struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
if (lw == nullptr) {
continue;
}
const float adapter_scale = lora.second;
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
ctx0, lw->b, // non-transposed lora_b
ggml_get_rows(ctx0, lw->a, inp_tokens)
), scale);
inpL = ggml_add(ctx0, inpL, inpL_delta);
}
} else {
inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = inp_embd;
ggml_set_input(inp_embd);
}
// For Granite architecture
if (hparams.f_embedding_scale != 0.0f) {
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
}
//cb(inpL, "inp_embd", -1);
return inpL;
}
ggml_tensor * llama_context_kv_self::build_inp_pos(
ggml_context * ctx0,
int32_t n_tokens) {
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
ggml_set_input(inp_pos);
return inp_pos;
}
ggml_tensor * llama_context_kv_self::build_inp_out_ids(
ggml_context * ctx0,
int32_t n_tokens,
bool worst_case) {
const int32_t n_out_ids = worst_case ? n_tokens : n_outputs;
inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids);
ggml_set_input(inp_out_ids);
return inp_out_ids;
}
ggml_tensor * llama_context_kv_self::build_inp_mean(
ggml_context * ctx0,
int32_t n_tokens) {
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
ggml_set_input(inp_mean);
return inp_mean;
}
ggml_tensor * llama_context_kv_self::build_inp_cls(
ggml_context * ctx0,
int32_t n_tokens) {
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp_cls);
return inp_cls;
}
void llama_context_kv_self::build_k_shift(
ggml_context * ctx0,
ggml_cgraph * graph) {

View File

@ -169,6 +169,28 @@ struct llama_context : public llama_graph_i {
virtual ggml_tensor * build_rope_factors(int il);
virtual ggml_tensor * build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch);
virtual ggml_tensor * build_inp_pos(
ggml_context * ctx0,
int32_t n_tokens);
virtual ggml_tensor * build_inp_out_ids(
ggml_context * ctx0,
int32_t n_tokens,
bool worst_case);
virtual ggml_tensor * build_inp_mean(
ggml_context * ctx0,
int32_t n_tokens);
virtual ggml_tensor * build_inp_cls(
ggml_context * ctx0,
int32_t n_tokens);
// state save/load
virtual size_t state_get_size();
@ -330,28 +352,6 @@ public:
struct ggml_tensor * inp_KQ_mask_swa_cnv; // [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
virtual ggml_tensor * build_inp_embd(
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) override;
virtual ggml_tensor * build_inp_pos(
ggml_context * ctx0,
int32_t n_tokens) override;
virtual ggml_tensor * build_inp_out_ids(
ggml_context * ctx0,
int32_t n_tokens,
bool worst_case) override;
virtual ggml_tensor * build_inp_mean(
ggml_context * ctx0,
int32_t n_tokens) override;
virtual ggml_tensor * build_inp_cls(
ggml_context * ctx0,
int32_t n_tokens) override;
virtual void build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,