diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index ed61869a5..2be54c31b 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC { } else if constexpr(RM == 8 && RN == 4) { KERNEL_8x4(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC { const int nth; }; -template +template class tinyBLAS_Q0_PPC { public: tinyBLAS_Q0_PPC(int64_t k, const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC { private: - template - inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); @@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC { fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); } } - - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { - int64_t i, j; - TA *aoffset = NULL; - VA *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; + /* This function processes quantized data from block_q4_0 elements. + * First the we try to extract the two int4 values stored in single int8_t into two signed int8. + * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8. + * Also compute the rowsum which is required to compensate the above conversion. */ + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); - aoffset = const_cast(a); - vecOffset = vec; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + } + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + int64_t i, j; + TA *aoffset = NULL; + int8_t *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + aoffset = const_cast(a); + vecOffset = vec; j = (rows >> 3); if (j > 0) { do { @@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC { aoffset7 = aoffset6 + lda; aoffset8 = aoffset7 + lda; aoffset += 8 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c5[0] = vec_and(c5[1], lowMask); - c5[1] = vec_sr(c5[1], v4); - c5[0] = vec_sub(c5[0], v8); - c5[1] = vec_sub(c5[1], v8); - vsum = vec_sum4s(c5[0], vsum); - vsum2 = vec_sum4s(c5[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c6[0] = vec_and(c6[1], lowMask); - c6[1] = vec_sr(c6[1], v4); - c6[0] = vec_sub(c6[0], v8); - c6[1] = vec_sub(c6[1], v8); - vsum = vec_sum4s(c6[0], vsum); - vsum2 = vec_sum4s(c6[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c7[0] = vec_and(c7[1], lowMask); - c7[1] = vec_sr(c7[1], v4); - c7[0] = vec_sub(c7[0], v8); - c7[1] = vec_sub(c7[1], v8); - vsum = vec_sum4s(c7[0], vsum); - vsum2 = vec_sum4s(c7[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c8[0] = vec_and(c8[1], lowMask); - c8[1] = vec_sr(c8[1], v4); - c8[0] = vec_sub(c8[0], v8); - c8[1] = vec_sub(c8[1], v8); - vsum = vec_sum4s(c8[0], vsum); - vsum2 = vec_sum4s(c8[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + process_q4_elements(c5, &comparray[4]); + process_q4_elements(c6, &comparray[5]); + process_q4_elements(c7, &comparray[6]); + process_q4_elements(c8, &comparray[7]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC { aoffset3 = aoffset2 + lda; aoffset4 = aoffset3 + lda; aoffset += 4 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats( 0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC { } } } - template - void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { int64_t i, j; - TB *aoffset = NULL; + block_q8_0 *aoffset = NULL; VA *vecOffset = NULL; - TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - - aoffset = const_cast(a); + block_q8_0* aoffsets[8]; + __vector_pair arr[8]; + VB c[8][2] = {0}; + VB c1[8] = {0}; VB c2[8] = {0}; + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; - aoffset5 += lda; - aoffset6 += lda; - aoffset7 += lda; - aoffset8 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + for (int it = 0; it < 8; it++) + aoffsets[it] += lda; vecOffset += 256; i--; } while(i > 0); @@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; - + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 4; it++) { + aoffsets[it] += lda; } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; vecOffset += 128; i--; } while(i > 0); } } + if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - __builtin_vsx_disassemble_pair(c3, &C3); - case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - __builtin_vsx_disassemble_pair(c2, &C2); - case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - __builtin_vsx_disassemble_pair(c1, &C1); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], &arr[2]); + c1[2] = c[2][0]; c2[2] = c[2][1]; + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], &arr[1]); + c1[1] = c[1][0]; c2[1] = c[1][1]; + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], &arr[0]); + c1[0] = c[0][0]; c2[0] = c[0][1]; break; } - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 3; it++) + aoffsets[it] += lda; vecOffset += 128; i--; } while(i > 0); @@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 8); - int n_rem = MIN(n - n0, 8); - // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance - // issues. After resolving them, below code will be enabled. - /*if (m_rem >= 16 && n_rem >= 8) { - mc = 16; - nc = 8; - gemm<16,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 16; - gemm<8,16>(m0, m, n0, n); - }*/ + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { mc = 4; nc = 8; - gemm<4,8>(m0, m, n0, n); + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { mc = 8; nc = 4; - gemm<8,4>(m0, m, n0, n); + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { mc = 4; nc = 4; - gemm_small<4, 4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 2: - mc = 2; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 3: - mc = 3; - gemm_small<3, 4>(m0, m, n0, n); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 2: - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 3: - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - default: - return; - } + gemm_small(m0, m, n0, n, mc, nc); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small<3, 4>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small<3, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small<3, 1>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small<2, 2>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small<2, 1>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small<1, 3>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small<1, 1>(m0, m, n0, n); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; @@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC { compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii, jj+4, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii, jj+4, 4, fin_res); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); - save_res<4, 4>(ii, jj+4, 8, fin_res); - save_res<4, 4>(ii+4, jj+4, 12, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); + save_res(ii, jj+4, 8, fin_res); + save_res(ii+4, jj+4, 12, fin_res); } - template - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); if (isAblock_q4) { - packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { @@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC { fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); } } - save_res(ii, jj, 0, fin_res); + save_res(ii, jj, 0, fin_res, RM, RN); } } @@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC { } else if constexpr(RM == 8 && RN == 8) { KERNEL_8x8(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC { } const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const block_q8_0 *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC { const int nth; }; -template class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const float *A, int64_t lda, + const float *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -2707,247 +2184,139 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - template - void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { + inline void vector_permute_store_4(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } + + inline void vector_permute_store_8(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); + } + + void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { int64_t i, j; - TA *aoffset = NULL, *boffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VA t1, t2, t3, t4, t5, t6, t7, t8; - aoffset = const_cast(a); + float * aoffsets[8]; + float *aoffset = NULL, *boffset = NULL; + __vector_pair arr[8]; + vector float c[8][2] = {0}; + vector float c1[8] = {0}; + vector float c2[8] = {0}; + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it< 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); + for (int it = 0; it< 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergeh(c5[1], c6[1]); - t4 = vec_mergeh(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+32); - vec_xst(t6, 0, boffset+36); - vec_xst(t7, 0, boffset+40); - vec_xst(t8, 0, boffset+44); - - t1 = vec_mergel(c1[1], c2[1]); - t2 = vec_mergel(c3[1], c4[1]); - t3 = vec_mergel(c5[1], c6[1]); - t4 = vec_mergel(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+48); - vec_xst(t6, 0, boffset+52); - vec_xst(t7, 0, boffset+56); - vec_xst(t8, 0, boffset+60); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + vector_permute_store_8(c1, boffset); + vector_permute_store_8(c2, boffset+32); + for (int it = 0; it < 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; boffset += 64; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - c5[0] = vec_xl(0, aoffset5); - c6[0] = vec_xl(0, aoffset6); - c7[0] = vec_xl(0, aoffset7); - c8[0] = vec_xl(0, aoffset8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); + for (int it = 0; it < 8 ; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_8(c1, boffset); } j--; } while(j > 0); } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergel(c1[0], c2[0]); - t4 = vec_mergel(c3[0], c4[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergel(c1[1], c2[1]); - t4 = vec_mergel(c3[1], c4[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + vector_permute_store_4(c1, boffset); + vector_permute_store_4(c2, boffset+16); + for (int it = 0; it < 4; it++) + aoffsets[it] += 8*lda; boffset += 32; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 4; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++) + aoffsets[it] = aoffsets[it-1] + lda; if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 3; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } } @@ -2957,8 +2326,8 @@ class tinyBLAS_PPC { acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -2973,8 +2342,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2994,8 +2363,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -3017,8 +2386,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -3033,155 +2402,37 @@ class tinyBLAS_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 16); - int n_rem = MIN(n - n0, 16); - if (m_rem >= 16 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4,8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8,4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4,4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - mc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - mc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; @@ -3206,22 +2457,22 @@ class tinyBLAS_PPC { * matrix elements. */ if (RM == 1) { - TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + float* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - TB* b = const_cast(B+(jj)*ldb+l); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + float* b = const_cast(B+(jj)*ldb+l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -3231,7 +2482,7 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); } } } @@ -3263,11 +2514,9 @@ class tinyBLAS_PPC { } } - const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const float *const A; + const float *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #elif defined(__MMA__) if (k % 8) return false; - tinyBLAS_PPC tb{ + tinyBLAS_PPC tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc,