diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 8b952db43..d0d5ac9a4 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -48,11 +48,11 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { - return { block_index * (traits::qk / traits::qr), 0 }; + return { block_index * (QK4_0 / QR4_0), 0 }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; + return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } @@ -71,14 +71,12 @@ template <> struct block_q_t { } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); - return { nblocks * (QK_K / 2), + auto nblocks = (nrows * (ncols / QK_K)); + return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE), (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } - - constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } }; template <> struct block_q_t { @@ -90,22 +88,23 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { - auto low_bits_index = block_index * (traits::qk / traits::qr); + auto low_bits_index = block_index * (QK_K / QR6_K); // the index of high bits it's after all low bits auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); return { low_bits_index, high_bits_index }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); + auto nblocks = (nrows * (ncols / QK_K)); auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); auto block_scales = total_qs_bytes + block_index * (QK_K / 16); - auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half); return { block_scales, sb_scale }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 0a5d49994..4088ddb54 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -350,11 +350,9 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int & iqs) { - const int ib = ibx_offset.first / (QK_K / 2); - const uint8_t * base = static_cast(vbq); const uint8_t * qs = base + ibx_offset.first; - const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE; + const uint8_t * scs = base + d_offset.first; const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); @@ -427,13 +425,11 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int iqs) { - const int ib = ibx_offset.first / (QK_K / 2); - const uint8_t * base = static_cast(vbq); const uint8_t * ql = base + ibx_offset.first; const uint8_t * qh = base + ibx_offset.second; const int8_t * scales = reinterpret_cast(base + d_offset.first); - const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib; + const ggml_half * d = (const ggml_half *) (base + d_offset.second); const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);