mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
metal : add memory pool for temp allocs (#12850)
* metal : add memory pool for temp allocs (wip) [no ci] * cont : free buffers from the heap * cont : resize heap [no ci] * cont : refactor heap [no ci] * cont : heap for each cmd buffer [no ci] * cont : fix free * wip * cont : fix alignment [no ci] * cont : not working .. [no ci] * cont : heap allocation now works [no ci] * cont : use MTLHeapTypePlacement ggml-ci * metal : use dynamic MTLHeap allocations ggml-ci * metal : add comments * metal : disable softmax use of mem_pool ggml-ci * metal : final touches
This commit is contained in:
@ -490,7 +490,259 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_COUNT
|
GGML_METAL_KERNEL_TYPE_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// ggml_metal_heap
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_metal_heap {
|
||||||
|
// number of times the heap was unused
|
||||||
|
int n_unused;
|
||||||
|
|
||||||
|
// total number of buffer allocations in this heap across all computes
|
||||||
|
int64_t n_alloc;
|
||||||
|
|
||||||
|
// current offset in the heap - we reset this after each node in order to reuse the memory
|
||||||
|
size_t offs;
|
||||||
|
|
||||||
|
// the currently allocated MTLBuffer objects in this heap
|
||||||
|
id<MTLHeap> obj;
|
||||||
|
|
||||||
|
NSMutableArray * bufs;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
||||||
|
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
|
||||||
|
|
||||||
|
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
||||||
|
desc.storageMode = MTLStorageModePrivate;
|
||||||
|
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
||||||
|
desc.type = MTLHeapTypePlacement;
|
||||||
|
desc.size = size;
|
||||||
|
|
||||||
|
heap->n_unused = 0;
|
||||||
|
heap->n_alloc = 0;
|
||||||
|
|
||||||
|
heap->obj = [device newHeapWithDescriptor:desc];
|
||||||
|
if (!heap->obj) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
||||||
|
|
||||||
|
free(heap);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
[desc release];
|
||||||
|
|
||||||
|
heap->bufs = [[NSMutableArray alloc] init];
|
||||||
|
|
||||||
|
return heap;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
|
||||||
|
heap->offs = 0;
|
||||||
|
|
||||||
|
// count how many graph computes the heap ended up being unused
|
||||||
|
if ([heap->bufs count] > 0) {
|
||||||
|
heap->n_unused = 0;
|
||||||
|
} else {
|
||||||
|
heap->n_unused++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (id<MTLBuffer> buf in heap->bufs) {
|
||||||
|
[buf release];
|
||||||
|
}
|
||||||
|
[heap->bufs removeAllObjects];
|
||||||
|
|
||||||
|
// tell the OS that it can reuse this memory if needed
|
||||||
|
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
||||||
|
if (heap == nil) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
[heap->obj release];
|
||||||
|
[heap->bufs release];
|
||||||
|
|
||||||
|
free(heap);
|
||||||
|
}
|
||||||
|
|
||||||
|
@interface ggml_metal_heap_ptr : NSObject
|
||||||
|
|
||||||
|
@property (nonatomic, assign) struct ggml_metal_heap * data;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation ggml_metal_heap_ptr
|
||||||
|
@end
|
||||||
|
|
||||||
|
//
|
||||||
|
// ggml_metal_mem_pool
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_metal_mem_pool {
|
||||||
|
id<MTLDevice> device;
|
||||||
|
|
||||||
|
int n_heaps; // total number of heaps ever created (including those that were removed)
|
||||||
|
|
||||||
|
NSMutableArray * heaps;
|
||||||
|
NSMutableArray * heaps_to_remove;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
|
||||||
|
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
|
||||||
|
|
||||||
|
mem_pool->n_heaps = 0;
|
||||||
|
|
||||||
|
mem_pool->heaps = [[NSMutableArray alloc] init];
|
||||||
|
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
||||||
|
|
||||||
|
return mem_pool;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
||||||
|
|
||||||
|
size_t size_all = 0;
|
||||||
|
size_t size_cur = 0;
|
||||||
|
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
||||||
|
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
||||||
|
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
||||||
|
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
||||||
|
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
||||||
|
|
||||||
|
if ([ptr.data->bufs count] > 0) {
|
||||||
|
size_cur += [ptr.data->obj size];
|
||||||
|
}
|
||||||
|
size_all += [ptr.data->obj size];
|
||||||
|
|
||||||
|
ggml_metal_heap_free(ptr.data);
|
||||||
|
[ptr release];
|
||||||
|
}
|
||||||
|
[mem_pool->heaps release];
|
||||||
|
[mem_pool->heaps_to_remove release];
|
||||||
|
|
||||||
|
if (size_all > 0) {
|
||||||
|
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
||||||
|
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(mem_pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
||||||
|
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
// if the heap hasn't been used for a while, remove it
|
||||||
|
if (heap->n_unused >= 128) {
|
||||||
|
[mem_pool->heaps_to_remove addObject:@(i)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mem_pool->heaps_to_remove.count > 0) {
|
||||||
|
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
||||||
|
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
||||||
|
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
ggml_metal_heap_free(heap);
|
||||||
|
|
||||||
|
[mem_pool->heaps removeObjectAtIndex:index];
|
||||||
|
[ptr release];
|
||||||
|
}
|
||||||
|
|
||||||
|
[mem_pool->heaps_to_remove removeAllObjects];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
ptr.data->offs = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
||||||
|
const size_t alignment = 32;
|
||||||
|
|
||||||
|
const size_t size_aligned = GGML_PAD(size, alignment);
|
||||||
|
|
||||||
|
// try one of the existing heaps
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
if (heap->offs + size_aligned <= [heap->obj size]) {
|
||||||
|
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
||||||
|
// it cannot free the memory used by the heap
|
||||||
|
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||||
|
if ([heap->bufs count] == 0) {
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||||
|
if (buf == nil) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
heap->n_alloc++;
|
||||||
|
heap->offs += size_aligned;
|
||||||
|
|
||||||
|
[heap->bufs addObject:buf];
|
||||||
|
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new heap that can fit this buffer
|
||||||
|
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
|
||||||
|
if (heap == NULL) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
||||||
|
|
||||||
|
heap_ptr.data = heap;
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||||
|
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||||
|
if (buf == nil) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
heap->n_alloc++;
|
||||||
|
heap->offs += size_aligned;
|
||||||
|
|
||||||
|
[heap->bufs addObject:buf];
|
||||||
|
|
||||||
|
[mem_pool->heaps addObject:heap_ptr];
|
||||||
|
mem_pool->n_heaps++;
|
||||||
|
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_metal_command_buffer {
|
||||||
|
id<MTLCommandBuffer> obj;
|
||||||
|
|
||||||
|
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
||||||
|
struct ggml_metal_mem_pool * mem_pool;
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_backend_metal_context {
|
struct ggml_backend_metal_context {
|
||||||
|
id<MTLDevice> device;
|
||||||
id<MTLCommandQueue> queue;
|
id<MTLCommandQueue> queue;
|
||||||
|
|
||||||
dispatch_queue_t d_queue;
|
dispatch_queue_t d_queue;
|
||||||
@ -515,7 +767,7 @@ struct ggml_backend_metal_context {
|
|||||||
void (^encode_async)(size_t ith);
|
void (^encode_async)(size_t ith);
|
||||||
|
|
||||||
// n_cb command buffers + 1 used by the main thread
|
// n_cb command buffers + 1 used by the main thread
|
||||||
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||||
|
|
||||||
// abort ggml_metal_graph_compute if callback returns true
|
// abort ggml_metal_graph_compute if callback returns true
|
||||||
ggml_abort_callback abort_callback;
|
ggml_abort_callback abort_callback;
|
||||||
@ -705,8 +957,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||||
|
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||||
|
|
||||||
|
ctx->device = device;
|
||||||
ctx->queue = [device newCommandQueue];
|
ctx->queue = [device newCommandQueue];
|
||||||
if (ctx->queue == nil) {
|
if (ctx->queue == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||||
@ -768,7 +1022,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
ctx->gf = nil;
|
ctx->gf = nil;
|
||||||
ctx->encode_async = nil;
|
ctx->encode_async = nil;
|
||||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
ctx->command_buffers[i] = nil;
|
ctx->cmd_bufs[i].obj = nil;
|
||||||
|
|
||||||
|
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
||||||
|
ctx->cmd_bufs[i].mem_pool->device = device;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||||
@ -1181,6 +1438,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|||||||
|
|
||||||
[ctx->queue release];
|
[ctx->queue release];
|
||||||
|
|
||||||
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
|
// ctx->cmd_bufs[i].obj is auto released
|
||||||
|
|
||||||
|
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch_release(ctx->d_queue);
|
dispatch_release(ctx->d_queue);
|
||||||
|
|
||||||
free(ctx);
|
free(ctx);
|
||||||
@ -1486,10 +1749,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_metal_encode_node(
|
static bool ggml_metal_encode_node(
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
int idx,
|
int idx,
|
||||||
id<MTLComputeCommandEncoder> encoder) {
|
id<MTLComputeCommandEncoder> encoder,
|
||||||
|
struct ggml_metal_mem_pool * mem_pool) {
|
||||||
struct ggml_backend_metal_context * ctx = backend->context;
|
struct ggml_backend_metal_context * ctx = backend->context;
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||||
|
|
||||||
@ -1505,7 +1769,7 @@ static void ggml_metal_encode_node(
|
|||||||
struct ggml_tensor * dst = node;
|
struct ggml_tensor * dst = node;
|
||||||
|
|
||||||
if (ggml_is_empty(dst)) {
|
if (ggml_is_empty(dst)) {
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
@ -1516,7 +1780,7 @@ static void ggml_metal_encode_node(
|
|||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
{
|
{
|
||||||
// noop -> next node
|
// noop -> next node
|
||||||
} return;
|
} return true;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
@ -1527,6 +1791,8 @@ static void ggml_metal_encode_node(
|
|||||||
GGML_ABORT("unsupported op");
|
GGML_ABORT("unsupported op");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_mem_pool_clear(mem_pool);
|
||||||
|
|
||||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
@ -2173,6 +2439,56 @@ static void ggml_metal_encode_node(
|
|||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
// use this branch to test the ggml_metal_mem_pool functionality
|
||||||
|
#if 0
|
||||||
|
// cpy to tmp buffer in MTLHeap
|
||||||
|
|
||||||
|
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
||||||
|
if (!h_src0) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
offs_src0 = 0;
|
||||||
|
|
||||||
|
ggml_metal_kargs_cpy args_cpy = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ ne00,
|
||||||
|
/*.ne1 =*/ ne01,
|
||||||
|
/*.ne2 =*/ ne02,
|
||||||
|
/*.ne3 =*/ ne03,
|
||||||
|
/*.nb0 =*/ nb00,
|
||||||
|
/*.nb1 =*/ nb01,
|
||||||
|
/*.nb2 =*/ nb02,
|
||||||
|
/*.nb3 =*/ nb03,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
||||||
|
}
|
||||||
|
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||||
|
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
||||||
|
|
||||||
|
#else
|
||||||
|
id<MTLBuffer> h_src0 = id_src0;
|
||||||
|
#endif
|
||||||
|
// softmax
|
||||||
|
|
||||||
ggml_metal_kargs_soft_max args = {
|
ggml_metal_kargs_soft_max args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne01,
|
||||||
@ -2185,11 +2501,11 @@ static void ggml_metal_encode_node(
|
|||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
||||||
if (id_src1) {
|
if (id_src1) {
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||||
@ -4601,6 +4917,8 @@ static void ggml_metal_encode_node(
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_status ggml_metal_graph_compute(
|
static enum ggml_status ggml_metal_graph_compute(
|
||||||
@ -4654,25 +4972,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the main thread commits the first few commands immediately
|
// the main thread commits the first few commands immediately
|
||||||
// command_buffer[n_cb]
|
// cmd_buf[n_cb]
|
||||||
{
|
{
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
ctx->command_buffers[n_cb] = command_buffer;
|
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||||
|
|
||||||
[command_buffer enqueue];
|
[cmd_buf enqueue];
|
||||||
ctx->encode_async(n_cb);
|
ctx->encode_async(n_cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the rest of the command buffers asynchronously
|
// prepare the rest of the command buffers asynchronously
|
||||||
// command_buffer[0.. n_cb)
|
// cmd_buf[0.. n_cb)
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
ctx->command_buffers[cb_idx] = command_buffer;
|
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||||
|
|
||||||
// always enqueue the first two command buffers
|
// always enqueue the first two command buffers
|
||||||
// enqueue all of the command buffers if we don't need to abort
|
// enqueue all of the command buffers if we don't need to abort
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
[command_buffer enqueue];
|
[cmd_buf enqueue];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4681,14 +4999,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
// wait for completion and check status of each command buffer
|
// wait for completion and check status of each command buffer
|
||||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||||
{
|
{
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||||
[command_buffer waitUntilCompleted];
|
[cmd_buf waitUntilCompleted];
|
||||||
|
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [cmd_buf status];
|
||||||
if (status != MTLCommandBufferStatusCompleted) {
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||||
if (status == MTLCommandBufferStatusError) {
|
if (status == MTLCommandBufferStatusError) {
|
||||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
@ -4696,20 +5014,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_cb; ++i) {
|
for (int i = 0; i < n_cb; ++i) {
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||||
[command_buffer waitUntilCompleted];
|
[cmd_buf waitUntilCompleted];
|
||||||
|
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [cmd_buf status];
|
||||||
if (status != MTLCommandBufferStatusCompleted) {
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||||
if (status == MTLCommandBufferStatusError) {
|
if (status == MTLCommandBufferStatusError) {
|
||||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||||
if (!next_buffer) {
|
if (!next_buffer) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5092,8 +5410,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
|
|
||||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
|
||||||
|
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
||||||
|
|
||||||
int node_start = 0;
|
int node_start = 0;
|
||||||
int node_end = n_nodes_0;
|
int node_end = n_nodes_0;
|
||||||
@ -5105,22 +5424,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
|
|
||||||
const bool should_capture = ctx->capture_next_compute;
|
const bool should_capture = ctx->capture_next_compute;
|
||||||
|
|
||||||
|
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
||||||
|
ggml_metal_mem_pool_reset(mem_pool);
|
||||||
|
|
||||||
for (int idx = node_start; idx < node_end; ++idx) {
|
for (int idx = node_start; idx < node_end; ++idx) {
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_encode_node(backend, idx, encoder);
|
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
||||||
|
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
[encoder popDebugGroup];
|
[encoder popDebugGroup];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
|
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
[command_buffer commit];
|
[cmd_buf commit];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user