diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4e72d4c0d..f921211f5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1955,6 +1955,17 @@ void llama_context::opt_epoch_iter( //} llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + n_outputs = ubatch.n_tokens; + + printf("ubatch.n_tokens = %d\n", ubatch.n_tokens); + + // TODO: not sure if this is needed + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + + GGML_ABORT("TODO: handle this error"); + } + auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); @@ -1970,7 +1981,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), ggml_graph_node(gf, -1)); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); //llama_set_inputs(*lctx, ubatch); res->set_inputs(&ubatch);