server : fix pooled embedding output (#14645)

This commit is contained in:
Douglas Hanley
2025-07-12 06:21:02 -04:00
committed by GitHub
parent b3ad3a0191
commit 0c1df14b5f

View File

@ -2581,12 +2581,14 @@ struct server_context {
continue; continue;
} }
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); const float * embd = nullptr;
if (embd == NULL) { if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings_ith(ctx, i); embd = llama_get_embeddings_ith(ctx, i);
} else {
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
} }
if (embd == NULL) { if (embd == nullptr) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f)); res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
@ -2594,12 +2596,12 @@ struct server_context {
} }
// normalize only when there is pooling // normalize only when there is pooling
// TODO: configurable
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, 2); common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res); res->embedding.push_back(embd_res);
break;
} else { } else {
res->embedding.push_back({ embd, embd + n_embd }); res->embedding.emplace_back(embd, embd + n_embd);
} }
} }