mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 08:37:43 +00:00
vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (#14485)
Commit taken from remyoudompheng's PR https://github.com/ggml-org/llama.cpp/pull/12260 Co-authored-by: Rémy Oudompheng <remyoudompheng@gmail.com>
This commit is contained in:
@ -500,10 +500,9 @@ void main() {
|
|||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
const uint ib8 = (idx % 128) / 4;
|
const uint ib8 = idx % 32;
|
||||||
const int i8 = 2 * int(idx % 4);
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint qh = data_a[ib].qh[ib32];
|
const uint qh = data_a[ib].qh[ib32];
|
||||||
@ -512,22 +511,16 @@ void main() {
|
|||||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||||
|
|
||||||
const ivec2 gvec = ivec2(
|
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||||
bitfieldExtract(grid, 2 * (i8), 2),
|
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
}
|
||||||
);
|
|
||||||
const vec2 v = dl * (vec2(gvec) + delta);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
||||||
#elif defined(DATA_A_IQ1_M)
|
#elif defined(DATA_A_IQ1_M)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
const uint ib8 = (idx % 128) / 4;
|
const uint ib8 = idx % 32;
|
||||||
const uint ib16 = ib8 / 2;
|
const uint ib16 = ib8 / 2;
|
||||||
const int i8 = 2 * int(idx % 4);
|
|
||||||
|
|
||||||
const uint16_t[4] scales = data_a[ib].scales;
|
const uint16_t[4] scales = data_a[ib].scales;
|
||||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||||
@ -538,21 +531,17 @@ void main() {
|
|||||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||||
const ivec2 gvec = ivec2(
|
|
||||||
bitfieldExtract(grid, 2 * (i8), 2),
|
|
||||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
|
||||||
);
|
|
||||||
const vec2 v = dl * (vec2(gvec) + delta);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||||
|
}
|
||||||
#elif defined(DATA_A_IQ2_XXS)
|
#elif defined(DATA_A_IQ2_XXS)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
const uint ib8 = (idx / 4) % 4;
|
const uint ib8 = idx % 4;
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||||
@ -562,63 +551,81 @@ void main() {
|
|||||||
data_a[ib].qs[8*ib32 + 6],
|
data_a[ib].qs[8*ib32 + 6],
|
||||||
data_a[ib].qs[8*ib32 + 7]
|
data_a[ib].qs[8*ib32 + 7]
|
||||||
));
|
));
|
||||||
const float db = d * 0.25 * (0.5 + (signs >> 28));
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
const uvec2 grid = iq2xxs_grid[qs];
|
||||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
#elif defined(DATA_A_IQ2_XS)
|
#elif defined(DATA_A_IQ2_XS)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
const uint ib8 = (idx / 4) % 4; // 0..3
|
const uint ib8 = idx % 4; // 0..3
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||||
const float db = d * 0.25 * (0.5 + scale);
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||||
const uint sign7 = qs >> 9;
|
const uint sign7 = qs >> 9;
|
||||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
#elif defined(DATA_A_IQ2_S)
|
#elif defined(DATA_A_IQ2_S)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
const uint ib8 = idx % 32; // 0..31
|
||||||
const uint ib32 = ib8 / 4; // 0..7
|
const uint ib32 = ib8 / 4; // 0..7
|
||||||
|
|
||||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||||
const uint qs = data_a[ib].qs[ib8];
|
const uint qs = data_a[ib].qs[ib8];
|
||||||
const uint qh = data_a[ib].qh[ib32];
|
const uint qh = data_a[ib].qh[ib32];
|
||||||
const uint qhshift = 2 * (ib8 % 4);
|
const uint qhshift = 2 * (ib8 % 4);
|
||||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const float db = d * 0.25 * (0.5 + scale);
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
#elif defined(DATA_A_IQ3_XXS)
|
#elif defined(DATA_A_IQ3_XXS)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 64; // 4 values per idx
|
||||||
const uint iqs = (idx % 128) / 2; // 0..63
|
const uint iqs = idx % 64; // 0..63
|
||||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
@ -631,33 +638,36 @@ void main() {
|
|||||||
));
|
));
|
||||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
const uint grid = iq3xxs_grid[qs];
|
||||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
const vec4 v = db * vec4(unpack8(grid));
|
||||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||||
|
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||||
#elif defined(DATA_A_IQ3_S)
|
#elif defined(DATA_A_IQ3_S)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
const uint ib = idx / 64; // 4 values per idx
|
||||||
const uint iqs = (idx % 128) / 2; // 0..63
|
const uint iqs = idx % 64; // 0..63
|
||||||
const uint iqh = iqs / 8;
|
const uint iqh = iqs / 8;
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const uint qs = data_a[ib].qs[iqs];
|
const uint qs = data_a[ib].qs[iqs];
|
||||||
const uint qh = data_a[ib].qh[iqh];
|
const uint qh = data_a[ib].qh[iqh];
|
||||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
|
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||||
const uint scale = data_a[ib].scales[iqs / 16];
|
const uint scale = data_a[ib].scales[iqs / 16];
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
const vec4 v = db * vec4(unpack8(grid));
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||||
|
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||||
#elif defined(DATA_A_IQ4_XS)
|
#elif defined(DATA_A_IQ4_XS)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||||||
|
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
std::string load_vec_quant = "2";
|
std::string load_vec_quant = "2";
|
||||||
if ((tname == "q4_0") || (tname == "q4_1"))
|
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||||
load_vec_quant = "8";
|
load_vec_quant = "8";
|
||||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
||||||
load_vec_quant = "4";
|
load_vec_quant = "4";
|
||||||
|
|
||||||
if (tname == "bf16") {
|
if (tname == "bf16") {
|
||||||
|
Reference in New Issue
Block a user