mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 12:55:17 +00:00
llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// sample first token
|
||||
{
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
||||
id = llama_sampling_sample(smpl, ctx, 0);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
|
||||
{
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
||||
id = llama_sampling_sample(smpl, ctx, i_batch);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
|
||||
// print
|
||||
{
|
||||
@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
|
||||
if (v == 0) {
|
||||
// sample from the last level
|
||||
for (int i = 0; i < W; i++) {
|
||||
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < W; i++) {
|
||||
@ -468,10 +467,10 @@ int main(int argc, char ** argv) {
|
||||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_print_timings(ctx, smpl);
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
llama_sampling_free(ctx_sampling);
|
||||
llama_sampling_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
Reference in New Issue
Block a user