llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -3,12 +3,12 @@
#include <vector>
#include <cstdio>
#include <chrono>
int main(int argc, char ** argv) {
gpt_params params;
params.prompt = "The quick brown fox";
params.sparams.seed = 1234;
if (!gpt_params_parse(argc, argv, params)) {
gpt_params_print_usage(argc, argv, params);
@ -38,6 +38,11 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling_params sparams = llama_sampling_default_params();
sparams.seed = params.sparams.seed;
llama_sampling * smpl = llama_sampling_init(model, sparams);
// tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true);
@ -64,16 +69,11 @@ int main(int argc, char ** argv) {
printf("\nfirst run: %s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(model);
const auto * logits = llama_get_logits(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx, &candidates_p);
llama_sampling_set_logits(smpl, logits);
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
auto next_token_str = llama_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str());
@ -96,6 +96,8 @@ int main(int argc, char ** argv) {
// make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl2 = llama_sampling_init(model, sparams);
printf("\nsecond run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
@ -124,15 +126,11 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx2, &candidates_p);
const auto * logits = llama_get_logits(ctx2);
llama_sampling_set_logits(smpl2, logits);
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
auto next_token_str = llama_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str());
@ -157,7 +155,9 @@ int main(int argc, char ** argv) {
}
// make new context
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl3 = llama_sampling_init(model, sparams);
printf("\nsingle seq run: %s", params.prompt.c_str());
@ -215,15 +215,11 @@ int main(int argc, char ** argv) {
// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx3);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx3, &candidates_p);
const auto * logits = llama_get_logits(ctx3);
llama_sampling_set_logits(smpl3, logits);
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
auto next_token_str = llama_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str());
@ -240,6 +236,10 @@ int main(int argc, char ** argv) {
printf("\n");
llama_sampling_free(smpl);
llama_sampling_free(smpl2);
llama_sampling_free(smpl3);
llama_free(ctx3);
llama_free_model(model);