mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
implement swapped variants (cpu/cuda)
This commit is contained in:
committed by
Akarshan
parent
f8705a2399
commit
0b2703fc57
@ -1101,23 +1101,37 @@ extern "C" {
|
|||||||
// gated linear unit ops
|
// gated linear unit ops
|
||||||
// A: n columns, r rows,
|
// A: n columns, r rows,
|
||||||
// result is n / 2 columns, r rows,
|
// result is n / 2 columns, r rows,
|
||||||
|
// expects gate in second half of row, unless swapped is true
|
||||||
GGML_API struct ggml_tensor * ggml_glu(
|
GGML_API struct ggml_tensor * ggml_glu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_glu_op op);
|
enum ggml_glu_op op,
|
||||||
|
bool swapped);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_reglu(
|
GGML_API struct ggml_tensor * ggml_reglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_reglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_geglu(
|
GGML_API struct ggml_tensor * ggml_geglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_swiglu(
|
GGML_API struct ggml_tensor * ggml_swiglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// normalize along rows
|
// normalize along rows
|
||||||
GGML_API struct ggml_tensor * ggml_norm(
|
GGML_API struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_reglu_f32(nc,
|
ggml_vec_reglu_f32(nc,
|
||||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_reglu_f16(nc,
|
ggml_vec_reglu_f16(nc,
|
||||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_geglu_f32(nc,
|
ggml_vec_geglu_f32(nc,
|
||||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_geglu_f16(nc,
|
ggml_vec_geglu_f16(nc,
|
||||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_swiglu_f32(nc,
|
ggml_vec_swiglu_f32(nc,
|
||||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16(
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16(
|
|||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
ggml_vec_swiglu_f16(nc,
|
ggml_vec_swiglu_f16(nc,
|
||||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
|
||||||
|
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int k = 0; k < nc; k++) {
|
for (int k = 0; k < nc; k++) {
|
||||||
|
@ -254,27 +254,27 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
|
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
for (; i + 15 < n; i += 16) {
|
for (; i + 15 < n; i += 16) {
|
||||||
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
|
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
|
||||||
}
|
}
|
||||||
#elif defined(__AVX2__) && defined(__FMA__)
|
#elif defined(__AVX2__) && defined(__FMA__)
|
||||||
for (; i + 7 < n; i += 8) {
|
for (; i + 7 < n; i += 8) {
|
||||||
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
|
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
|
||||||
}
|
}
|
||||||
#elif defined(__SSE2__)
|
#elif defined(__SSE2__)
|
||||||
for (; i + 3 < n; i += 4) {
|
for (; i + 3 < n; i += 4) {
|
||||||
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
|
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
||||||
}
|
}
|
||||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
for (; i + 3 < n; i += 4) {
|
for (; i + 3 < n; i += 4) {
|
||||||
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
|
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (; i < n; ++i) {
|
for (; i < n; ++i) {
|
||||||
y[i] = ggml_silu_f32(x[i]) * x[i + n];
|
y[i] = ggml_silu_f32(x[i]) * g[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -905,57 +905,57 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) {
|
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f;
|
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
float v = GGML_FP16_TO_FP32(x[i]);
|
float v = GGML_FP16_TO_FP32(x[i]);
|
||||||
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f);
|
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_GELU_FP16
|
#ifdef GGML_GELU_FP16
|
||||||
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
|
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
uint16_t t;
|
uint16_t t;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
if (x[i] <= -10.0f) {
|
if (x[i] <= -10.0f) {
|
||||||
y[i] = 0.0f;
|
y[i] = 0.0f;
|
||||||
} else if (x[i] >= 10.0f) {
|
} else if (x[i] >= 10.0f) {
|
||||||
y[i] = x[i] * x[i + n];
|
y[i] = x[i] * g[i];
|
||||||
} else {
|
} else {
|
||||||
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
||||||
memcpy(&t, &fp16, sizeof(uint16_t));
|
memcpy(&t, &fp16, sizeof(uint16_t));
|
||||||
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n];
|
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
|
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
y[i] = ggml_gelu_f32(x[i]) * x[i + n];
|
y[i] = ggml_gelu_f32(x[i]) * g[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
const uint16_t * i16 = (const uint16_t *) x;
|
const uint16_t * i16 = (const uint16_t *) x;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
float g = GGML_FP16_TO_FP32(x[i + n]);
|
float v = GGML_FP16_TO_FP32(g[i]);
|
||||||
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g);
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_vec_swiglu_f32(const int n, float * y, const float * x);
|
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
|
||||||
|
|
||||||
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
float v = GGML_FP16_TO_FP32(x[i]);
|
float v = GGML_FP16_TO_FP32(x[i]);
|
||||||
float g = GGML_FP16_TO_FP32(x[i + n]);
|
float w = GGML_FP16_TO_FP32(g[i]);
|
||||||
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g);
|
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
/* gated ops */
|
/* gated ops */
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) {
|
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) {
|
||||||
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
@ -208,13 +208,13 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t
|
|||||||
|
|
||||||
// perform base op on first half of row and multiply with gate in second half
|
// perform base op on first half of row and multiply with gate in second half
|
||||||
const int64_t j = (i / n) * o + (i % n);
|
const int64_t j = (i / n) * o + (i % n);
|
||||||
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
|
dst[i] = (T)(op((float)x[j]) * (float)g[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
|
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
|
||||||
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||||
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
|
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <float (*op)(float)>
|
template <float (*op)(float)>
|
||||||
@ -235,10 +235,26 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
GGML_ASSERT(dst->ne[0] == nc);
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||||
|
|
||||||
|
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F16) {
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
unary_gated_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);
|
unary_gated_cuda<op>(
|
||||||
|
(const half *)src0_d + (swapped ? nc : 0),
|
||||||
|
(const half *)src0_d + (swapped ? 0 : nc),
|
||||||
|
(half *)dst_d,
|
||||||
|
ggml_nelements(dst),
|
||||||
|
nc,
|
||||||
|
src0->nb[1] / sizeof(half),
|
||||||
|
stream);
|
||||||
} else {
|
} else {
|
||||||
unary_gated_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream);
|
unary_gated_cuda<op>(
|
||||||
|
(const float *)src0_d + (swapped ? nc : 0),
|
||||||
|
(const float *)src0_d + (swapped ? 0 : nc),
|
||||||
|
(float *)dst_d,
|
||||||
|
ggml_nelements(dst),
|
||||||
|
nc,
|
||||||
|
src0->nb[1] / sizeof(float),
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2643,13 +2643,15 @@ struct ggml_tensor * ggml_exp_inplace(
|
|||||||
struct ggml_tensor * ggml_glu(
|
struct ggml_tensor * ggml_glu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_glu_op op) {
|
enum ggml_glu_op op,
|
||||||
|
bool swapped) {
|
||||||
GGML_ASSERT(ggml_is_contiguous_1(a));
|
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||||
|
|
||||||
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
|
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
|
||||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
|
||||||
|
|
||||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||||
|
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
|
||||||
|
|
||||||
result->op = GGML_OP_GLU;
|
result->op = GGML_OP_GLU;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
@ -2662,7 +2664,13 @@ struct ggml_tensor * ggml_glu(
|
|||||||
struct ggml_tensor * ggml_reglu(
|
struct ggml_tensor * ggml_reglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
|
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_reglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_geglu
|
// ggml_geglu
|
||||||
@ -2670,7 +2678,13 @@ struct ggml_tensor * ggml_reglu(
|
|||||||
struct ggml_tensor * ggml_geglu(
|
struct ggml_tensor * ggml_geglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
|
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_swiglu
|
// ggml_swiglu
|
||||||
@ -2678,7 +2692,13 @@ struct ggml_tensor * ggml_geglu(
|
|||||||
struct ggml_tensor * ggml_swiglu(
|
struct ggml_tensor * ggml_swiglu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
|
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_norm
|
// ggml_norm
|
||||||
|
@ -1110,16 +1110,18 @@ struct test_glu : public test_case {
|
|||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne_a;
|
const std::array<int64_t, 4> ne_a;
|
||||||
int v; // view (1 : non-contiguous a)
|
int v; // view (1 : non-contiguous a)
|
||||||
|
bool swapped;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne_a, v);
|
return VARS_TO_STR4(type, ne_a, v, swapped);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_glu(ggml_glu_op op,
|
test_glu(ggml_glu_op op,
|
||||||
ggml_type type = GGML_TYPE_F32,
|
ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
|
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
|
||||||
int v = 0)
|
int v = 0,
|
||||||
: op(op), type(type), ne_a(ne_a), v(v) {}
|
bool swapped = false)
|
||||||
|
: op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a;
|
ggml_tensor * a;
|
||||||
@ -1135,7 +1137,7 @@ struct test_glu : public test_case {
|
|||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_glu(ctx, a, op);
|
ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
@ -4009,8 +4011,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||||
for (int v : {0, 1}) {
|
for (int v : {0, 1}) {
|
||||||
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
|
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
|
||||||
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
|
for (bool swapped : {false, true}) {
|
||||||
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
|
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
|
||||||
|
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user