mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-13 14:29:17 +00:00
cont : use returend tensors from the graph build
ggml-ci
This commit is contained in:
@ -1855,7 +1855,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
|
||||
auto & gf = res.gf;
|
||||
auto * gf = res.gf;
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
@ -1863,29 +1863,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
input_set(ubatch);
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * t_logits = ggml_graph_node(gf, -1);
|
||||
struct ggml_tensor * t_embd = ggml_graph_node(gf, -2);
|
||||
|
||||
if (n_outputs == 0) {
|
||||
// no output
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
} else if (cparams.embeddings) {
|
||||
t_logits = nullptr; // do not extract logits for embedding case
|
||||
t_embd = nullptr;
|
||||
for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
|
||||
if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
|
||||
t_embd = ggml_graph_node(gf, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(t_embd != nullptr && "missing embeddings tensor");
|
||||
} else {
|
||||
t_embd = nullptr; // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(t_logits->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
switch (compute_status) {
|
||||
@ -1914,8 +1891,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||
|
||||
if (t_embd && res.t_embd_pooled) {
|
||||
t_embd = res.t_embd_pooled;
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits) {
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
@ -1930,7 +1914,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
@ -2103,32 +2087,12 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
|
||||
auto & gf = res.gf;
|
||||
auto * gf = res.gf;
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
|
||||
// the output embeddings after the final encoder normalization
|
||||
struct ggml_tensor * t_embd = nullptr;
|
||||
|
||||
// there are two cases here
|
||||
if (llama_model_has_decoder(&model)) {
|
||||
// first case is an encoder-decoder T5 model where embeddings are passed to decoder
|
||||
t_embd = ggml_graph_node(gf, -1);
|
||||
GGML_ASSERT(strcmp(t_embd->name, "result_norm") == 0 && "missing result_output tensor");
|
||||
} else {
|
||||
// second case is an encoder-only T5 model
|
||||
if (cparams.embeddings) {
|
||||
// only output embeddings if required
|
||||
t_embd = ggml_graph_node(gf, -1);
|
||||
if (strcmp(t_embd->name, "result_embd_pooled") != 0) {
|
||||
t_embd = ggml_graph_node(gf, -2);
|
||||
}
|
||||
GGML_ASSERT(strcmp(t_embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
||||
}
|
||||
}
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
@ -2142,6 +2106,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
|
Reference in New Issue
Block a user