mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 04:45:17 +00:00
sampling : don't consider -infinity values in top_n_sigma (#13344)
This commit is contained in:
@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|||||||
// find max logit and calculate mean
|
// find max logit and calculate mean
|
||||||
float max = cur_p->data[0].logit;
|
float max = cur_p->data[0].logit;
|
||||||
float logits_sum = 0;
|
float logits_sum = 0;
|
||||||
|
size_t valid_count = 0;
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
// Only count non-negative infinity values
|
||||||
|
if (cur_p->data[i].logit != -INFINITY) {
|
||||||
if (cur_p->data[i].logit > max) {
|
if (cur_p->data[i].logit > max) {
|
||||||
max = cur_p->data[i].logit;
|
max = cur_p->data[i].logit;
|
||||||
}
|
}
|
||||||
logits_sum += cur_p->data[i].logit;
|
logits_sum += cur_p->data[i].logit;
|
||||||
|
valid_count++;
|
||||||
}
|
}
|
||||||
float mean = logits_sum/cur_p->size;
|
}
|
||||||
|
float mean = valid_count > 0 ? logits_sum/valid_count : 0;
|
||||||
|
|
||||||
// calculate standard deviation
|
// calculate standard deviation
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
// Skip -infinity in std calculation
|
||||||
|
if (cur_p->data[i].logit != -INFINITY) {
|
||||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||||
}
|
}
|
||||||
float std = sqrt(acc/cur_p->size);
|
}
|
||||||
|
float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
|
||||||
|
|
||||||
//apply mask
|
//apply mask
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
Reference in New Issue
Block a user