diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f8238f315..ec5478db8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -314,6 +314,13 @@ extern "C" { #endif + // Function type used in fatal error callbacks + typedef void (*ggml_abort_callback_t)(const char * error_message); + + // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) + // Returns the old callback for chaining + GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback); + GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4) GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c51cb57cc..4227fb101 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -202,19 +202,34 @@ void ggml_print_backtrace(void) { } #endif +static ggml_abort_callback_t g_abort_callback = NULL; + +// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) +GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) { + ggml_abort_callback_t ret_val = g_abort_callback; + g_abort_callback = callback; + return ret_val; +} + void ggml_abort(const char * file, int line, const char * fmt, ...) { fflush(stdout); - fprintf(stderr, "%s:%d: ", file, line); + char message[2048]; + int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line); va_list args; va_start(args, fmt); - vfprintf(stderr, fmt, args); + vsnprintf(message + offset, sizeof(message) - offset, fmt, args); va_end(args); - fprintf(stderr, "\n"); + if (g_abort_callback) { + g_abort_callback(message); + } else { + // default: print error and backtrace to stderr + fprintf(stderr, "%s\n", message); + ggml_print_backtrace(); + } - ggml_print_backtrace(); abort(); }