diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 8460698b1..94a93da55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -98,11 +98,11 @@ layout (constant_id = 12) const uint LOAD_VEC_B_SHIFT = 0; #ifdef COOPMAT #define SHMEM_STRIDE (BK + 8) #else -#define SHMEM_STRIDE (BK + 1) +#define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[3072]; @@ -223,8 +223,8 @@ void main() { } #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; - FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[TN]; + FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC2 cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -262,7 +262,7 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < BK; i++) { + [[unroll]] for (uint i = 0; i < BK / 2; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { @@ -278,7 +278,7 @@ void main() { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); + sums[sums_idx] += dot(ACC_TYPE_VEC2(cache_a[wsir * TM + cr]), ACC_TYPE_VEC2(cache_b[cc])); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp index bab76bd82..52c0b4b76 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp @@ -4,32 +4,31 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if defined(FLOAT16) && defined(A_TYPE_VEC8) if (LOAD_VEC_A_SHIFT == 3) { const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row; - const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT)); - buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec8[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec8[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec8[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec8[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a_vec8[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a_vec8[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a_vec8[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a_vec8[idx][1].w); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; + const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_a_vec8[idx]); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals[0].xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw); + buf_a[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy); + buf_a[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw); } else #endif if (LOAD_VEC_A_SHIFT == 2) { const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row; - const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT)); - buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec4[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec4[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec4[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec4[idx].w); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; + const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_a_vec4[idx]); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw); + } else if (idx_m < p.M && idx_k + 1 < end_k) { + buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row ], + data_a[pos_a + col * p.stride_a + row + 1]); } else if (idx_m < p.M && idx_k < end_k) { - buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]); + buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row], 0.0f); } else { - buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f); + buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f); } #elif defined(DATA_A_Q4_0) - const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row; - const uint buf_idx = col * SHMEM_STRIDE + 4 * row; + const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; + const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -39,17 +38,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q4_1) - const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row; - const uint buf_idx = col * SHMEM_STRIDE + 4 * row; + const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; + const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -60,17 +55,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q5_0) - const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; + const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -83,13 +74,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -103,13 +92,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -119,13 +106,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -140,11 +125,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -162,11 +146,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); - buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); - buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -195,11 +179,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -231,11 +215,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -250,11 +234,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); - buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib32 = (idx % 128) / 16; // 0..7 @@ -274,11 +258,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin ); const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib8 = (idx % 128) / 4; @@ -300,11 +283,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin ); const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib32 = (idx % 128) / 16; // 0..7 @@ -325,11 +307,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib32 = (idx % 128) / 16; // 0..7 @@ -345,11 +326,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib8 = (idx % 128) / 4; // 0..31 @@ -367,11 +347,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = (idx % 128) / 2; // 0..63 @@ -392,11 +371,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint iqs = (idx % 128) / 2; // 0..63 @@ -412,11 +390,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2; const uint ib = idx / 128; // 2 values per idx const uint ib32 = (idx % 128) / 16; // 0..7 @@ -431,11 +408,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = float(data_a[ib].d); const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -443,10 +419,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; - buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF] * d, + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)] * d); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)] * d, + kvalues_iq4nl[vui >> 12] * d); #endif } @@ -455,28 +431,27 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin #if defined(B_TYPE_VEC8) if (LOAD_VEC_B_SHIFT == 3) { const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT); - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2; + const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]); + buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy); + buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw); + buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy); + buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw); } else #endif if (LOAD_VEC_B_SHIFT == 2) { const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT); - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w); - } else if (idx_n < p.N && idx_k < end_k) { - buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + col * p.stride_b + row]); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2; + const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]); + buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy); + buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw); + } else if (idx_n < p.N && idx_k + 1 < end_k) { + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row ], + data_b[pos_b + col * p.stride_b + row + 1]); + } else if (idx_n < p.N && idx_k + 1 < end_k) { + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row], 0.0f); } else { - buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f); + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f); } } #else @@ -485,32 +460,34 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin if (LOAD_VEC_B_SHIFT == 3) { const u16vec2 row_idx = row_ids[ic * BN + col]; const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT); - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2; + const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]); + buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy); + buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw); + buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy); + buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw); } else #endif if (LOAD_VEC_B_SHIFT == 2) { const u16vec2 row_idx = row_ids[ic * BN + col]; const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row; - const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT); - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w); + const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2; + const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]); + buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy); + buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw); } else { - const uint row_i = ic * BN + col; - if (row_i < _ne1) { - const u16vec2 row_idx = row_ids[row_i]; - buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row]); + const uint row_i_1 = ic * BN + col; + const uint row_i_2 = ic * BN + col + 1; + if (row_i_1 < _ne1 && row_i_2 < _ne1) { + const u16vec2 row_idx_1 = row_ids[row_i_1]; + const u16vec2 row_idx_2 = row_ids[row_i_2]; + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx_1.y * p.batch_stride_b + (row_idx_1.x % p.ne11) * p.stride_b + row], + data_b[pos_b + row_idx_2.y * p.batch_stride_b + (row_idx_2.x % p.ne11) * p.stride_b + row]); + } else if (row_i_1 < _ne1) { + const u16vec2 row_idx = row_ids[row_i_1]; + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row], 0.0f); } else { - buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f); + buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 271fcbc8f..2c424ba9e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -293,10 +293,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; - std::map base_dict = { - {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}, - {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, - }; + std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float" }, + {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2" }, + {"FLOAT_TYPE_VEC4", (coopmat2 || fp16) ? "f16vec4" : "vec4" }, + {"FLOAT_TYPE_VEC8", (coopmat2 || fp16) ? "f16mat2x4" : "mat2x4"}}; std::string shader_name = "matmul"; if (matmul_id) { @@ -308,7 +308,8 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; if (coopmat) { base_dict["COOPMAT"] = "1"; @@ -386,6 +387,7 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctype_vec2 = f16acc ? "f16vec2" : "vec2"; for (const auto& tname : type_names) { if (tname == "f32") {