mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
@ -2,7 +2,6 @@
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -65,6 +64,15 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
auto sparams = llama_sampling_default_params();
|
||||
|
||||
sparams.seed = params.sparams.seed;
|
||||
sparams.top_k = 40;
|
||||
sparams.top_p = 0.9f;
|
||||
sparams.temp = 0.4f;
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
@ -164,29 +172,17 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
const auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
llama_sampling_top_k(smpl, nullptr);
|
||||
llama_sampling_top_p(smpl, nullptr);
|
||||
llama_sampling_temp (smpl, nullptr);
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr);
|
||||
|
||||
const int top_k = 40;
|
||||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp (ctx, &candidates_p, temp);
|
||||
|
||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||
|
||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
@ -244,12 +240,13 @@ int main(int argc, char ** argv) {
|
||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, smpl);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
Reference in New Issue
Block a user