swift : adapt to new API

This commit is contained in:
Georgi Gerganov
2025-03-19 10:48:42 +02:00
parent b0db7fc2c6
commit 96ca6e8d23

View File

@ -5,35 +5,19 @@ enum LlamaError: Error {
case couldNotInitializeContext
}
func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.n_tokens += 1
}
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var vocab: OpaquePointer
private var sampling: UnsafeMutablePointer<llama_sampler>
private var batch: llama_batch
private var batch: OpaquePointer
private var tokens_list: [llama_token]
var is_done: Bool = false
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
var n_len: Int32 = 1024
var n_len: Int32 = 128
var n_cur: Int32 = 0
var n_decode: Int32 = 0
@ -42,7 +26,7 @@ actor LlamaContext {
self.model = model
self.context = context
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.batch = llama_batch_ext_init(512, 1)
self.temporary_invalid_cchars = []
let sparams = llama_sampler_chain_default_params()
self.sampling = llama_sampler_chain_init(sparams)
@ -53,7 +37,7 @@ actor LlamaContext {
deinit {
llama_sampler_free(sampling)
llama_batch_free(batch)
llama_batch_ext_free(batch)
llama_model_free(model)
llama_free(context)
llama_backend_free()
@ -111,7 +95,7 @@ actor LlamaContext {
}
func get_n_tokens() -> Int32 {
return batch.n_tokens;
return llama_batch_ext_get_n_tokens(batch)
}
func completion_init(text: String) {
@ -133,25 +117,25 @@ actor LlamaContext {
print(String(cString: token_to_piece(token: id) + [0]))
}
llama_batch_clear(&batch)
llama_batch_ext_clear(batch)
for i1 in 0..<tokens_list.count {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
llama_batch_ext_add_text(batch, tokens_list[i], Int32(i), [llama_seq_id(0)], 1, false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_batch_ext_set_output_last(batch)
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
if llama_decode_ext(context, batch) != 0 {
print("llama_decode_ext() failed")
}
n_cur = batch.n_tokens
n_cur = llama_batch_ext_get_n_tokens(batch)
}
func completion_loop() -> String {
var new_token_id: llama_token = 0
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
new_token_id = llama_sampler_sample(sampling, context, llama_batch_ext_get_n_tokens(batch) - 1)
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
print("\n")
@ -178,13 +162,13 @@ actor LlamaContext {
print(new_token_str)
// tokens_list.append(new_token_id)
llama_batch_clear(&batch)
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
llama_batch_ext_clear(batch)
llama_batch_ext_add_text(batch, new_token_id, n_cur, [llama_seq_id(0)], 1, true)
n_decode += 1
n_cur += 1
if llama_decode(context, batch) != 0 {
if llama_decode_ext(context, batch) != 0 {
print("failed to evaluate llama!")
}
@ -201,21 +185,21 @@ actor LlamaContext {
for _ in 0..<nr {
// bench prompt processing
llama_batch_clear(&batch)
llama_batch_ext_clear(batch)
let n_tokens = pp
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(0)], 1, false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_batch_ext_set_output_last(batch)
llama_kv_self_clear(context)
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during prompt")
if llama_decode_ext(context, batch) != 0 {
print("llama_decode_ext() failed during prompt")
}
llama_synchronize(context)
@ -228,14 +212,14 @@ actor LlamaContext {
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
for i in 0..<tg {
llama_batch_clear(&batch)
llama_batch_ext_clear(batch)
for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(Int32(j))], 1, true)
}
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during text generation")
if llama_decode_ext(context, batch) != 0 {
print("llama_decode_ext() failed during text generation")
}
llama_synchronize(context)
}