llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -1,7 +1,6 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
// sample first token
{
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
id = llama_sampling_sample(smpl, ctx, 0);
llama_sampling_accept(ctx_sampling, ctx, id, true);
llama_sampling_accept(smpl, id, true);
{
const std::string token_str = llama_token_to_piece(ctx, id);
@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
}
// sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
id = llama_sampling_sample(smpl, ctx, i_batch);
llama_sampling_accept(ctx_sampling, ctx, id, true);
llama_sampling_accept(smpl, id, true);
// print
{
@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
@ -468,10 +467,10 @@ int main(int argc, char ** argv) {
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_accept = %d\n", n_accept);
llama_print_timings(ctx);
llama_print_timings(ctx, smpl);
llama_kv_cache_view_free(&kvc_view);
llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
llama_batch_free(batch);