mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
CUDA: fix q_nope_absorbed prec for DS 2 Lite f16 (#13137)
This commit is contained in:
@ -393,8 +393,8 @@ extern "C" {
|
||||
|
||||
// precision
|
||||
enum ggml_prec {
|
||||
GGML_PREC_DEFAULT,
|
||||
GGML_PREC_F32,
|
||||
GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default
|
||||
GGML_PREC_F32 = 10,
|
||||
};
|
||||
|
||||
// model file types
|
||||
|
@ -1935,8 +1935,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && use_mul_mat_vec_q) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
|
||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
||||
dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// general KQ + KQV multi-batch without FlashAttention
|
||||
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
||||
} else if (use_mul_mat_vec) {
|
||||
|
@ -10149,6 +10149,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
||||
|
||||
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
|
||||
ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32);
|
||||
cb(q_nope_absorbed, "q_nope_absorbed", il);
|
||||
|
||||
// {kv_lora_rank, n_head, n_tokens}
|
||||
|
Reference in New Issue
Block a user