diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3992dad01..a0a650e92 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2705,9 +2705,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons " : converting src1 to fp16"); // iterate tensor dims and find the slowest moving dim and stride - int64_t last_dim=0; - int64_t last_str=0; - int64_t largest_str=0; + int last_dim=0; + int last_str=0; + size_t largest_str=0; for(int i = 0; i< 4; i++){ // last stride is always the largest if(src1->nb[i] == largest_str){ @@ -2783,7 +2783,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0, const sycl::half *src1, float *dst, int64_t a0, int64_t a1, int64_t batcha, - int64_t b0, int64_t b1, int64_t batchb, + int64_t /*b0*/, int64_t b1, int64_t batchb, int64_t sa0, int64_t sa1, int64_t sa2, int64_t sb0, int64_t sb1, int64_t sb2, int64_t sd2) { @@ -2832,14 +2832,26 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } }; - bool cont_batches_a = nb02 * ne02 == nb03; - bool cont_batches_b = nb12 * ne12 == nb13; - if (cont_batches_a && cont_batches_b) { + const bool cont_batches_dim2_a = nb02 * ne02 == nb03; + const bool cont_batches_dim2_b = nb12 * ne12 == nb13; + const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03; + const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13; + if (cont_batches_dim2_a && cont_batches_dim2_b) { + // A batch is considered contiguous if the dimension 2 is not strided int64_t batches0 = ne02 * ne03; int64_t batches1 = ne12 * ne13; launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float)); + } else if (cont_batches_dim3_a && cont_batches_dim3_b) { + // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1. + int64_t batches0 = ne02 * ne03; + int64_t batches1 = ne12 * ne13; + int64_t str_a3 = nb03 / type_size_src0; + int64_t str_b3 = nb13 / type_size_src1; + launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, + ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1, + str_b3, nb2 / sizeof(float)); } else { for (int64_t b_a = 0; b_a < ne03; b_a++) { const sycl::half *src0_f16_shifted @@ -4215,6 +4227,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added return false; } + // TODO: The configuration below needs more work to be supported with oneDNN + if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) { + return false; + } + // TODO: This specific configuration can fail with oneDNN and needs more debugging + if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 && + a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) { + return false; + } return true; } case GGML_OP_OUT_PROD: