From 5fd160bbd9d70b94b5b11b0001fd7f477005e4a0 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 6 Aug 2025 15:14:40 -0700 Subject: [PATCH] ggml: Add basic SET_ROWS support in WebGPU (#15137) * Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments --- .github/workflows/build.yml | 2 - ggml/src/ggml-webgpu/ggml-webgpu.cpp | 237 +++++++++++++++--- .../ggml-webgpu/wgsl-shaders/set_rows.wgsl | 82 ++++++ 3 files changed, 286 insertions(+), 35 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 63e40c358..3d4f837e2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -179,7 +179,6 @@ jobs: - name: Test id: cmake_test run: | - export LLAMA_SET_ROWS=0 cd build ctest -L main --verbose --timeout 900 @@ -438,7 +437,6 @@ jobs: - name: Test id: cmake_test run: | - export LLAMA_SET_ROWS=0 cd build # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5009e26a2..ba1addc8d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -19,18 +19,21 @@ #include #ifdef GGML_WEBGPU_DEBUG -# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +# define WEBGPU_DEBUG_BUF_ELEMS 32 #else # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_NUM_PARAM_BUFS 100 -#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_MUL_MAT_WG_SIZE 64 +#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 /* End Constants */ @@ -54,46 +57,42 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -struct webgpu_param_bufs { +struct webgpu_pool_bufs { wgpu::Buffer host_buf; wgpu::Buffer dev_buf; }; // Holds a pool of parameter buffers for WebGPU operations -struct webgpu_param_buf_pool { - std::vector free; +struct webgpu_buf_pool { + std::vector free; std::mutex mutex; std::condition_variable cv; - void init(wgpu::Device device) { - for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) { + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, - host_buf, - WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, - "ggml_webgpu_host_params_buf"); - ggml_webgpu_create_buffer(device, - dev_buf, - WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - "ggml_webgpu_dev_params_buf"); + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); free.push_back({ host_buf, dev_buf }); } } - webgpu_param_bufs alloc_bufs() { + webgpu_pool_bufs alloc_bufs() { std::unique_lock lock(mutex); cv.wait(lock, [this] { return !free.empty(); }); - webgpu_param_bufs bufs = free.back(); + webgpu_pool_bufs bufs = free.back(); free.pop_back(); return bufs; } - void free_bufs(std::vector bufs) { + void free_bufs(std::vector bufs) { std::lock_guard lock(mutex); free.insert(free.end(), bufs.begin(), bufs.end()); cv.notify_all(); @@ -121,10 +120,12 @@ struct webgpu_context_struct { bool device_init = false; - webgpu_param_buf_pool param_buf_pool; + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; wgpu::ComputePipeline memset_pipeline; wgpu::ComputePipeline mul_mat_pipeline; + wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline cpy_pipeline; size_t memset_bytes_per_thread; @@ -136,9 +137,16 @@ struct webgpu_context_struct { std::vector staged_command_bufs; // Parameter buffers associated with the staged command buffers - std::vector staged_param_bufs; + std::vector staged_param_bufs; + // Buffers associated with set_rows operations, used to store potential errors + std::vector staged_set_row_error_bufs; std::vector callback_futures; + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif }; typedef std::shared_ptr webgpu_context; @@ -249,20 +257,55 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { return; } ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (ctx->staged_set_row_error_bufs.size() > 0) { + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + for (auto & error_bufs : ctx->staged_set_row_error_bufs) { + // Copy the error buffer to the host buffer + encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize()); + } + wgpu::CommandBuffer commands = encoder.Finish(); + ctx->queue.Submit(1, &commands); + } + ctx->staged_command_bufs.clear(); - std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); + std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); + std::vector staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs); // Free the staged parameter buffers once the submission completes - wgpu::Future f = ctx->queue.OnSubmittedWorkDone( + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); } - // Free the staged parameter buffers + // Free the staged buffers ctx->param_buf_pool.free_bufs(staged_param_bufs); }); - ctx->callback_futures.push_back({ f }); + ctx->callback_futures.push_back({ p_f }); + + // Check for errrors in SET_ROWS operations + for (auto & error_bufs : staged_set_row_error_bufs) { + wgpu::Future f = error_bufs.host_buf.MapAsync( + wgpu::MapMode::Read, + 0, + error_bufs.host_buf.GetSize(), + wgpu::CallbackMode::AllowSpontaneous, + [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data); + } else { + const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->set_rows_error_buf_pool.free_bufs({ error_bufs }); + } + }); + ctx->callback_futures.push_back({ f }); + } } static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, @@ -283,13 +326,34 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, UINT64_MAX); } +#ifdef GGML_WEBGPU_DEBUG +// This function adds debugging information to shaders, as WebGPU does not support printing directly. +// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and +// debug statements in the shader, and then call this function after encoding the commands and submitting them. +static void ggml_backend_webgpu_debug(webgpu_context & ctx) { + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ctx->queue.Submit(1, &commands); + + ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); + const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange(); + std::cout << "debug data:"; + for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) { + std::cout << " " << i << ": " << debug_data[i]; + } + std::cout << "\n"; + ctx->debug_host_buf.Unmap(); +} +#endif + static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, wgpu::ComputePipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, bool submit_and_wait = false) { - webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); + webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); @@ -429,6 +493,76 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); } +static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + // For set rows specifically, we need to check if src and idx are empty tensors. + if (ggml_is_empty(src) || ggml_is_empty(idx)) { + return; + } + + webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); + if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + error_bufs.host_buf.Unmap(); + } + + size_t src_offset = ggml_backend_webgpu_tensor_offset(src); + // assumes power of 2 offset alignment + size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + // align to minimum offset alignment + src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx); + size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); + size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + + std::vector params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)), + (uint32_t) (idx_misalignment / ggml_type_size(idx->type)), + (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Shape of src + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + // Shape of idx + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src), + .offset = ggml_backend_webgpu_tensor_offset(src), + .size = ggml_nbytes(src) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(idx), + .offset = ggml_backend_webgpu_tensor_offset(idx), + .size = ggml_nbytes(idx) }, + { .binding = 2, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = ggml_backend_webgpu_tensor_offset(dst), + .size = ggml_nbytes(dst) }, + { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + }; + + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + + std::lock_guard lock(ctx->mutex); + ctx->staged_set_row_error_bufs.push_back(error_bufs); + + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x); +} + static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { std::vector params = { (uint32_t) dst->ne[1], // number of rows in result (M) @@ -487,6 +621,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_webgpu_cpy(ctx, src0, node); break; } + case GGML_OP_SET_ROWS: + { + ggml_webgpu_set_rows(ctx, src0, src1, node); + break; + } case GGML_OP_MUL_MAT: { ggml_webgpu_mul_mat(ctx, src0, src1, node); @@ -771,6 +910,14 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); } +static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + ggml_webgpu_create_pipeline( + webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants); +} + static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants(1); constants[0].key = "wg_size"; @@ -827,11 +974,35 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); // Create buffer pool for shader parameters - webgpu_ctx->param_buf_pool.init(webgpu_ctx->device); + webgpu_ctx->param_buf_pool.init(webgpu_ctx->device, + WEBGPU_NUM_PARAM_BUFS, + WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device, + WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); ggml_webgpu_init_memset_pipeline(webgpu_ctx); ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_set_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); + +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->device, + webgpu_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->device, + webgpu_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, + "debug_dev_buf"); +#endif webgpu_ctx->device_init = true; } @@ -882,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_VIEW: case GGML_OP_PERMUTE: return true; - case GGML_OP_CPY: + case GGML_OP_CPY | GGML_OP_SET_ROWS: return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_MUL_MAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl new file mode 100644 index 000000000..4bd6f94a2 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -0,0 +1,82 @@ +enable f16; + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var error: atomic; + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(4) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + return; + } + var i = gid.x; + let i_src3 = i / (params.ne2 * params.n_rows); + let i_dst3 = i / (params.ne2 * 3); + + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; + + let idx_high_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1]; + + if (idx_low_val != 0) { + // Upper bits of index are not zero, output will be incorrect + atomicStore(&error, 1); + return; + } + + let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + + for (var i: u32 = 0; i < params.ne0; i++) { + dst[i_dst_row + i] = f16(src[i_src_row + i]); + } +}