llama: Add support for RWKV v7 architecture (#12412)

* ggml: Add op l2_norm

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: Add op rwkv_wkv7

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: Add support for RWKV7 and ARWKV7 models

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix inference with RWKV6Qwen2

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: add more (a)rwkv7 variants in size

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code-format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* fix MUSA build

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama: fix shape error with rwkv using llama-parallel

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia
2025-03-18 07:27:50 +08:00
committed by GitHub
parent 60c902926c
commit 7dfad387e3
35 changed files with 2948 additions and 438 deletions

View File

@ -304,6 +304,7 @@ struct vk_device_struct {
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
vk_pipeline pipeline_gelu_f32;
vk_pipeline pipeline_gelu_quick_f32;
vk_pipeline pipeline_silu_f32;
@ -328,6 +329,7 @@ struct vk_device_struct {
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@ -629,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
uint32_t H;
};
struct vk_op_rwkv_wkv7_push_constants {
uint32_t B;
uint32_t T;
uint32_t C;
uint32_t H;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -2263,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -2374,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
for (auto &c : compiles) {
@ -5473,6 +5485,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rms_norm_back_f32;
}
return nullptr;
case GGML_OP_L2_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_l2_norm_f32;
}
return nullptr;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_SILU:
@ -5612,6 +5629,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv6_f32;
}
return nullptr;
case GGML_OP_RWKV_WKV7:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
@ -5859,6 +5881,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUM_ROWS:
@ -6108,23 +6131,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * k = dst->src[0];
const ggml_tensor * v = dst->src[1];
const ggml_tensor * r = dst->src[2];
const ggml_tensor * tf = dst->src[3];
const ggml_tensor * td = dst->src[4];
const ggml_tensor * state = dst->src[5];
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
for (int i = 0; i < num_srcs; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(!ggml_is_quantized(k->type));
GGML_ASSERT(!ggml_is_quantized(v->type));
GGML_ASSERT(!ggml_is_quantized(r->type));
GGML_ASSERT(!ggml_is_quantized(tf->type));
GGML_ASSERT(!ggml_is_quantized(td->type));
GGML_ASSERT(!ggml_is_quantized(state->type));
GGML_ASSERT(dst->buffer != nullptr);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
@ -6133,89 +6150,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
}
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
for (int i = 0; i < num_srcs; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
ggml_vk_sync_buffers(subctx);
vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
for (int i = 0; i < num_srcs; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
dst_uma = d_D != nullptr;
}
if (!K_uma) {
d_K = k_buf_ctx->dev_buffer;
k_offset = vk_tensor_offset(k) + k->view_offs;
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
for (int i = 0; i < num_srcs; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
}
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer;
r_offset = vk_tensor_offset(r) + r->view_offs;
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
const uint64_t dst_size = ggml_nbytes(dst);
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
const uint64_t r_size = ggml_nbytes(r);
const uint64_t tf_size = ggml_nbytes(tf);
const uint64_t td_size = ggml_nbytes(td);
const uint64_t state_size = ggml_nbytes(state);
const uint64_t dst_size = ggml_nbytes(dst);
std::array<uint32_t, 3> elements = {
(uint32_t)(pc.B * pc.H),
1,
1
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },
vk_subbuffer{ d_R, r_offset, r_size },
vk_subbuffer{ d_TF, tf_offset, tf_size },
vk_subbuffer{ d_TD, td_offset, td_size },
vk_subbuffer{ d_State, state_offset, state_size },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
if (version == 6) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
} else if (version == 7) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
} else {
// shouldn't happen
GGML_ASSERT(false);
}
}
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@ -6224,7 +6225,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[5]->ne[1];
ggml_vk_op_f32_rwkv6(
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
@ -6232,6 +6233,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
6,
dryrun
);
}
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
const size_t n_heads = dst->src[0]->ne[1];
const size_t n_seqs = dst->src[6]->ne[1];
ggml_vk_op_f32_wkv(
ctx, subctx, dst,
{
(uint32_t)n_seqs,
(uint32_t)seq_length,
(uint32_t)n_embed,
(uint32_t)n_heads,
},
7,
dryrun
);
}
@ -6533,6 +6554,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
@ -7528,6 +7554,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
@ -7544,6 +7571,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
@ -7590,6 +7618,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@ -7707,6 +7736,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_L2_NORM:
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
@ -7797,6 +7830,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
break;
case GGML_OP_RWKV_WKV7:
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@ -7870,6 +7908,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
@ -7889,6 +7928,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
@ -8806,6 +8846,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_SUB:
@ -8835,6 +8876,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_POOL_2D:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
return true;
@ -9219,6 +9261,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
} else if (tensor->op == GGML_OP_SILU_BACK) {
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_L2_NORM) {
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@ -9338,6 +9383,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],