From bc6f187e9c0d40ca355e088708e4323bac2828da Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Feb 2025 14:24:17 +0200 Subject: [PATCH] cont : use returend tensors from the graph build ggml-ci --- src/llama-context.cpp | 60 ++++++++++--------------------------------- 1 file changed, 13 insertions(+), 47 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d39263d28..b508a4f8d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1855,7 +1855,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { auto ctx = graph_init(); auto res = graph_build(ctx, ubatch, false); - auto & gf = res.gf; + auto * gf = res.gf; // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -1863,29 +1863,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { input_set(ubatch); - // the output is always the last tensor in the graph - struct ggml_tensor * t_logits = ggml_graph_node(gf, -1); - struct ggml_tensor * t_embd = ggml_graph_node(gf, -2); - - if (n_outputs == 0) { - // no output - t_logits = nullptr; - t_embd = nullptr; - } else if (cparams.embeddings) { - t_logits = nullptr; // do not extract logits for embedding case - t_embd = nullptr; - for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { - if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) { - t_embd = ggml_graph_node(gf, i); - break; - } - } - GGML_ASSERT(t_embd != nullptr && "missing embeddings tensor"); - } else { - t_embd = nullptr; // do not extract embeddings when not needed - GGML_ASSERT(strcmp(t_logits->name, "result_output") == 0 && "missing result_output tensor"); - } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { switch (compute_status) { @@ -1914,8 +1891,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; + auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + + if (t_embd && res.t_embd_pooled) { + t_embd = res.t_embd_pooled; + } + // extract logits - if (t_logits) { + if (t_logits && n_outputs > 0) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1930,7 +1914,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { } // extract embeddings - if (t_embd) { + if (t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -2103,32 +2087,12 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { auto ctx = graph_init(); auto res = graph_build(ctx, ubatch, false); - auto & gf = res.gf; + auto * gf = res.gf; ggml_backend_sched_alloc_graph(sched.get(), gf); input_set(ubatch); - // the output embeddings after the final encoder normalization - struct ggml_tensor * t_embd = nullptr; - - // there are two cases here - if (llama_model_has_decoder(&model)) { - // first case is an encoder-decoder T5 model where embeddings are passed to decoder - t_embd = ggml_graph_node(gf, -1); - GGML_ASSERT(strcmp(t_embd->name, "result_norm") == 0 && "missing result_output tensor"); - } else { - // second case is an encoder-only T5 model - if (cparams.embeddings) { - // only output embeddings if required - t_embd = ggml_graph_node(gf, -1); - if (strcmp(t_embd->name, "result_embd_pooled") != 0) { - t_embd = ggml_graph_node(gf, -2); - } - GGML_ASSERT(strcmp(t_embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); - } - } - const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: @@ -2142,6 +2106,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { return -3; } + auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + // extract embeddings if (t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);