compile ok

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:56:35 +01:00
parent 9fb2d81eab
commit 65f0184517
9 changed files with 46 additions and 29 deletions

View File

@ -607,7 +607,7 @@ struct common_batch {
n_outputs++; n_outputs++;
} }
} }
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, seq_ids[0], logits}); tokens.push_back({token, seq_ids[0], logits});
if (logits) { if (logits) {

View File

@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }

View File

@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }

View File

@ -96,16 +96,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
auto batch = llama_batch_get_one(&tokens[i], n_eval);
// TODO: add mrope pos ids somewhere else
pos.resize(batch.n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < batch.n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % batch.n_tokens);
}
batch.pos = pos.data();
if (llama_decode(ctx_llama, batch)) { // TODO: add mrope pos ids somewhere else
int n_tokens = n_eval;
pos.resize(n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % n_tokens);
}
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
for (int j = 0; j < n_eval; j++) {
llama_token token = tokens[i + j];
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
}
llama_batch_ext_set_output_last(batch.get());
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false; return false;
} }

View File

@ -92,8 +92,10 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt // eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_self_seq_cp(ctx, 0, s, -1, -1); llama_kv_self_seq_cp(ctx, 0, s, -1, -1);

View File

@ -548,7 +548,8 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size(); int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data(); llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0));
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return 1; return 1;
} }
@ -668,7 +669,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return 1; return 1;
} }

View File

@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
} }
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
const llama_pos pos = j*n_batch + k; const llama_pos pos = j*n_batch + k;
bool output = pos >= first; bool output = pos >= first;
batch.add_text(tokens[seq_start + k], pos, seq, output); batch.add_text(tokens[seq_start + k], pos, seq, output);
@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
} }
for (size_t i = 0; i < hs_cur.common_prefix; ++i) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
} }
llama_batch_ext_set_output_last(batch.get()); llama_batch_ext_set_output_last(batch.get());
n_logits += 1; n_logits += 1;
@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence // TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1; const bool needs_logits = i < seq_tokens_size - 1;
batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits; n_logits += needs_logits;
} }
} }
@ -1155,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
} }
for (size_t i = 0; i < data[i1].common_prefix; ++i) { for (size_t i = 0; i < data[i1].common_prefix; ++i) {
batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
} }
llama_batch_ext_set_output_last(batch.get()); llama_batch_ext_set_output_last(batch.get());
n_logits += 1; n_logits += 1;
@ -1163,7 +1162,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
for (int s = 0; s < 2; ++s) { for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences // TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1; n_logits += 1;
} }
} }
@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) { for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
} }
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1; n_logits += 1;
@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence // TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1; const bool needs_logits = i < seq_tokens_size - 1;
batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits; n_logits += needs_logits;
} }
} }
@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
batch.clear(); batch.clear();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
} }
if (llama_decode_ext(ctx, batch.get())) { if (llama_decode_ext(ctx, batch.get())) {

View File

@ -113,7 +113,8 @@ int main(int argc, char ** argv) {
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// eval the prompt // eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0));
llama_decode_ext(ctx_tgt, batch.get());
// note: keep the last token separate! // note: keep the last token separate!
llama_token id_last = inp.back(); llama_token id_last = inp.back();

View File

@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
} }
common_init(); common_init();
#ifdef 0 #if 0
if (params.speculative.model.empty()) { if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__); LOG_ERR("%s: --model-draft is required\n", __func__);
return 1; return 1;
@ -166,9 +166,12 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt with both models // eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0));
llama_decode_ext(ctx_tgt, batch0);
llama_decode_ext(ctx_tgt, batch1);
llama_decode_ext(ctx_dft, batch2);
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();