mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
@ -3,12 +3,12 @@
|
||||
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <chrono>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sparams.seed = 1234;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
gpt_params_print_usage(argc, argv, params);
|
||||
@ -38,6 +38,11 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling_params sparams = llama_sampling_default_params();
|
||||
sparams.seed = params.sparams.seed;
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||
|
||||
// tokenize prompt
|
||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
@ -64,16 +69,11 @@ int main(int argc, char ** argv) {
|
||||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
const auto * logits = llama_get_logits(ctx);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx, &candidates_p);
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
@ -96,6 +96,8 @@ int main(int argc, char ** argv) {
|
||||
// make new context
|
||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl2 = llama_sampling_init(model, sparams);
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
@ -124,15 +126,11 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx2);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx2, &candidates_p);
|
||||
const auto * logits = llama_get_logits(ctx2);
|
||||
|
||||
llama_sampling_set_logits(smpl2, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
@ -157,7 +155,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// make new context
|
||||
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl3 = llama_sampling_init(model, sparams);
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
@ -215,15 +215,11 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// third run with seq 1 instead of 0
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx3);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx3, &candidates_p);
|
||||
const auto * logits = llama_get_logits(ctx3);
|
||||
|
||||
llama_sampling_set_logits(smpl3, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
@ -240,6 +236,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
printf("\n");
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_sampling_free(smpl2);
|
||||
llama_sampling_free(smpl3);
|
||||
|
||||
llama_free(ctx3);
|
||||
llama_free_model(model);
|
||||
|
||||
|
Reference in New Issue
Block a user