|
|
|
@ -1,6 +1,29 @@
|
|
|
|
|
#include "unary.cuh"
|
|
|
|
|
|
|
|
|
|
static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_abs(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dst[i] = fabsf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_sgn(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dst[i] = (T)(x[i] > (T)0.f ? 1.f : ((x[i] < (T)0.f ? -1.f : 0.f)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_neg(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -10,61 +33,67 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = -x[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void step_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_step(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dst[i] = x[i] > 0.0f;
|
|
|
|
|
dst[i] = x[i] > (T)0.0f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
const float GELU_COEF_A = 0.044715f;
|
|
|
|
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_gelu(const T * x, T * dst, const int k) {
|
|
|
|
|
const T GELU_COEF_A = 0.044715f;
|
|
|
|
|
const T SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float xi = x[i];
|
|
|
|
|
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
|
|
|
|
|
T xi = x[i];
|
|
|
|
|
dst[i] = (T)0.5f*xi*((T)1.0f + (T)tanhf(SQRT_2_OVER_PI*xi*((T)1.0f + GELU_COEF_A*xi*xi)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
|
|
|
|
|
const float GELU_QUICK_COEF = -1.702f;
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_gelu_quick(const T * x, T * dst, int k) {
|
|
|
|
|
const T GELU_QUICK_COEF = -1.702f;
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
|
|
|
|
dst[i] = x[i] * ((T)1.0f / ((T)1.0f + (T)expf(GELU_QUICK_COEF * x[i])));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_silu(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
|
|
|
|
dst[i] = x[i] / ((T)1.0f + (T)expf(-x[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void silu_back_f32(
|
|
|
|
|
const float * grad, const float * xf, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_silu_back(
|
|
|
|
|
const T * grad, const T * xf, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float xfi = xf[i];
|
|
|
|
|
const float s = 1.0f / (1.0f + expf(-xfi));
|
|
|
|
|
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
|
|
|
|
|
const T xfi = xf[i];
|
|
|
|
|
const T s = (T)1.0f / ((T)1.0f + (T)expf(-xfi));
|
|
|
|
|
dst[i] = grad[i] * s * ((T)1.0f + xfi * ((T)1.0f - s));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_tanh(const T * x, T * dst, int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
@ -72,7 +101,8 @@ static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
|
|
|
|
dst[i] = tanhf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_relu(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -81,34 +111,38 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = fmaxf(x[i], 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_sigmoid(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = 1.0f / (1.0f + expf(-x[i]));
|
|
|
|
|
dst[i] = (T)1.0f / ((T)1.0f + (T)expf(-x[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_hardsigmoid(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
|
|
|
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_hardswish(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
|
|
|
|
dst[i] = x[i] * (T)fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_exp(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -117,15 +151,17 @@ static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = expf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_leaky_relu(const T * x, T * dst, const int k, const float negative_slope) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
|
|
|
|
|
dst[i] = (T)fmaxf(x[i], 0) + (T)fminf(x[i], 0.0f) * (T)negative_slope;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_sqr(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -134,7 +170,8 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = x[i] * x[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_sqrt(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -143,7 +180,8 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = sqrtf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_sin(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -152,7 +190,8 @@ static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = sinf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_cos(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
@ -161,145 +200,248 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
|
|
|
|
dst[i] = cosf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static __global__ void op_log(const T * x, T * dst, const int k) {
|
|
|
|
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
if (i >= k) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
dst[i] = logf(x[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static void abs_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
|
|
|
|
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_abs<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void sgn_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
|
|
|
|
op_sgn<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static void neg_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
|
|
|
|
op_neg<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static void step_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
|
|
|
|
|
step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_step<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void gelu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
|
|
|
|
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_gelu<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void gelu_quick_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
|
|
|
|
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_gelu_quick<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void silu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
|
|
|
|
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_silu<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
|
|
|
|
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
|
|
|
|
op_silu_back<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void tanh_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
|
|
|
|
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_tanh<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void relu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
|
|
|
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void sigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
|
|
|
|
|
sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_sigmoid<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void hardsigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
|
|
|
|
|
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_hardsigmoid<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void hardswish_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
|
|
|
|
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_hardswish<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void exp_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
|
|
|
|
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_exp<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
|
|
|
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
|
|
|
op_leaky_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void sqr_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
|
|
|
|
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_sqr<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void sqrt_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
|
|
|
|
|
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_sqrt<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void sin_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
|
|
|
|
|
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_sin<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
template <class T>
|
|
|
|
|
static void cos_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
|
|
|
|
|
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
op_cos<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
static void log_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
|
|
|
|
|
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
|
|
|
|
|
op_log<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
abs_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
abs_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
sgn_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
sgn_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
neg_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
neg_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
step_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
step_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
gelu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
gelu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
silu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
silu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
@ -314,179 +456,263 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
gelu_quick_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
gelu_quick_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
tanh_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
tanh_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
sigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
sigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
hardsigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
hardsigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
hardswish_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
hardswish_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
exp_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
exp_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
float negative_slope;
|
|
|
|
|
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
|
|
|
|
|
|
|
|
|
leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
|
|
|
|
} else {
|
|
|
|
|
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
sqr_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
sqr_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
sqrt_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
sqrt_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
sin_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
sin_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
cos_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
cos_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
const void * src0_d = src0->data;
|
|
|
|
|
void * dst_d = dst->data;
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
if (src0->type == GGML_TYPE_F16) {
|
|
|
|
|
log_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
} else {
|
|
|
|
|
log_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|