mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 20:29:41 -04:00
embeddings: fix extraction of CLS pooling results (#14927)
* embeddings: fix extraction of CLS pooling results * merge RANK pooling into CLS case for inputs
This commit is contained in:
@@ -188,12 +188,12 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|||||||
|
|
||||||
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
||||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||||
|
|
||||||
if (cparams.embeddings && (
|
if (cparams.embeddings && (
|
||||||
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
||||||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
|
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
|
||||||
|
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
|
||||||
)) {
|
)) {
|
||||||
GGML_ASSERT(cls);
|
GGML_ASSERT(cls);
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
||||||
@@ -201,25 +201,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||||||
uint32_t * data = (uint32_t *) cls->data;
|
uint32_t * data = (uint32_t *) cls->data;
|
||||||
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
std::vector<int> target_pos(n_seqs_unq, -1);
|
||||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
std::vector<int> target_row(n_seqs_unq, -1);
|
||||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
||||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
||||||
|
|
||||||
data[seq_idx] = i;
|
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
|
||||||
GGML_ASSERT(cls);
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
||||||
|
|
||||||
uint32_t * data = (uint32_t *) cls->data;
|
|
||||||
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
||||||
|
|
||||||
std::vector<int> last_pos(n_seqs_unq, -1);
|
|
||||||
std::vector<int> last_row(n_seqs_unq, -1);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
const llama_pos pos = ubatch->pos[i];
|
const llama_pos pos = ubatch->pos[i];
|
||||||
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||||
|
|
||||||
if (pos >= last_pos[seq_idx]) {
|
if (
|
||||||
last_pos[seq_idx] = pos;
|
(target_pos[seq_idx] == -1) ||
|
||||||
last_row[seq_idx] = i;
|
( last && pos >= target_pos[seq_idx]) ||
|
||||||
|
(!last && pos < target_pos[seq_idx])
|
||||||
|
) {
|
||||||
|
target_pos[seq_idx] = pos;
|
||||||
|
target_row[seq_idx] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seqs_unq; ++s) {
|
for (int s = 0; s < n_seqs_unq; ++s) {
|
||||||
if (last_row[s] >= 0) {
|
if (target_row[s] >= 0) {
|
||||||
data[s] = last_row[s];
|
data[s] = target_row[s];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user