mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-05 22:23:46 +00:00
MPT : support GQA for replit-code-v1.5 (#3627)
This commit is contained in:
@ -2839,8 +2839,8 @@ static void llm_load_tensors(
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
|
Reference in New Issue
Block a user