mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-20 17:49:18 +00:00
server : fix pooled embedding output (#14645)
This commit is contained in:
@ -2581,12 +2581,14 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
const float * embd = nullptr;
|
||||||
if (embd == NULL) {
|
if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
|
} else {
|
||||||
|
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd == NULL) {
|
if (embd == nullptr) {
|
||||||
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
||||||
|
|
||||||
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
||||||
@ -2594,12 +2596,12 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// normalize only when there is pooling
|
// normalize only when there is pooling
|
||||||
// TODO: configurable
|
|
||||||
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||||
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
||||||
res->embedding.push_back(embd_res);
|
res->embedding.push_back(embd_res);
|
||||||
|
break;
|
||||||
} else {
|
} else {
|
||||||
res->embedding.push_back({ embd, embd + n_embd });
|
res->embedding.emplace_back(embd, embd + n_embd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user