sycl: refactor quantization to q8_1 (#14815)

* sycl: quantization to q8_1 refactor

* Refactored src1 copy logic in op_mul_mat
This commit is contained in:
Alberto Cabrera Pérez
2025-07-28 11:05:53 +01:00
committed by GitHub
parent a5771c9eea
commit afc0e89698
3 changed files with 184 additions and 206 deletions

View File

@@ -28,6 +28,7 @@
#include "mmvq.hpp"
#include "norm.hpp"
#include "outprod.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "rope.hpp"
#include "set_rows.hpp"

View File

@@ -44,6 +44,7 @@
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml.h"
static bool g_sycl_loaded = false;
@@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
template<int QUANT_BLOCK_TILE>
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
const sycl::nd_item<3> &item_ct1) {
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
if (ix >= kx_padded) {
return;
}
const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int i_padded = iy*kx_padded + ix;
block_q8_1 * y = (block_q8_1 *) vy;
const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
TC zeros;
TQ qzeros;
#pragma unroll
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
{
zeros[i] = 0.f;
qzeros[i] = 0;
}
const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
float sum = xi[0];
float amax = sycl::fabs(xi[0]);
#pragma unroll
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
{
sum += xi[i];
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
}
sum = warp_reduce_sum(sum, item_ct1);
amax = warp_reduce_max(amax, item_ct1);
const float d = amax / 127;
TQ q = qzeros;
if (amax != 0.0f)
{
#pragma unroll
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
q[i] = sycl::round(xi[i] / d);
}
}
*(TQ *)&y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
}
template <int ElementsPerWI>
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
/*
Quantizes and reorders the resultant q8 tensor in a per row fashion
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
*/
auto subgroup_id = it.get_group(0);
auto wi_id = it.get_local_id(0);
const int num_blocks_per_row = kx / QK8_1;
auto row = subgroup_id / num_blocks_per_row;
auto col = subgroup_id % num_blocks_per_row;
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
sycl::vec<float, ElementsPerWI> wi_f32_vals;
sycl::vec<int8_t, ElementsPerWI> quantized_values;
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
float sum = 0.0f;
float amax = 0.0f;
#pragma unroll(ElementsPerWI)
for (int i = 0; i < ElementsPerWI; i++) {
sum += wi_f32_vals[i];
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
quantized_values[i] = 0;
}
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
float d = amax == 0 ? 1 : amax / 127;
#pragma unroll(ElementsPerWI)
for (int i = 0; i < ElementsPerWI; i++) {
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
}
d = amax == 0 ? 0 : d;
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
if (wi_id == 0) {
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
}
}
static void mul_mat_p021_f16_f32(
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel(
o_ptr[cur_oh * ow + cur_ow] = res;
}
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
bool reorder_q8_tensor, queue_ptr stream) {
if (reorder_q8_tensor) {
auto local_range = std::size_t(WARP_SIZE);
auto num_quant_blocks = ky * (kx / QK8_1);
auto global_range = num_quant_blocks * local_range;
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
});
} else {
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
const sycl::range<3> num_blocks(1, ky, block_num_x);
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
static_assert(QK8_1 % WARP_SIZE == 0);
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
});
}
}
}
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
float *dst, const int ncols_x,
@@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
peer_access_enabled = enable_peer_access;
}
template <template <int> typename quantize_f>
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
ggml_sycl_op_mul_mat_t op,
const bool convert_src1_to_q8_1) try {
ggml_sycl_op_mul_mat_t op) try {
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
@@ -2470,6 +2331,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
}
}
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
continue;
@@ -2495,20 +2358,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
}
if (convert_src1_to_q8_1) {
if constexpr(quantize_enabled) {
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
if (src1_on_device && src1_is_contiguous) {
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
/*num_src=*/2, " : converting src1 to Q8_1");
quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
/*
DPCT1010:90: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced with 0. You need to
rewrite this code.
*/
SYCL_CHECK(0);
try {
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
} catch (sycl::exception const &exc) {
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
}
}
@@ -2524,11 +2386,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
// here an event is recorded that signals that the main device has finished calculating the input data
if (split && used_devices > 1) {
ggml_sycl_set_device(ctx.device);
/*
DPCT1024:91: The original code returned the error code that was further
consumed by the program logic. This original code was replaced with 0.
You may need to rewrite the program logic consuming the error code.
*/
SYCL_CHECK(CHECK_TRY_ERROR(
*src0_extra->events[ctx.device][0] =
ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2552,11 +2409,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
// wait for main GPU data if necessary
if (split && (i != ctx.device || is != 0)) {
/*
DPCT1009:163: SYCL uses exceptions to report errors and does not
use the error codes. The original code was commented out and a
warning string was inserted. You need to rewrite this code.
*/
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
{*src0_extra->events[ctx.device][0]})));
}
@@ -2582,39 +2434,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
// copy src0, src1 to device if necessary
if (src1_is_contiguous) {
if (i != ctx.device) {
if (convert_src1_to_q8_1) {
if constexpr (quantize_enabled) {
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
src1_ddq_i, src1_ddq_i_source,
src1_ncols * src1_padded_col_size * q8_1_ts /
q8_1_bs).wait()));
SYCL_CHECK(
CHECK_TRY_ERROR(stream
->memcpy(src1_ddq_i, src1_ddq_i_source,
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
.wait()));
} else {
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
src1_ddf_i, src1_ddf_i_source,
src1_ncols * ne10 * sizeof(float))));
SYCL_CHECK(
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
src1_ncols * ne10 * sizeof(float))));
}
}
} else if (src1_on_device && !src1_is_contiguous) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
} else {
GGML_ABORT("fatal error");
}
if (src1_on_device) {
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
src1_col_0 + src1_ncols, stream));
} else {
GGML_ABORT("src1 is non-contiguous and not on device");
}
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
/*num_src=*/2, " : converting src1 to Q8_1");
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
/*
DPCT1010:92: SYCL uses exceptions to report errors and does
not use the error codes. The call was replaced with 0. You
need to rewrite this code.
*/
SYCL_CHECK(0);
if constexpr (quantize_enabled) {
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
/*num_src=*/2, " : converting src1 to Q8_1");
try {
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
src1_padded_col_size, stream);
} catch (const sycl::exception & exc) {
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1);
}
}
}
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
@@ -2626,12 +2481,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
// do the computation
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
/*
DPCT1010:93: SYCL uses exceptions to report errors and does not
use the error codes. The call was replaced with 0. You need to
rewrite this code.
*/
SYCL_CHECK(0);
// copy dst to host or other device if necessary
if (!dst_on_device) {
@@ -2662,12 +2511,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
// add event for the main device to wait on until other device is done
if (split && (i != ctx.device || is != 0)) {
/*
DPCT1024:94: The original code returned the error code that
was further consumed by the program logic. This original
code was replaced with 0. You may need to rewrite the
program logic consuming the error code.
*/
SYCL_CHECK(CHECK_TRY_ERROR(
*src0_extra->events[i][is] =
stream->ext_oneapi_submit_barrier()));
@@ -3351,19 +3194,20 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// KQ + KQV multi-batch
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
constexpr bool convert_src1_to_q8_1 = false;
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
} else if (use_mul_mat_vec_q) {
constexpr bool convert_src1_to_q8_1 = true;
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
if (extra && extra->optimized_feature.reorder) {
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
} else {
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
}
} else if (use_mul_mat_q) {
constexpr bool convert_src1_to_q8_1 = true;
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
} else {
constexpr bool convert_src1_to_q8_1 = false;
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
}
}

