diff --git a/common/speculative.cpp b/common/speculative.cpp index 585850aae..62ec5bfd8 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -209,7 +209,7 @@ llama_tokens common_speculative_gen_draft( for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); + llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); prompt.push_back(prompt_tgt[i]); } @@ -226,7 +226,7 @@ llama_tokens common_speculative_gen_draft( LOG_DBG("%s: n_past = %d\n", __func__, n_past); llama_batch_ext_clear(batch.get()); - llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true); prompt.push_back(id_last); @@ -265,7 +265,7 @@ llama_tokens common_speculative_gen_draft( break; } - llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true); // evaluate the drafted tokens on the draft model llama_decode_ext(ctx, batch.get()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2b06914ea..b745dd044 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2849,7 +2849,7 @@ struct server_context { slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); slot.n_past += 1; @@ -3057,7 +3057,7 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); + llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3255,10 +3255,10 @@ struct server_context { // construct the speculation batch llama_batch_ext_clear(slot.batch_spec.get()); std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text_token(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); diff --git a/include/llama.h b/include/llama.h index 86aa40d8c..dab1aea2b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -905,7 +905,7 @@ extern "C" { // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_ext_add_text_token( + LLAMA_API int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 36a3d00be..b63d4ec7f 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -344,7 +344,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( int32_t seq_id) { llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); for (int32_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false); + llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } return batch; } @@ -404,7 +404,7 @@ int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } -int32_t llama_batch_ext_add_text_token( +int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos,