mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-26 19:23:37 -04:00
* Minimal setup of webgpu backend with dawn. Just prints out the adapter and segfaults * Initialize webgpu device * Making progress on setting up the backend * Finish more boilerplate/utility functions * Organize file and work on alloc buffer * Add webgpu_context to prepare for actually running some shaders * Work on memset and add shader loading * Work on memset polyfill * Implement set_tensor as webgpu WriteBuffer, remove host_buffer stubs since webgpu doesn't support it * Implement get_tensor and buffer_clear * Finish rest of setup * Start work on compute graph * Basic mat mul working * Work on emscripten build * Basic WebGPU backend instructions * Use EMSCRIPTEN flag * Work on passing ci, implement 4d tensor multiplication * Pass thread safety test * Implement permuting for mul_mat and cpy * minor cleanups * Address feedback * Remove division by type size in cpy op * Fix formatting and add github action workflows for vulkan and metal (m-series) webgpu backends * Fix name * Fix macos dawn prefix path
61 lines
1.4 KiB
WebGPU Shading Language
61 lines
1.4 KiB
WebGPU Shading Language
enable f16;
|
|
|
|
@group(0) @binding(0)
|
|
var<storage, read_write> src: array<f32>;
|
|
|
|
@group(0) @binding(1)
|
|
var<storage, read_write> dst: array<f16>;
|
|
|
|
struct Params {
|
|
ne: u32, // total number of elements
|
|
offset_src: u32, // in elements
|
|
offset_dst: u32, // in elements
|
|
|
|
// Strides (in elements) — may be permuted
|
|
stride_src0: u32,
|
|
stride_src1: u32,
|
|
stride_src2: u32,
|
|
stride_src3: u32,
|
|
|
|
stride_dst0: u32,
|
|
stride_dst1: u32,
|
|
stride_dst2: u32,
|
|
stride_dst3: u32,
|
|
|
|
// Logical shape (same for both tensors)
|
|
ne0: u32,
|
|
ne1: u32,
|
|
ne2: u32,
|
|
ne3: u32,
|
|
};
|
|
|
|
@group(0) @binding(2)
|
|
var<uniform> params: Params;
|
|
|
|
override wg_size: u32;
|
|
@compute @workgroup_size(wg_size)
|
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
if (gid.x >= params.ne) {
|
|
return;
|
|
}
|
|
|
|
var i = gid.x;
|
|
|
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
|
|
|
let i2 = i / (params.ne1 * params.ne0);
|
|
i = i % (params.ne1 * params.ne0);
|
|
|
|
let i1 = i / params.ne0;
|
|
let i0 = i % params.ne0;
|
|
|
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
|
|
|
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
|
|
i2 * params.stride_dst2 + i3 * params.stride_dst3;
|
|
|
|
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
|
|
}
|