mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
sampling: add Top-nσ sampler (#11223)
* initial sampling changes: * completed top nsigma sampler implementation * apply parameter to only llama-cli * updated readme * added tests and fixed nsigma impl * cleaned up pr * format * format * format * removed commented tests * cleanup pr and remove explicit floats * added top-k sampler to improve performance * changed sigma to float * fixed string format to float * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * added llama_sampler_init --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e4376270d9
commit
27e8a23300
@ -181,6 +181,17 @@ static void test_dry(
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_top_n_sigma(n));
|
||||
tester.apply(llama_sampler_init_dist (0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
sampler_tester tester(n_vocab);
|
||||
@ -348,6 +359,10 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
||||
|
Reference in New Issue
Block a user