CUDA: add fused rms norm (#14800)

This commit is contained in:
Aman Gupta
2025-07-23 09:25:42 +08:00
committed by GitHub
parent acd6cb1c41
commit 8c988fa41d
4 changed files with 144 additions and 9 deletions

View File

@@ -55,6 +55,7 @@
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <initializer_list>
#include <limits>
#include <map>
#include <memory>
@@ -2765,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
}
#endif
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
//rms norm only supports F32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
}
//rms_norm kernel assumes contigous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
@@ -2774,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2781,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue;
}
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {