Use dot instruction to multiply two values at once, to enable dual issue instructions

This commit is contained in:
0cc4m
2025-04-06 10:32:35 +00:00
parent 039cc0745a
commit ab27292b26
3 changed files with 123 additions and 144 deletions

View File

@@ -4,32 +4,31 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if defined(FLOAT16) && defined(A_TYPE_VEC8)
if (LOAD_VEC_A_SHIFT == 3) {
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT));
buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec8[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec8[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec8[idx][0].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec8[idx][0].w);
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a_vec8[idx][1].x);
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a_vec8[idx][1].y);
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a_vec8[idx][1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a_vec8[idx][1].w);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_a_vec8[idx]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals[0].xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
buf_a[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
buf_a[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
} else
#endif
if (LOAD_VEC_A_SHIFT == 2) {
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
const uint buf_idx = col * SHMEM_STRIDE + ((row << LOAD_VEC_A_SHIFT));
buf_a[buf_idx ] = FLOAT_TYPE(data_a_vec4[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a_vec4[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a_vec4[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a_vec4[idx].w);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_a_vec4[idx]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(vals.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
} else if (idx_m < p.M && idx_k + 1 < end_k) {
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row ],
data_a[pos_a + col * p.stride_a + row + 1]);
} else if (idx_m < p.M && idx_k < end_k) {
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_a[pos_a + col * p.stride_a + row], 0.0f);
} else {
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
}
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@@ -39,17 +38,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (4 * row) / 2;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@@ -60,17 +55,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * ((p.stride_a >> LOAD_VEC_A_SHIFT)) + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@@ -83,13 +74,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@@ -103,13 +92,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@@ -119,13 +106,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -140,11 +125,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -162,11 +146,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[ib].d) * float(us - 32);
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -195,11 +179,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -231,11 +215,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
@@ -250,11 +234,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ1_S)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
@@ -274,11 +258,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4;
@@ -300,11 +283,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
@@ -325,11 +307,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
@@ -345,11 +326,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31
@@ -367,11 +347,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
@@ -392,11 +371,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
@@ -412,11 +390,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_A_SHIFT) / 2;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
@@ -431,11 +408,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * (p.stride_a >> LOAD_VEC_A_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + (2 * row) / 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@@ -443,10 +419,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF] * d,
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)] * d);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)] * d,
kvalues_iq4nl[vui >> 12] * d);
#endif
}
@@ -455,28 +431,27 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
#if defined(B_TYPE_VEC8)
if (LOAD_VEC_B_SHIFT == 3) {
const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w);
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x);
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y);
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]);
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy);
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
} else
#endif
if (LOAD_VEC_B_SHIFT == 2) {
const uint idx = pos_b + col * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w);
} else if (idx_n < p.N && idx_k < end_k) {
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + col * p.stride_b + row]);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]);
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy);
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
} else if (idx_n < p.N && idx_k + 1 < end_k) {
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row ],
data_b[pos_b + col * p.stride_b + row + 1]);
} else if (idx_n < p.N && idx_k + 1 < end_k) {
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + col * p.stride_b + row], 0.0f);
} else {
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
}
}
#else
@@ -485,32 +460,34 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
if (LOAD_VEC_B_SHIFT == 3) {
const u16vec2 row_idx = row_ids[ic * BN + col];
const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec8[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec8[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec8[idx][0].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec8[idx][0].w);
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b_vec8[idx][1].x);
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b_vec8[idx][1].y);
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b_vec8[idx][1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b_vec8[idx][1].w);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
const FLOAT_TYPE_VEC8 vals = FLOAT_TYPE_VEC8(data_b_vec8[idx]);
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals[0].xy);
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals[0].zw);
buf_b[buf_idx + 2] = FLOAT_TYPE_VEC2(vals[1].xy);
buf_b[buf_idx + 3] = FLOAT_TYPE_VEC2(vals[1].zw);
} else
#endif
if (LOAD_VEC_B_SHIFT == 2) {
const u16vec2 row_idx = row_ids[ic * BN + col];
const uint idx = pos_b + row_idx.y * (p.batch_stride_b >> LOAD_VEC_B_SHIFT) + (row_idx.x % p.ne11) * (p.stride_b >> LOAD_VEC_B_SHIFT) + row;
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT);
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b_vec4[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b_vec4[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b_vec4[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b_vec4[idx].w);
const uint buf_idx = col * SHMEM_STRIDE + (row << LOAD_VEC_B_SHIFT) / 2;
const FLOAT_TYPE_VEC4 vals = FLOAT_TYPE_VEC4(data_b_vec4[idx]);
buf_b[buf_idx + 0] = FLOAT_TYPE_VEC2(vals.xy);
buf_b[buf_idx + 1] = FLOAT_TYPE_VEC2(vals.zw);
} else {
const uint row_i = ic * BN + col;
if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i];
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row]);
const uint row_i_1 = ic * BN + col;
const uint row_i_2 = ic * BN + col + 1;
if (row_i_1 < _ne1 && row_i_2 < _ne1) {
const u16vec2 row_idx_1 = row_ids[row_i_1];
const u16vec2 row_idx_2 = row_ids[row_i_2];
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx_1.y * p.batch_stride_b + (row_idx_1.x % p.ne11) * p.stride_b + row],
data_b[pos_b + row_idx_2.y * p.batch_stride_b + (row_idx_2.x % p.ne11) * p.stride_b + row]);
} else if (row_i_1 < _ne1) {
const u16vec2 row_idx = row_ids[row_i_1];
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row], 0.0f);
} else {
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE_VEC2(0.0f);
}
}
}