llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
struct client {
~client() {
if (ctx_sampling) {
llama_sampling_free(ctx_sampling);
if (smpl) {
llama_sampling_free(smpl);
}
}
@ -72,7 +72,7 @@ struct client {
std::string prompt;
std::string response;
struct llama_sampling_context * ctx_sampling = nullptr;
struct llama_sampling * smpl = nullptr;
};
static void print_date_time() {
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.ctx_sampling = llama_sampling_init(params.sparams);
client.smpl = llama_sampling_init(model, params.sparams);
}
std::vector<llama_token> tokens_system;
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
llama_sampling_reset(client.ctx_sampling);
llama_sampling_reset(client.smpl);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i);
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
llama_sampling_accept(client.smpl, id, true);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();
@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
llama_print_timings(ctx);
// TODO: print sampling/grammar timings for all clients
llama_print_timings(ctx, nullptr);
llama_batch_free(batch);