apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// clear the KV cache
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
}
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
//LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return {tokens, -1, logit_history, prob_history};
}
@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
std::vector<float> logits;
if (num_batches > 1) {
@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
int n_outputs = 0;
batch.n_tokens = 0;
batch.clear();
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
batch.token [idx] = tokens[seq_start + k];
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
const llama_pos pos = j*n_batch + k;
bool output = pos >= first;
batch.add_text(tokens[seq_start + k], pos, seq, output);
n_outputs += batch.logits[idx] != 0;
n_outputs += output ? 1 : 0;
}
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_INF("%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
}
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history};
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
}
int n_outputs = batch_view.n_outputs;
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
common_batch batch(n_ctx, 4);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 4; ++s) {
@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
i0 = i1 - 1;
}
llama_batch_free(batch);
LOG("\n");
}
@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
common_batch batch(n_ctx, 2);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
size_t i1 = i0;
size_t i_logits = 0;
common_batch_clear(batch);
batch.clear();
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
int n_logits = 0;
@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1;
}
}
@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
common_batch batch(n_ctx, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
i0 = i1 - 1;
}
llama_batch_free(batch);
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
float p = 1.f*n_correct/n_done;
@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return;
}
@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {