mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-15 04:33:06 -04:00
opt : fix n_outputs
ggml-ci
This commit is contained in:
@@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter(
|
|||||||
//}
|
//}
|
||||||
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
||||||
|
|
||||||
|
n_outputs = ubatch.n_tokens;
|
||||||
|
|
||||||
|
printf("ubatch.n_tokens = %d\n", ubatch.n_tokens);
|
||||||
|
|
||||||
|
// TODO: not sure if this is needed
|
||||||
|
if (!kv_self->find_slot(ubatch)) {
|
||||||
|
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||||
|
|
||||||
|
GGML_ABORT("TODO: handle this error");
|
||||||
|
}
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
||||||
|
|
||||||
@@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter(
|
|||||||
};
|
};
|
||||||
ctx_compute_opt = ggml_init(params);
|
ctx_compute_opt = ggml_init(params);
|
||||||
}
|
}
|
||||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), ggml_graph_node(gf, -1));
|
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||||
ggml_opt_alloc(opt_ctx, train);
|
ggml_opt_alloc(opt_ctx, train);
|
||||||
//llama_set_inputs(*lctx, ubatch);
|
//llama_set_inputs(*lctx, ubatch);
|
||||||
res->set_inputs(&ubatch);
|
res->set_inputs(&ubatch);
|
||||||
|
Reference in New Issue
Block a user