mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
implement unary REGLU/GEGLU/SWIGLU cuda ops
This commit is contained in:
committed by
Akarshan
parent
bb2fda70ae
commit
a1a7b6dfa9
@ -2246,6 +2246,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
ggml_cuda_op_exp(ctx, dst);
|
ggml_cuda_op_exp(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_REGLU:
|
||||||
|
ggml_cuda_op_reglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_GEGLU:
|
||||||
|
ggml_cuda_op_geglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -3039,6 +3048,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_EXP:
|
case GGML_UNARY_OP_EXP:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
|
case GGML_UNARY_OP_REGLU:
|
||||||
|
case GGML_UNARY_OP_GEGLU:
|
||||||
|
case GGML_UNARY_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous_1(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -196,6 +196,62 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
ggml_cuda_op_unary<op_log>(ctx, dst);
|
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* gated ops */
|
||||||
|
|
||||||
|
template <float (*op)(float), typename T>
|
||||||
|
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, const int n, const int o) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// perform base op on first half of row and multiply with gate in second half
|
||||||
|
const int j = (i / n) * o + (i % n);
|
||||||
|
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <float (*op)(float), typename T>
|
||||||
|
static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
||||||
|
unary_gated_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <float (*op)(float)>
|
||||||
|
void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const void * src0_d = src0->data;
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
const int nc = src0->ne[0] / 2;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
GGML_ASSERT(dst->ne[0] >= nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0));
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
unary_gated_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);
|
||||||
|
} else {
|
||||||
|
unary_gated_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
/* silu_back */
|
/* silu_back */
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||||
|
@ -57,3 +57,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
Reference in New Issue
Block a user