add OP sigmoid (#12056)

Co-authored-by: Judd <foldl@boxvest.com>
This commit is contained in:
Judd
2025-02-25 19:32:20 +08:00
committed by GitHub
parent 393fca629e
commit c132239bfb
3 changed files with 35 additions and 0 deletions

View File

@ -249,6 +249,7 @@ struct vk_device_struct {
vk_pipeline pipeline_relu_f32; vk_pipeline pipeline_relu_f32;
vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_tanh_f32; vk_pipeline pipeline_tanh_f32;
vk_pipeline pipeline_sigmoid_f32;
vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@ -2189,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@ -5342,6 +5344,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_tanh_f32; return ctx->device->pipeline_tanh_f32;
} }
break; break;
case GGML_UNARY_OP_SIGMOID:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sigmoid_f32;
}
break;
default: default:
break; break;
} }
@ -7335,6 +7342,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
break; break;
default: default:
return false; return false;
@ -7551,6 +7559,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
break; break;
default: default:
@ -7738,6 +7747,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
buf = tensor->buffer; buf = tensor->buffer;
break; break;
default: default:
@ -8439,6 +8449,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
default: default:
return false; return false;
@ -9105,6 +9116,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
break; break;
case GGML_UNARY_OP_SIGMOID:
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
break;
default: default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View File

@ -0,0 +1,20 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
}

View File

@ -482,6 +482,7 @@ void process_shaders() {
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});