ggml : add more generic custom op, remove deprecated custom ops (ggml/1183)

* ggml : add more generic ggml_custom op

* ggml : remove deprecated custom ops
This commit is contained in:
Diego Devesa
2025-04-09 12:31:34 +02:00
committed by Georgi Gerganov
parent e4bf72d631
commit 459895c326
6 changed files with 132 additions and 485 deletions

View File

@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
ggml_unary_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_compute_forward_map_unary(params, tensor, fun);
}
break;
case GGML_OP_MAP_BINARY:
{
ggml_binary_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_compute_forward_map_binary(params, tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM1_F32:
{
ggml_custom1_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_compute_forward_map_custom1_f32(params, tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM2_F32:
{
ggml_custom2_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_compute_forward_map_custom2_f32(params, tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM3_F32:
{
ggml_custom3_op_f32_t fun;
memcpy(&fun, tensor->op_params, sizeof(fun));
ggml_compute_forward_map_custom3_f32(params, tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_map_custom3(params, tensor);
}
break;
case GGML_OP_CUSTOM:
{
ggml_compute_forward_custom(params, tensor);
}
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
ggml_compute_forward_cross_entropy_loss(params, tensor);
@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32:
case GGML_OP_MAP_CUSTOM2_F32:
case GGML_OP_MAP_CUSTOM3_F32:
{
n_tasks = 1;
} break;
@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = MIN(p.n_tasks, n_threads);
}
} break;
case GGML_OP_CUSTOM:
{
struct ggml_custom_op_params p;
memcpy(&p, node->op_params, sizeof(p));
if (p.n_tasks == GGML_N_TASKS_MAX) {
n_tasks = n_threads;
} else {
n_tasks = MIN(p.n_tasks, n_threads);
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: