diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1b9cc4aec..702192b79 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; if (cparams.embeddings && ( - cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || - cparams.pooling_type == LLAMA_POOLING_TYPE_RANK - )) { + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK || + cparams.pooling_type == LLAMA_POOLING_TYPE_LAST + )) { GGML_ASSERT(cls); GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); uint32_t * data = (uint32_t *) cls->data; memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); - for (int i = 0; i < n_tokens; i += n_seq_tokens) { - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; - const int32_t seq_idx = ubatch->seq_idx[seq_id]; + std::vector target_pos(n_seqs_unq, -1); + std::vector target_row(n_seqs_unq, -1); - data[seq_idx] = i; - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - GGML_ASSERT(cls); - GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); - - uint32_t * data = (uint32_t *) cls->data; - memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); - - std::vector last_pos(n_seqs_unq, -1); - std::vector last_row(n_seqs_unq, -1); + bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; for (int i = 0; i < n_tokens; ++i) { const llama_pos pos = ubatch->pos[i]; @@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const llama_seq_id seq_id = ubatch->seq_id[i][s]; const int32_t seq_idx = ubatch->seq_idx[seq_id]; - if (pos >= last_pos[seq_idx]) { - last_pos[seq_idx] = pos; - last_row[seq_idx] = i; + if ( + (target_pos[seq_idx] == -1) || + ( last && pos >= target_pos[seq_idx]) || + (!last && pos < target_pos[seq_idx]) + ) { + target_pos[seq_idx] = pos; + target_row[seq_idx] = i; } } } for (int s = 0; s < n_seqs_unq; ++s) { - if (last_row[s] >= 0) { - data[s] = last_row[s]; + if (target_row[s] >= 0) { + data[s] = target_row[s]; } } }