vulkan: Add fusion support for RMS_NORM+MUL (#14366)

* vulkan: Add fusion support for RMS_NORM+MUL

- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.

* extract some common fusion logic

* fix -Winconsistent-missing-override

* move ggml_can_fuse to a common function

* build fix

* C and C++ versions of can_fuse

* move use count to the graph to avoid data races and double increments when used in multiple threads

* use hash table lookup to find node index

* change use_counts to be indexed by hash table slot

* minimize hash lookups

style fixes

* last node doesn't need single use.
fix type.
handle mul operands being swapped.

* remove redundant parameter

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Jeff Bolz
2025-06-29 02:43:36 -05:00
committed by GitHub
parent 27208bf657
commit bd9c981d72
8 changed files with 263 additions and 56 deletions

View File

@@ -301,6 +301,7 @@ struct ggml_cgraph {
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
struct ggml_tensor ** grad_accs; // accumulators for node gradients
struct ggml_tensor ** leafs; // tensors with constant data
int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
struct ggml_hash_set visited_hash_set;
@@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
// return true if the node's results are only used by N other nodes
// and can be fused into their calculations.
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
// check the use count against how many we're replacing
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
return false;
}
// if node is a view, some other node might be using the intermediate result
// via the view source.
if (node->view_src) {
return false;
}
// If the user requested output for the node, can't fuse
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
return false;
}
return true;
}
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
// and are fusable. Nodes are considered fusable according to this function if:
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
// - all nodes except the last are a src of the following node.
// - all nodes are the same shape.
// TODO: Consider allowing GGML_OP_NONE nodes in between
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
if (node_idx + num_ops > cgraph->n_nodes) {
return false;
}
for (int i = 0; i < num_ops; ++i) {
struct ggml_tensor * node = cgraph->nodes[node_idx + i];
if (node->op != ops[i]) {
return false;
}
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
return false;
}
if (i > 0) {
struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
if (node->src[0] != prev && node->src[1] != prev) {
return false;
}
if (!ggml_are_same_shape(node, prev)) {
return false;
}
}
}
return true;
}
#ifdef __cplusplus
}
#endif
#ifdef __cplusplus
#include <initializer_list>
#include <vector>
// nicer C++ syntax for ggml_can_fuse
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);