View File

@@ -0,0 +1,133 @@
/***************************************************************************
*
* Copyright (C) 2025 Codeplay Software Ltd.
* Copyright (C) 2025 Intel Corporation
*
* MIT License
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* quantize.hpp
*
* Description:
* Sycl backend specific quantization functions
**************************************************************************/
#pragma once
#include <sycl/nd_item.hpp>
#include "ggml-sycl/dpct/helper.hpp"
template <int ElementsPerWI>
__dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x,
sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d,
float & sum, const sycl::nd_item<1> & it) {
auto subgroup_id = it.get_group(0);
auto wi_id = it.get_local_id(0);
sycl::vec<float, ElementsPerWI> wi_f32_vals;
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
float amax = 0.0f;
#pragma unroll(ElementsPerWI)
for (int i = 0; i < ElementsPerWI; i++) {
sum += wi_f32_vals[i];
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
quantized_values[i] = 0;
}
sum = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>());
amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>());
d = amax == 0 ? 1 : amax / 127;
#pragma unroll(ElementsPerWI)
for (int i = 0; i < ElementsPerWI; i++) {
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
}
d = amax == 0 ? 0 : d;
}
// No op to control codepath in ggml_sycl_op_mul_mat
template <int ElementsPerWI> struct no_quantize_q8_1 {
void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {}
};
template <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa {
__dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx,
const int kx_padded, const sycl::nd_item<1> & it) const {
/*
Quantizes and reorders the resultant q8 tensor in a per row fashion
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
*/
auto subgroup_id = it.get_group(0);
auto wi_id = it.get_local_id(0);
sycl::vec<int8_t, ElementsPerWI> quantized_values;
float d = 0.0f;
float sum = 0.0f;
quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
const int num_blocks_per_row = kx / QK8_1;
auto row = subgroup_id / num_blocks_per_row;
auto col = subgroup_id % num_blocks_per_row;
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
if (wi_id == 0) {
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
}
}
};
template <int ElementsPerWI> struct quantize_q8_1 {
__dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded,
const sycl::nd_item<1> & it) const {
auto subgroup_id = it.get_group(0);
auto wi_id = it.get_local_id(0);
const int num_blocks_per_row = kx / QK8_1;
auto row = subgroup_id / num_blocks_per_row;
const int pitch = kx_padded / QK8_1;
sycl::vec<int8_t, ElementsPerWI> quantized_values;
float d = 0.0f;
float sum = 0.0f;
quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it);
block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor;
auto block_id = subgroup_id % num_blocks_per_row + row * pitch;
int8_t * qs = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]);
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values;
if (wi_id == 0) {
quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum));
}
}
};
template <template <int> typename quantize_f>
void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
dpct::queue_ptr stream) {
static_assert(QK8_1 % WARP_SIZE == 0);
auto local_range = std::size_t(WARP_SIZE);
auto num_quant_blocks = ky * (kx / QK8_1);
auto global_range = num_quant_blocks * local_range;
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it);
});
}