diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index f78a36ddf..f839a42bc 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -30,6 +30,7 @@ #include "outprod.hpp" #include "quants.hpp" #include "rope.hpp" +#include "set_rows.hpp" #include "softmax.hpp" #include "tsembd.hpp" #include "wkv.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index cd15bbdb2..65b26fd02 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -41,6 +41,7 @@ #include "ggml-sycl/element_wise.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml.h" @@ -3605,6 +3606,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_sycl_get_rows(ctx, dst); break; + case GGML_OP_SET_ROWS: + ggml_sycl_op_set_rows(ctx, dst); + break; case GGML_OP_DUP: ggml_sycl_dup(ctx, dst); break; @@ -4299,7 +4303,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 - return false; + return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64)); } break; case GGML_OP_CPY: { diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp new file mode 100644 index 000000000..4a76a63d3 --- /dev/null +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -0,0 +1,131 @@ +#include "set_rows.hpp" + +namespace utils { +template +static constexpr bool is_arithmetic_v() { + return std::is_arithmetic_v || std::is_same_v || std::is_same_v; +} +} +template +static inline std::enable_if_t() && utils::is_arithmetic_v(), void> +convert (const char* src, char* dst) { + auto src_val = *reinterpret_cast(src); + auto dst_val = sycl::vec(src_val).template convert()[0]; + *reinterpret_cast(dst) = dst_val;; +} + +template +static void k_set_rows( + const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + const size_t src_type_size, const size_t dst_type_size, + const sycl::nd_item<3> & item_ct1) { + + const int i03 = item_ct1.get_group(0); + const int i02 = item_ct1.get_group(1); + const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index + + if (i01 >= ne01) { + return; + } + + const int i12 = i03 % ne12; + const int i11 = i02 % ne11; + const int i10 = i01; + + const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12})); + + const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03}); + char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + + for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) { + const char * src_elem = src0_row + col * src_type_size; + char * dst_elem = dst_row_ptr + col * dst_type_size; + convert(src_elem, dst_elem); + } +} + +template +static void set_rows_sycl( + const char * src0_d, const int64_t * src1_d, char * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + const size_t src_type_size, const size_t dst_type_size, + queue_ptr stream) { + + constexpr int max_threads_per_row = 64; // KEEPING 64 for now + const int threads_per_row = std::min((int)ne00, max_threads_per_row); + + constexpr int max_threads_per_block = 64; + const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row); + + const sycl::range<3> block_size(1, rows_per_block, threads_per_row); + const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block); + + sycl_parallel_for( + stream, + sycl::nd_range<3>(grid_size * block_size, block_size), + [=](sycl::nd_item<3> item_ct1) { + k_set_rows( + src0_d, src1_d, dst_d, + ne00, ne01, ne11, ne12, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + src_type_size, dst_type_size, + item_ct1 + ); + } + ); +} + + +void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t * src1_dd = static_cast(src1->data); + + dpct::queue_ptr stream = ctx.stream(); + switch (dst->type) { + case GGML_TYPE_F32: + set_rows_sycl( + (const char *)src0->data, src1_dd, (char *)dst->data, + ne00, ne01, ne02, ne03, + ne11, ne12, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + sizeof(float), sizeof(float), + stream + ); + break; + case GGML_TYPE_F16: + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + set_rows_sycl( + (const char *)src0->data, src1_dd, (char *)dst->data, + ne00, ne01, ne02, ne03, + ne11, ne12, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + sizeof(float), sizeof(sycl::half), + stream + ); + break; + default: + GGML_ABORT("Unsupported tensor type!"); + break; + } +} diff --git a/ggml/src/ggml-sycl/set_rows.hpp b/ggml/src/ggml-sycl/set_rows.hpp new file mode 100644 index 000000000..27fcc8f90 --- /dev/null +++ b/ggml/src/ggml-sycl/set_rows.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_SET_ROWS_HPP +#define GGML_SYCL_SET_ROWS_HPP + +#include "common.hpp" + +void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_SET_ROWS_HPP