mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-16 07:38:28 +00:00
163 lines
4.8 KiB
Plaintext
163 lines
4.8 KiB
Plaintext
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
|
layout (constant_id = 1) const uint32_t Br = 1;
|
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
layout (constant_id = 3) const uint32_t D = 32;
|
|
layout (constant_id = 4) const uint32_t Clamp = 0;
|
|
layout (constant_id = 5) const uint32_t D_split = 16;
|
|
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint32_t N;
|
|
uint32_t KV;
|
|
|
|
uint32_t ne1;
|
|
uint32_t ne2;
|
|
uint32_t ne3;
|
|
|
|
uint32_t neq2;
|
|
uint32_t neq3;
|
|
uint32_t nek2;
|
|
uint32_t nek3;
|
|
uint32_t nev2;
|
|
uint32_t nev3;
|
|
uint32_t nem1;
|
|
uint32_t nem2;
|
|
|
|
uint32_t nb01;
|
|
uint32_t nb02;
|
|
uint32_t nb03;
|
|
uint32_t nb11;
|
|
uint32_t nb12;
|
|
uint32_t nb13;
|
|
uint32_t nb21;
|
|
uint32_t nb22;
|
|
uint32_t nb23;
|
|
|
|
float scale;
|
|
float max_bias;
|
|
float logit_softcap;
|
|
|
|
uint32_t mask;
|
|
uint32_t n_head_log2;
|
|
float m0;
|
|
float m1;
|
|
|
|
uint32_t gqa_ratio;
|
|
uint32_t split_kv;
|
|
uint32_t k_num;
|
|
} p;
|
|
|
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
|
|
#if defined(A_TYPE_PACKED16)
|
|
#define BINDING_IDX_K 0
|
|
#define BINDING_IDX_V 1
|
|
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
|
#endif
|
|
|
|
#if defined(DATA_A_Q4_0)
|
|
#define BLOCK_BYTE_SIZE 18
|
|
|
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
|
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
|
uint shift = (iqs & 0x10) >> 2;
|
|
vui_lo >>= shift;
|
|
vui_hi >>= shift;
|
|
|
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
|
}
|
|
#endif
|
|
|
|
#if defined(DATA_A_Q8_0)
|
|
#define BLOCK_BYTE_SIZE 34
|
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
|
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
|
|
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
|
}
|
|
#endif
|
|
|
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
|
|
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c == 0) {
|
|
uint32_t offset = iq2 + r;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Load the slope matrix, indexed by Q's dimension 2.
|
|
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
{
|
|
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
|
|
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
|
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
}
|
|
|
|
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
|
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
|
q_stride, k_stride, v_stride, m_stride;
|
|
|
|
void init_indices()
|
|
{
|
|
N = p.N;
|
|
KV = p.KV;
|
|
|
|
i = gl_WorkGroupID.x;
|
|
split_k_index = 0;
|
|
|
|
if (p.k_num > 1) {
|
|
i = 0;
|
|
split_k_index = gl_WorkGroupID.x;
|
|
}
|
|
|
|
Tr = CEIL_DIV(N, Br);
|
|
|
|
start_j = split_k_index * p.split_kv / Bc;
|
|
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
|
|
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
iq3 = gl_WorkGroupID.z;
|
|
|
|
// broadcast factors
|
|
rk2 = p.neq2/p.nek2;
|
|
rk3 = p.neq3/p.nek3;
|
|
|
|
rv2 = p.neq2/p.nev2;
|
|
rv3 = p.neq3/p.nev3;
|
|
|
|
// k indices
|
|
ik3 = iq3 / rk3;
|
|
ik2 = iq2 / rk2;
|
|
|
|
// v indices
|
|
iv3 = iq3 / rv3;
|
|
iv2 = iq2 / rv2;
|
|
|
|
// nb?1 are already divided by the type size and are in units of elements.
|
|
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
// should be nb02 (which is in bytes).
|
|
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
k_stride = p.nb11;
|
|
v_stride = p.nb21;
|
|
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
// that prevents the compiler from folding the "&" through the select
|
|
// and breaking the alignment detection.
|
|
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
}
|