llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov
2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
48 changed files with 2481 additions and 2590 deletions

View File

@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
print("Failed to load model")
exit(1)
}
defer {
llama_free_model(model)
}
@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true)
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
var context_params = llama_context_default_params()
context_params.seed = 1234
context_params.n_ctx = n_kv_req
context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
@ -48,11 +46,24 @@ guard context != nil else {
print("Failed to initialize context")
exit(1)
}
defer {
llama_free(context)
}
var sparams = llama_sampling_params()
sparams.top_k = 40
sparams.top_p = 0.9
sparams.temp = 0.4
let smpl = llama_sampling_init(model, sparams)
guard smpl != nil else {
print("Failed to initialize sampling")
exit(1)
}
defer {
llama_sampling_free(smpl)
}
let n_ctx = llama_n_ctx(context)
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
@ -125,32 +136,17 @@ while n_cur <= n_len {
continue
}
var n_vocab = llama_n_vocab(model)
var logits = llama_get_logits_ith(context, i_batch[i])
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
llama_sampling_set_logits(smpl, logits)
for token_id in 0 ..< n_vocab {
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
}
llama_sampling_top_k(smpl, nil)
llama_sampling_top_p(smpl, nil)
llama_sampling_temp (smpl, nil)
var candidates_p: llama_token_data_array = .init(
data: &candidates,
size: candidates.count,
sorted: false
)
let new_token_id = llama_sampling_sample_dist(smpl, nil)
let top_k: Int32 = 40
let top_p: Float = 0.9
let temp: Float = 0.4
llama_sample_top_k(context, &candidates_p, top_k, 1)
llama_sample_top_p(context, &candidates_p, top_p, 1)
llama_sample_temp(context, &candidates_p, temp)
let new_token_id = llama_sample_token(context, &candidates_p)
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);
// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
@ -212,7 +208,7 @@ let t_main_end = ggml_time_us()
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
llama_print_timings(context)
llama_print_timings(context, smpl)
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count