vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (#14485)

Commit taken from remyoudompheng's PR https://github.com/ggml-org/llama.cpp/pull/12260

Co-authored-by: Rémy Oudompheng <remyoudompheng@gmail.com>
This commit is contained in:
Eve
2025-07-06 10:29:36 +00:00
committed by GitHub
parent e592be1575
commit 6491d6e4f1
2 changed files with 79 additions and 69 deletions

View File

@ -500,10 +500,9 @@ void main() {
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx % 128) / 4; const uint ib8 = idx % 32;
const int i8 = 2 * int(idx % 4);
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qh = data_a[ib].qh[ib32]; const uint qh = data_a[ib].qh[ib32];
@ -512,22 +511,16 @@ void main() {
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
const ivec2 gvec = ivec2( [[unroll]] for (int k = 0; k < 8; ++k) {
bitfieldExtract(grid, 2 * (i8), 2), buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
bitfieldExtract(grid, 2 * (i8 + 1), 2) }
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ1_M) #elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib8 = (idx % 128) / 4; const uint ib8 = idx % 32;
const uint ib16 = ib8 / 2; const uint ib16 = ib8 / 2;
const int i8 = 2 * int(idx % 4);
const uint16_t[4] scales = data_a[ib].scales; const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@ -538,21 +531,17 @@ void main() {
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x); [[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ2_XXS) #elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx / 4) % 4; const uint ib8 = idx % 4;
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8]; const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@ -562,63 +551,81 @@ void main() {
data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7] data_a[ib].qs[8*ib32 + 7]
)); ));
const float db = d * 0.25 * (0.5 + (signs >> 28)); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = sign7 | (bitCount(sign7) << 7);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2xxs_grid[qs];
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS) #elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx / 4) % 4; // 0..3 const uint ib8 = idx % 4; // 0..3
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const float db = d * 0.25 * (0.5 + scale); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9; const uint sign7 = qs >> 9;
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = sign7 | (bitCount(sign7) << 7);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2xs_grid[qs & 511];
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S) #elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31 const uint ib8 = idx % 32; // 0..31
const uint ib32 = ib8 / 4; // 0..7 const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8]; const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32]; const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4); const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const float db = d * 0.25 * (0.5 + scale); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS) #elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 128) / 2; // 0..63 const uint iqs = idx % 64; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
@ -631,33 +638,36 @@ void main() {
)); ));
const float db = d * 0.5 * (0.5 + (signs >> 28)); const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs];
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); const vec4 v = db * vec4(unpack8(grid));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S) #elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 128) / 2; // 0..63 const uint iqs = idx % 64; // 0..63
const uint iqh = iqs / 8; const uint iqh = iqs / 8;
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs]; const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh]; const uint qh = data_a[ib].qh[iqh];
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
const uint scale = data_a[ib].scales[iqs / 16]; const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS) #elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

View File

@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string load_vec_quant = "2"; std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1")) if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8"; load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
load_vec_quant = "4"; load_vec_quant = "4";
if (tname == "bf16") { if (tname == "bf16") {