llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -3,13 +3,11 @@
#include "common.h"
#include "ngram-cache.h"
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <string>
#include <vector>
#include <unordered_map>
int main(int argc, char ** argv){
gpt_params params;
@ -106,7 +104,7 @@ int main(int argc, char ** argv){
bool has_eos = false;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
std::vector<llama_token> draft;
@ -130,9 +128,9 @@ int main(int argc, char ** argv){
int i_dft = 0;
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
llama_token id = llama_sampling_sample(smpl, ctx, i_dft);
llama_sampling_accept(ctx_sampling, ctx, id, true);
llama_sampling_accept(smpl, id, true);
const std::string token_str = llama_token_to_piece(ctx, id);
@ -241,9 +239,9 @@ int main(int argc, char ** argv){
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx);
llama_print_timings(ctx, smpl);
llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
llama_batch_free(batch_tgt);
llama_free(ctx);