diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7e74bbf33..9131507bb 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3211,8 +3211,8 @@ static void ggml_compute_forward_reglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3252,8 +3252,8 @@ static void ggml_compute_forward_reglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3318,8 +3318,8 @@ static void ggml_compute_forward_geglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3359,8 +3359,8 @@ static void ggml_compute_forward_geglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3425,8 +3425,8 @@ static void ggml_compute_forward_swiglu_f32( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3466,8 +3466,8 @@ static void ggml_compute_forward_swiglu_f16( const int nc = src0->ne[0] / 2; const int nr = ggml_nrows(src0); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= nr); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); // rows per thread const int dr = (nr + nth - 1)/nth; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c98564a31..77ef81545 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -230,8 +230,8 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(dst->ne[0] >= nc); - GGML_ASSERT(ggml_nrows(dst) >= ggml_nrows(src0)); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); if (src0->type == GGML_TYPE_F16) { unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);