examples : allow extracting embeddings from decoder contexts (#13797)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-26 14:03:54 +03:00
committed by GitHub
parent 22229314fc
commit 79c137f776
4 changed files with 10 additions and 16 deletions

View File

@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to process\n", __func__);
}
for (int i = 0; i < batch.n_tokens; i++) {
@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_encode(ctx, batch, out, s, n_embd);
batch_process(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
p += s;
s = 0;
@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
// final batch
float * out = emb + p * n_embd;
batch_encode(ctx, batch, out, s, n_embd);
batch_process(ctx, batch, out, s, n_embd);
// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd, 0);
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);