mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
more constraints and use 64bit ints
ggml-ci
This commit is contained in:
committed by
Akarshan
parent
cfa9c7a47a
commit
70e8b48e6a
@ -199,21 +199,21 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
/* gated ops */
|
/* gated ops */
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int k, const int n, const int o) {
|
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// perform base op on first half of row and multiply with gate in second half
|
// perform base op on first half of row and multiply with gate in second half
|
||||||
const int j = (i / n) * o + (i % n);
|
const int64_t j = (i / n) * o + (i % n);
|
||||||
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
|
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <float (*op)(float), typename T>
|
template <float (*op)(float), typename T>
|
||||||
static void unary_gated_cuda(const T * x, T * dst, const int k, const int n, const int o, cudaStream_t stream) {
|
static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||||
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
|
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,10 +222,12 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const void * src0_d = src0->data;
|
const void * src0_d = src0->data;
|
||||||
void * dst_d = dst->data;
|
void * dst_d = dst->data;
|
||||||
const int nc = src0->ne[0] / 2;
|
const int64_t nc = src0->ne[0] / 2;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
Reference in New Issue
Block a user