diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 5e72f0e1a..947bbc174 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -26,56 +26,52 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - batch.add_text(tokens[i], i, seq_id, true); + llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); } } -static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - const struct llama_model * model = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); + const int n_tokens = llama_batch_ext_get_n_tokens(batch); + // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, n_tokens, n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode_ext(ctx, batch.get()) < 0) { + if (llama_encode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode_ext(ctx, batch.get()) < 0) { + if (llama_decode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { - if (!batch.tokens[i].logits) { - continue; - } - - const float * embd = nullptr; - int embd_pos = 0; - - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings - embd = llama_get_embeddings_ith(ctx, i); - embd_pos = i; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int i = 0; i < n_tokens; i++) { + const float * embd = llama_get_embeddings_ith(ctx, i); GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); - embd_pos = batch.tokens[i].seq_id; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - } - float * out = output + embd_pos * n_embd; - common_embd_normalize(embd, out, n_embd, embd_norm); + float * out = output + i * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } + } else { + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + s * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } } } @@ -171,7 +167,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct common_batch batch = common_batch(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); // count number of embeddings int n_embd_count = 0; @@ -198,12 +194,12 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.get_n_tokens() + n_toks > n_batch) { - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; + if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { + batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); + llama_batch_ext_clear(batch); + + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens(batch) : s; s = 0; - batch.clear(); } // add to batch @@ -212,8 +208,7 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { LOG("\n"); @@ -318,6 +313,8 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + llama_batch_ext_free(batch); + // clean up llama_backend_free(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 9fe6f8b64..608666549 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch_ext * batch, const std::vector & } static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { - const struct llama_model * model = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx);