ggml : move AMX to the CPU backend (#10570)

* ggml : move AMX to the CPU backend

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Diego Devesa
2024-11-29 21:54:58 +01:00
committed by GitHub
parent b782e5c7d4
commit 7cc2d2c889
64 changed files with 514 additions and 801 deletions

View File

@ -3,6 +3,7 @@
#include "ggml-cpu.h"
#include "ggml-cpu-aarch64.h"
#include "ggml-impl.h"
#include "amx/amx.h"
#include <cctype>
#include <string>
#include <vector>
@ -134,12 +135,16 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backen
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts;
#ifdef GGML_USE_CPU_HBM
bufts.push_back(ggml_backend_cpu_hbm_buffer_type());
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
if (ggml_backend_amx_buffer_type()) {
bufts.push_back(ggml_backend_amx_buffer_type());
}
#endif
#ifdef GGML_USE_CPU_AARCH64
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
if (ggml_backend_cpu_aarch64_buffer_type()) {
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
}
#endif
bufts.push_back(NULL);
@ -456,12 +461,27 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
if (op->op == GGML_OP_NONE || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_VIEW || op->op == GGML_OP_PERMUTE || op->op == GGML_OP_TRANSPOSE) {
return true;
}
if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
if (op->op != GGML_OP_MUL_MAT || src0->type == ggml_aarch64_get_optimal_repack_type(src0)) {
return false;
}
}
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
if (src0 && src0->buffer && ggml_backend_amx_buft_is_amx(src0->buffer->buft)) {
return ggml_backend_amx_device_supports_op(op);
}
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_amx_buft_is_amx(op->src[i]->buffer->buft)) {
return false;
}
}
#endif
for (int i = 1; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
return false;
@ -491,7 +511,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
}
static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft);
bool supported = ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft);
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
supported = supported || ggml_backend_amx_buft_is_amx(buft);
#endif
return supported;
GGML_UNUSED(dev);
}