mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-20 06:36:48 -04:00
cuBLAS: fall back to pageable memory if pinned alloc fails (#1233)
* cuBLAS: fall back to pageable memory if pinned alloc fails * cuBLAS: do not use pinned memory if env variable GGML_CUDA_NO_PINNED is set
This commit is contained in:
14
ggml-cuda.cu
14
ggml-cuda.cu
@@ -355,8 +355,18 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
|
||||
}
|
||||
|
||||
void * ggml_cuda_host_malloc(size_t size) {
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
|
||||
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void * ptr = nullptr;
|
||||
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
||||
size/1024.0/1024.0, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user