mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
This commit is contained in:
@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
int n_dims;
|
||||
int mode;
|
||||
int n_ctx;
|
||||
int n_ctx; // used to generate positions
|
||||
float fs; // freq_scale
|
||||
float ef; // ext_factor
|
||||
float af; // attn_factor
|
||||
@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
|
||||
}
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
||||
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -1615,7 +1615,7 @@ struct llama_hparams {
|
||||
|
||||
// cparams
|
||||
static constexpr uint32_t n_ctx = 512; // user-specified context size
|
||||
static constexpr uint32_t n_orig_ctx = n_ctx;
|
||||
static constexpr uint32_t n_ctx_orig = n_ctx;
|
||||
|
||||
// batch
|
||||
int32_t n_tokens;
|
||||
@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
for (float ef : { 0.0f, 0.7465f }) {
|
||||
for (float af : { 1.0f, 1.4245f }) {
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
// TODO: ff not supported yet for !neox
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
|
||||
if (all) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
|
||||
}
|
||||
|
||||
for (bool ff : {false, true}) { // freq_factors
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
|
||||
|
||||
if (all) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
|
||||
}
|
||||
|
||||
if (all) {
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
|
||||
@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
|
||||
}
|
||||
}
|
||||
|
||||
all = false;
|
||||
}
|
||||
}
|
||||
|
@ -1465,7 +1465,7 @@ int main(int argc, const char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||
|
||||
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
||||
@ -1505,7 +1505,7 @@ int main(int argc, const char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||
|
||||
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||
|
@ -162,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
||||
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
|
||||
// 100, 101, 102, ..., 172
|
||||
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
|
||||
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
|
||||
// -67, -67, -67, ..., -67
|
||||
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
||||
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
|
||||
|
||||
// 33, 34, 35, ..., 105
|
||||
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
|
||||
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
|
Reference in New Issue
Block a user