llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var sampling: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
@ -42,9 +43,11 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_sampling_init(context, llama_sampling_default_params())
}
deinit {
llama_sampling_free(sampling)
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
@ -69,7 +72,6 @@ actor LlamaContext {
print("Using \(n_threads) threads")
var ctx_params = llama_context_default_params()
ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = Int32(n_threads)
ctx_params.n_threads_batch = Int32(n_threads)
@ -147,17 +149,9 @@ actor LlamaContext {
let n_vocab = llama_n_vocab(model)
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
var candidates = Array<llama_token_data>()
candidates.reserveCapacity(Int(n_vocab))
llama_sampling_set_logits(sampling, logits);
for token_id in 0..<n_vocab {
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
}
candidates.withUnsafeMutableBufferPointer() { buffer in
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
new_token_id = llama_sample_token_greedy(context, &candidates_p)
}
new_token_id = llama_sampling_sample_greedy(sampling, nil)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")