From 65a3ebb0aa56d6c501466e0f950ad15105fd32d8 Mon Sep 17 00:00:00 2001 From: Anton Mitkov Date: Mon, 14 Jul 2025 10:37:35 +0100 Subject: [PATCH] sycl: Batched mulmat rework for oneDNN dispatch (#14617) --- ggml/src/ggml-sycl/gemm.hpp | 42 +++----- ggml/src/ggml-sycl/ggml-sycl.cpp | 163 ++++++++++++++++++++++--------- 2 files changed, 133 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 5efe03d36..dcf6c7aee 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,39 +32,28 @@ public: else static_assert(0); } - // matrix A has m rows, k columns - // matrix B has k rows, n columns - // nra - number of elements to skip when moving into next row in A - // nrb - number of elements to skip when moving into next row in B - // nca - number of elements to skip when moving into next column in A - // ncb - number of elements to skip when moving into next column in B - // stride_a - number of elements to skip when moving to next A matrix - // stride_b - number of elements to skip when moving to next B matrix - // batches_a - number of A matrices - // batches_b - number of B matrices static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, - const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, - const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2, + const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2, void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - // { # strides, # rows, # columns } - dnnl::memory::dims a_dims = { batches_a, m, k }; - dnnl::memory::dims b_dims = { batches_b, k, n }; - dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; - - // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } - dnnl::memory::dims a_strides = { stride_a, nra, nca }; - dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; - + dnnl::memory::dims a_dims = {batches_a, m, k }; + dnnl::memory::dims a_strides = {stra2, stra1, stra0}; const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); + dnnl::memory::dims b_dims = {batches_b, k, n }; + dnnl::memory::dims b_strides = {strb2, strb0, strb1}; + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n}; + dnnl::memory::dims c_strides = {m*n, 1, m }; + const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + #ifdef GGML_SYCL_F16 primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); #endif @@ -76,24 +65,23 @@ public: auto scratchpad_md = matmul_pd.scratchpad_desc(); auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); + auto matmul_prim = dnnl::matmul(matmul_pd); std::unordered_map matmul_args; matmul_args.insert({ DNNL_ARG_SRC, a_mem }); matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); matmul_prim.execute(stream, matmul_args); } - // matrices A and B are column major, both having k rows - // matrix A has m column, matrix B has n columns - // output: column major matrix C = A transposed * B static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { - gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); } }; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 7f74fbfe5..cf46012be 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32( static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor, + const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor, const sycl::nd_item<3> &item_ct1) { const sycl::half *x = (const sycl::half *)vx; @@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous item_ct1.get_local_id(0); const int channel_x = channel / channel_x_divisor; - const int nrows_y = ncols_x; const int nrows_dst = nrows_x; const int row_dst = row_x; @@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const int row_y = col_x; const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; + const int iy = channel * channel_stride_y + row_y; const float xi = sycl::vec(x[ix]) @@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, static void ggml_mul_mat_vec_nc_f16_f32_sycl( const void *vx, const float *y, float *dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, - const int nchannels_y, const int channel_stride_x, queue_ptr stream) { + const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) { const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, - row_stride_x, channel_stride_x, + row_stride_x, channel_stride_x, channel_stride_y, nchannels_y / nchannels_x, item_ct1); }); } @@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr, + DnnlGemmWrapper::to_dt(), src1_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t nb02 = src0->nb[2]; const int64_t ne12 = src1->ne[2]; + const int64_t nb11 = src1->nb[1]; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream(); @@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t row_stride_x = nb01 / sizeof(sycl::half); const int64_t channel_stride_x = nb02 / sizeof(sycl::half); + const int64_t channel_stride_y = nb11 / sizeof(float); - ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); + ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons float * dst_ddf = static_cast(dst->data); const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src0 = ggml_type_size(src0->type); const size_t type_size_src1 = ggml_type_size(src1->type); - GGML_ASSERT(nb10 == type_size_src1); // SRC1 strides int64_t s11 = nb11 / type_size_src1; @@ -2854,11 +2855,32 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons if (src1->type != GGML_TYPE_F16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, " : converting src1 to fp16"); - const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); - GGML_ASSERT(to_fp16_nc_sycl != nullptr); - const int64_t ne_src1 = ggml_nelements(src1); + + // iterate tensor dims and find the slowest moving dim and stride + int64_t last_dim=0; + int64_t last_str=0; + int64_t largest_str=0; + for(int i = 0; i< 4; i++){ + // last stride is always the largest + if(src1->nb[i] == largest_str){ + if(src1->ne[last_dim] == 1){ + last_str = i; + last_dim = i; + } + } + if(src1->nb[i] > largest_str){ + largest_str = src1->nb[i]; + last_str = i; + last_dim = i; + } + + } + const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; src1_f16_alloc.alloc(ne_src1); - to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); + + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue); src1_f16 = src1_f16_alloc.get(); s11 = ne10; @@ -2892,38 +2914,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] - (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { + int64_t str_a0 = nb00 / type_size_src0; + int64_t str_a1 = nb01 / type_size_src0; + int64_t str_a2 = nb02 / type_size_src0; - DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, - src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, - src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, - dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); - }; + int64_t str_b0 = nb10 / type_size_src1; + int64_t str_b1 = nb11 / type_size_src1; + int64_t str_b2 = nb12 / type_size_src1; - if (r2 == 1 && r3 == 1) { - if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); - } - else { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes - const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; - float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0, + const sycl::half *src1, float *dst, + int64_t a0, int64_t a1, int64_t batcha, + int64_t b0, int64_t b1, int64_t batchb, + int64_t sa0, int64_t sa1, int64_t sa2, + int64_t sb0, int64_t sb1, int64_t sb2, + int64_t sd2) { + bool supported_broadcast = batchb == batcha ? true + : batchb == 1 || batcha == 1 ? true + : false; + if (supported_broadcast) { + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, src1, + DnnlGemmWrapper::to_dt(), sb0, sb1, sb2, dst, + DnnlGemmWrapper::to_dt(), queue, batcha, batchb); + } else { + // iterate over batches from smaller set of matrices (matrix 0) + int64_t batches0 = batcha; + int64_t batches1 = batchb; + + if (batches0 > batches1) { + int64_t num_mul_mats = batches1; + int64_t sub_batch = batches0 / num_mul_mats; + // src0 is batched and bigger, shift and multiply with src1 + for (int64_t i0 = 0; i0 < num_mul_mats; i0++) { + const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch); + const sycl::half *src1_shifted = src1 + (sb2 * i0); + float *dst_shifted = dst + (sd2 * i0 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, sub_batch, 1); + } + } else { + int64_t num_mul_mats = batches0; + int64_t sub_batch = batches1 / num_mul_mats; + // src1 is batched and bigger, shift and multiply with src0 + for (int64_t i1 = 0; i1 < num_mul_mats; i1++) { + const sycl::half *src0_shifted = src0 + (sa2 * i1); + const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch); + float *dst_shifted = dst + (sd2 * i1 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, 1, sub_batch); + } + } + } + }; + + bool cont_batches_a = nb02 * ne02 == nb03; + bool cont_batches_b = nb12 * ne12 == nb13; + if (cont_batches_a && cont_batches_b) { + int64_t batches0 = ne02 * ne03; + int64_t batches1 = ne12 * ne13; + launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, + ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1, + str_b2, nb2 / sizeof(float)); + } else { + for (int64_t b_a = 0; b_a < ne03; b_a++) { + const sycl::half *src0_f16_shifted + = src0_f16 + (nb03 * b_a / type_size_src0); + const sycl::half *src1_f16_shifted + = src1_f16 + (nb13 * b_a / type_size_src1); + float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float)); + int64_t batches0 = ne02; + int64_t batches1 = ne12; + launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted, + ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1, + str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float)); } } - } else { - // iterate over batches from smaller set of matrices (matrix 0) - for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); - const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; - float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); - } - } - } + } else #endif @@ -3263,10 +3336,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) {