diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 55390d42e..27c9ab74b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -663,22 +663,14 @@ ggml_tensor * llm_graph_context::build_ffn( { // Split into two equal parts int64_t split_point = cur->ne[0] / 2; - ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d( - ctx0, cur, split_point, - cur->ne[1], cur->nb[1], 0 - )); - ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d( - ctx0, cur, split_point, - cur->ne[1], cur->nb[1], - split_point * ggml_element_size(cur) - )); + // TODO: these conts should not be needed + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - // Apply GELU activation function to the first part - output_ffn_up = ggml_gelu(ctx0, output_ffn_up); - cb(output_ffn_up, "ffn_gelu", il); + x0 = ggml_gelu(ctx0, x0); + cb(x0, "ffn_gelu", il); - // Element-wise multiplication between the activated part and the gate part - cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate); + cur = ggml_mul(ctx0, x0, x1); cb(cur, "ffn_geglu", il); } break; }