compile ok

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:56:35 +01:00
parent 9fb2d81eab
commit 65f0184517
9 changed files with 46 additions and 29 deletions

View File

@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
}
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
const llama_pos pos = j*n_batch + k;
bool output = pos >= first;
batch.add_text(tokens[seq_start + k], pos, seq, output);
@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@ -1155,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
@ -1163,7 +1162,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1;
}
}
@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1;
@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
batch.clear();
for (int i = 0; i < batch_size; i++) {
batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
}
if (llama_decode_ext(ctx, batch.get())) {