From 4d0dcd4a06080e796e6742a88f2ffa7fc41b28b8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Jul 2025 10:15:21 +0300 Subject: [PATCH] cuda : fix rope with partial rotation and non-cont src (#14580) * cuda : fix rope non-cont ggml-ci * cont : fix multi-rope + add test ggml-ci * sycl : try fix ggml-ci * cont : fix sycl + clean-up cuda ggml-ci --- ggml/src/ggml-cuda/rope.cu | 48 ++++++++++++++++--------------------- ggml/src/ggml-sycl/rope.cpp | 33 ++++++++++++------------- tests/test-backend-ops.cpp | 20 +++++++++++----- 3 files changed, 50 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 18f691b2d..d058504cd 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -50,21 +50,19 @@ static __global__ void rope_norm( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0; const int ix = channel_x*s2 + row_x*s1 + i0; + if (i0 >= n_dims) { + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -94,21 +92,19 @@ static __global__ void rope_neox( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -138,21 +134,19 @@ static __global__ void rope_multi( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index e44c6b6ef..1b60226dc 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0; const int i2 = channel0 * s2 + row0 * s1 + i0; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0 / 2; const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const } const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = (row_dst * ne0) + (i0 / 2); const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix); + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 652856a35..b54bcc8a3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5323,12 +5323,12 @@ static std::vector> make_test_cases_eval() { for (bool fw : {true, false}) { // fw == forward bool all = true; - for (float v : { 0, 1 }) { - for (float fs : { 1.0f, 1.4245f }) { - for (float ef : { 0.0f, 0.7465f }) { - for (float af : { 1.0f, 1.4245f }) { - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (bool ff : {false, true}) { // freq_factors + for (float fs : { 1.0f, 1.4245f }) { + for (float ef : { 0.0f, 0.7465f }) { + for (float af : { 1.0f, 1.4245f }) { + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (bool ff : {false, true}) { // freq_factors + for (float v : { 0, 1 }) { test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { @@ -5341,13 +5341,21 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) }