opt : fix n_outputs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-04-25 11:45:21 +03:00
parent 4e73b81a67
commit cee751c450

View File

@@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter(
//}
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
n_outputs = ubatch.n_tokens;
printf("ubatch.n_tokens = %d\n", ubatch.n_tokens);
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
GGML_ABORT("TODO: handle this error");
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
@@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter(
};
ctx_compute_opt = ggml_init(params);
}
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), ggml_graph_node(gf, -1));
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
ggml_opt_alloc(opt_ctx, train);
//llama_set_inputs(*lctx, ubatch);
res->set_inputs(&ubatch);