apply various in places

This commit is contained in:
Xuan Son Nguyen
2025-03-01 20:42:18 +01:00
parent 1d6ba97789
commit 46596caf6d
12 changed files with 142 additions and 133 deletions

View File

@ -565,6 +565,52 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id;
bool logits;
};
std::vector<batch_token> tokens;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_logits_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
}
return view;
}
};
//
// Token utils
//

View File

@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
// warm up
{
for (int i = 0; i < 16; ++i) {
common_batch_add(batch, 0, i, { 0 }, false);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
continue;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
common_batch_add(batch, 0, i, { j }, false);
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_logits_last(batch);
const auto t_pp_start = ggml_time_us();
@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int j = 0; j < pl; ++j) {
common_batch_add(batch, 0, pp + i, { j }, true);
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
LOG("\n");
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_free(ctx);
llama_model_free(model);

View File

@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());
if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
if (llama_encode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
decoder_start_token_id = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
llama_batch_ext_clear(batch);
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_logits_last(batch);
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_cur = batch.n_tokens;
int n_cur = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;
const auto t_main_start = ggml_time_us();
while (n_cur <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
streams[i] += common_token_to_piece(ctx, new_token_id);
i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_cur, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false);
n_decode += 1;
}
// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);

View File

@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
if (llama_decode_ext(ctx, batch.get())) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

View File

@ -25,14 +25,14 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
return lines;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
batch.add_text(tokens[i], i, seq_id, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
@ -40,21 +40,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
llama_kv_cache_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.tokens[i].logits) {
continue;
}
@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
struct common_batch batch = common_batch(n_batch, 1);
// count number of embeddings
int n_embd_count = 0;
@ -197,12 +197,12 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (batch.get_n_tokens() + n_toks > n_batch) {
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s;
s = 0;
common_batch_clear(batch);
batch.clear();
}
// add to batch
@ -318,7 +318,6 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
// clean up
llama_batch_free(batch);
llama_backend_free();
return 0;

View File

@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

View File

@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
const std::string input_string = instruction + sentences[i];
@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_model_n_embd(model);
@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
#endif
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
return result;
}
@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
common_batch_clear(bat);
llama_batch_ext_clear(bat);
{
const int32_t n_inputs = inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
}
}
inputs.clear();
llama_decode(ctx, bat);
llama_decode_ext(ctx, bat);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
if (token == eos_token) {
break;
@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::printf("\n");
}
llama_batch_free(bat);
llama_batch_ext_free(bat);
return result;
}

View File

@ -500,7 +500,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -514,14 +514,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return false;
}
@ -534,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();

View File

@ -353,7 +353,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}

View File

@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0));
llama_decode_ext(ctx, batch.get());
n_processed += n_tokens;
}
@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0));
llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}

View File

@ -91,8 +91,10 @@ int main(int argc, char ** argv){
const auto t_enc_start = ggml_time_us();
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
const auto t_enc_end = ggml_time_us();
@ -108,7 +110,7 @@ int main(int argc, char ** argv){
std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
@ -194,8 +196,9 @@ int main(int argc, char ** argv){
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_clear(batch_tgt);
llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true);
// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
@ -205,13 +208,13 @@ int main(int argc, char ** argv){
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (size_t i = 1; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt);
llama_decode_ext(ctx, batch_tgt);
++n_past;
draft.erase(draft.begin());
@ -243,7 +246,7 @@ int main(int argc, char ** argv){
common_sampler_free(smpl);
llama_batch_free(batch_tgt);
llama_batch_ext_free(batch_tgt);
llama_backend_free();

View File

@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result {
}
};
struct server_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id;
bool logits;
};
std::vector<batch_token> tokens;
server_batch() = default;
server_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_logits_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
server_batch get_view(int32_t offset, int32_t n_tokens) {
server_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
}
return view;
}
};
struct server_slot {
int id;
int id_task = -1;
@ -1253,7 +1212,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
server_batch batch_spec;
common_batch batch_spec;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
@ -1825,7 +1784,7 @@ struct server_context {
llama_context_params cparams_dft;
server_batch batch;
common_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1950,7 +1909,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1);
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
@ -1986,7 +1945,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = server_batch(std::max(n_batch, params_base.n_parallel), 1);
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
}
metrics.init();
@ -2104,7 +2063,7 @@ struct server_context {
}
if (slot.ctx_dft) {
slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1);
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
}
slot.state = SLOT_STATE_STARTED;
@ -2412,7 +2371,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, server_batch & batch) {
void send_embedding(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
@ -2456,7 +2415,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, server_batch & batch) {
void send_rerank(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
@ -3155,9 +3114,9 @@ struct server_context {
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
server_batch batch_view = batch.get_view(i, n_tokens);
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode_ext(ctx, batch_view.batch.get());
const int ret = llama_decode_ext(ctx, batch_view.get());
metrics.on_decoded(slots);
if (ret != 0) {
@ -3301,7 +3260,7 @@ struct server_context {
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
llama_decode_ext(ctx, slot.batch_spec.batch.get());
llama_decode_ext(ctx, slot.batch_spec.get());
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);