diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9c4e24023..a15f62361 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -538,6 +538,9 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_REGLU, + GGML_UNARY_OP_GEGLU, + GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_COUNT, }; @@ -1086,6 +1089,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 1d3cd009a..2eaa5b581 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2144,6 +2144,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_REGLU: + case GGML_UNARY_OP_GEGLU: + case GGML_UNARY_OP_SWIGLU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index eff4a53e3..be4353955 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3194,6 +3194,327 @@ void ggml_compute_forward_silu_back( } } +// ggml_compute_forward_reglu + +static void ggml_compute_forward_reglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_reglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_reglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_reglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_reglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_reglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_reglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu + +static void ggml_compute_forward_geglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_geglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_geglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(src0->ne[0] / 2 == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_swiglu_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_swiglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -7987,6 +8308,18 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_REGLU: + { + ggml_compute_forward_reglu(params, dst); + } break; + case GGML_UNARY_OP_GEGLU: + { + ggml_compute_forward_geglu(params, dst); + } break; + case GGML_UNARY_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index f7614568e..bfb2d5d36 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * x[i + n]; + } +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 09dbade21..48d13a60d 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f; + } +} + +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f); + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i] * x[i + n]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n]; + } + } +} +#else +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]) * x[i + n]; + } +} +#endif + +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float g = GGML_FP16_TO_FP32(x[i + n]); + y[i] = ggml_table_gelu_f16[i16[i]] * g; + } +} + +void ggml_vec_swiglu_f32(const int n, float * y, const float * x); + +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float v = GGML_FP16_TO_FP32(x[i]); + float g = GGML_FP16_TO_FP32(x[i + n]); + y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f8e7c595b..ae538180b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1103,9 +1103,12 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "REGLU", + "GEGLU", + "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); +static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -2612,6 +2615,57 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_reglu + +struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + +// ggml_geglu + +struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0] / 2, a->ne[1]); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 48589a50a..0cd91bafd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -582,32 +582,19 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_SWIGLU: { - // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_silu(ctx0, x0); - cb(cur, "ffn_silu", il); - - cur = ggml_mul(ctx0, x0, x1); - cb(cur, "ffn_mul", il); + cur = ggml_swiglu(ctx0, cur); + cb(cur, "ffn_swiglu", il); } break; case LLM_FFN_GEGLU: { - // Split into two equal parts - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_gelu(ctx0, x0); - cb(x0, "ffn_gelu", il); - - cur = ggml_mul(ctx0, x0, x1); + cur = ggml_geglu(ctx0, cur); cb(cur, "ffn_geglu", il); } break; + case LLM_FFN_REGLU: + { + cur = ggml_reglu(ctx0, cur); + cb(cur, "ffn_reglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { diff --git a/src/llama-graph.h b/src/llama-graph.h index b433f266d..20f67fc41 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -38,6 +38,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, LLM_FFN_GEGLU, + LLM_FFN_REGLU, }; enum llm_ffn_gate_type {