llama : reorganize source code + improve CMake (#8006)

* scripts : update sync [no ci]

* files : relocate [no ci]

* ci : disable kompute build [no ci]

* cmake : fixes [no ci]

* server : fix mingw build

ggml-ci

* cmake : minor [no ci]

* cmake : link math library [no ci]

* cmake : build normal ggml library (not object library) [no ci]

* cmake : fix kompute build

ggml-ci

* make,cmake : fix LLAMA_CUDA + replace GGML_CDEF_PRIVATE

ggml-ci

* move public backend headers to the public include directory (#8122)

* move public backend headers to the public include directory

* nix test

* spm : fix metal header

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* scripts : fix sync paths [no ci]

* scripts : sync ggml-blas.h [no ci]

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov
2024-06-26 18:33:02 +03:00
committed by GitHub
parent 8854044561
commit f3f65429c4
345 changed files with 2555 additions and 1937 deletions

View File

@@ -0,0 +1,104 @@
#include "argsort.cuh"
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols_pad) {
return;
}
const float * x_row = x + row * ncols;
extern __shared__ int dst_row[];
// initialize indices
dst_row[col] = col;
__syncthreads();
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ASSERT(false);
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}