ggml : update ggml_rope_multi (#12665)

* update `rope_multi`:

1. add `ggml_rope_multi_inplace`;
1. use `GGML_MROPE_SECTIONS` instead of 4.

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Judd
2025-08-13 18:45:15 +08:00
committed by GitHub
parent d8914fc47e
commit c24f4e2688
2 changed files with 63 additions and 41 deletions

View File

@@ -241,6 +241,8 @@
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_MROPE_SECTIONS 4
#define GGML_UNUSED(x) (void)(x)
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -1660,7 +1662,7 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
@@ -1686,6 +1688,22 @@ extern "C" {
float beta_fast,
float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@@ -3885,6 +3885,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
@@ -3898,15 +3899,19 @@ static struct ggml_tensor * ggml_rope_impl(
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
if (mrope_used) {
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
} else {
GGML_ASSERT(a->ne[2] == b->ne[0]);
}
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
int sections[4] = {0, 0, 0, 0};
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -3916,7 +3921,11 @@ static struct ggml_tensor * ggml_rope_impl(
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &sections, sizeof(int)*4);
if (mrope_used) {
memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
} else {
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
}
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
@@ -3934,7 +3943,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
);
}
@@ -3944,7 +3953,7 @@ struct ggml_tensor * ggml_rope_multi(
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
@@ -3953,36 +3962,31 @@ struct ggml_tensor * ggml_rope_multi(
float attn_factor,
float beta_fast,
float beta_slow) {
// Multimodal Rotary Position Embedding
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
return ggml_rope_impl(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(&params[11], sections, sizeof(int)*4);
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
struct ggml_tensor * ggml_rope_multi_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[GGML_MROPE_SECTIONS],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}
struct ggml_tensor * ggml_rope_inplace(
@@ -3992,7 +3996,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
);
}
@@ -4011,7 +4015,7 @@ struct ggml_tensor * ggml_rope_ext(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}
@@ -4031,7 +4035,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}
@@ -4050,7 +4054,7 @@ struct ggml_tensor * ggml_rope_custom(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, false
);
}
@@ -4069,7 +4073,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, true
);
}