diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a15f62361..c663b53f9 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -519,6 +519,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_GLU, + GGML_OP_COUNT, }; @@ -538,13 +540,18 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, - GGML_UNARY_OP_REGLU, - GGML_UNARY_OP_GEGLU, - GGML_UNARY_OP_SWIGLU, GGML_UNARY_OP_COUNT, }; + enum ggml_glu_op { + GGML_GLU_OP_REGLU, + GGML_GLU_OP_GEGLU, + GGML_GLU_OP_SWIGLU, + + GGML_GLU_OP_COUNT, + }; + enum ggml_object_type { GGML_OBJECT_TYPE_TENSOR, GGML_OBJECT_TYPE_GRAPH, @@ -660,6 +667,7 @@ extern "C" { GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -761,6 +769,7 @@ extern "C" { GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); @@ -1089,6 +1098,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // gated linear unit ops + // A: n columns, r rows, + // result is n / 2 columns, r rows, + GGML_API struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op); + GGML_API struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2eaa5b581..a985b6634 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1941,6 +1941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; + case GGML_OP_GLU: + { + ggml_compute_forward_glu(params, tensor); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2144,9 +2148,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_REGLU: - case GGML_UNARY_OP_GEGLU: - case GGML_UNARY_OP_SWIGLU: + { + n_tasks = n_threads; + } break; + default: + GGML_ABORT("fatal error"); + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9131507bb..5ce11915d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8308,15 +8308,31 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; - case GGML_UNARY_OP_REGLU: + default: + { + GGML_ABORT("fatal error"); + } + } +} + +//ggml_compute_forward_glu + +void ggml_compute_forward_glu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_glu_op op = ggml_get_glu_op(dst); + + switch (op) { + case GGML_GLU_OP_REGLU: { ggml_compute_forward_reglu(params, dst); } break; - case GGML_UNARY_OP_GEGLU: + case GGML_GLU_OP_GEGLU: { ggml_compute_forward_geglu(params, dst); } break; - case GGML_UNARY_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU: { ggml_compute_forward_swiglu(params, dst); } break; diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 2d8544d7d..f11a808f6 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -93,6 +93,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 885c56492..aecf5d2ab 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2246,13 +2246,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; - case GGML_UNARY_OP_REGLU: + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: ggml_cuda_op_reglu(ctx, dst); break; - case GGML_UNARY_OP_GEGLU: + case GGML_GLU_OP_GEGLU: ggml_cuda_op_geglu(ctx, dst); break; - case GGML_UNARY_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; default: @@ -3048,9 +3054,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); - case GGML_UNARY_OP_REGLU: - case GGML_UNARY_OP_GEGLU: - case GGML_UNARY_OP_SWIGLU: + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: return ggml_is_contiguous_1(op->src[0]); default: return false; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c57bad32f..c34f07217 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -984,6 +984,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + + "GLU", }; static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); @@ -1080,6 +1082,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + + "glu(x)", }; static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); @@ -1103,12 +1107,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", +}; + +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); + + +static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", "SWIGLU", }; -static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); +static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -1213,11 +1223,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) { return GGML_UNARY_OP_NAME[op]; } +const char * ggml_glu_op_name(enum ggml_glu_op op) { + return GGML_GLU_OP_NAME[op]; +} + const char * ggml_op_desc(const struct ggml_tensor * t) { if (t->op == GGML_OP_UNARY) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); } + if (t->op == GGML_OP_GLU) { + enum ggml_glu_op gop = ggml_get_glu_op(t); + return ggml_glu_op_name(gop); + } return ggml_op_name(t->op); } @@ -1736,6 +1754,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); } +enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_GLU); + return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0); +} + const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } @@ -2615,40 +2638,39 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } -// ggml_reglu +// ggml_glu -struct ggml_tensor * ggml_reglu( +struct ggml_tensor * ggml_glu( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a, + enum ggml_glu_op op) { GGML_ASSERT(ggml_is_contiguous_1(a)); int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); + ggml_set_op_params_i32(result, 0, (int32_t) op); - result->op = GGML_OP_UNARY; + result->op = GGML_OP_GLU; result->src[0] = a; return result; } +// ggml_reglu + +struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU); +} + // ggml_geglu struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a) { - GGML_ASSERT(ggml_is_contiguous_1(a)); - - int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU); - - result->op = GGML_OP_UNARY; - result->src[0] = a; - - return result; + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU); } // ggml_swiglu @@ -2656,17 +2678,7 @@ struct ggml_tensor * ggml_geglu( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - GGML_ASSERT(ggml_is_contiguous_1(a)); - - int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); - - ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU); - - result->op = GGML_OP_UNARY; - result->src[0] = a; - - return result; + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU); } // ggml_norm diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a19437233..bdfa0d3e5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1072,16 +1072,7 @@ struct test_unary : public test_case { ggml_set_name(a, "a"); } - ggml_tensor * out; - if (op == GGML_UNARY_OP_REGLU) { - out = ggml_reglu(ctx, a); - } else if (op == GGML_UNARY_OP_GEGLU) { - out = ggml_geglu(ctx, a); - } else if (op == GGML_UNARY_OP_SWIGLU) { - out = ggml_swiglu(ctx, a); - } else { - out = ggml_unary(ctx, a, op); - } + ggml_tensor * out = ggml_unary(ctx, a, op); ggml_set_name(out, "out"); return out; @@ -1113,6 +1104,51 @@ struct test_unary : public test_case { }; +// GGML_OP_GLU +struct test_glu : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, v); + } + + test_glu(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + } + + ggml_tensor * out = ggml_glu(ctx, a, op); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -3969,6 +4005,16 @@ static std::vector> make_test_cases_eval() { } } + // glu ops + for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { + for (int v : {0, 1}) { + for (int op = 0; op < GGML_GLU_OP_COUNT; op++) { + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); + } + } + } + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); for (ggml_type type : all_types) { for (int b : {1, 7}) {