mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
swift : adapt to new API
This commit is contained in:
@ -5,35 +5,19 @@ enum LlamaError: Error {
|
||||
case couldNotInitializeContext
|
||||
}
|
||||
|
||||
func llama_batch_clear(_ batch: inout llama_batch) {
|
||||
batch.n_tokens = 0
|
||||
}
|
||||
|
||||
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
|
||||
batch.token [Int(batch.n_tokens)] = id
|
||||
batch.pos [Int(batch.n_tokens)] = pos
|
||||
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
|
||||
for i in 0..<seq_ids.count {
|
||||
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
||||
}
|
||||
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
|
||||
|
||||
batch.n_tokens += 1
|
||||
}
|
||||
|
||||
actor LlamaContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var vocab: OpaquePointer
|
||||
private var sampling: UnsafeMutablePointer<llama_sampler>
|
||||
private var batch: llama_batch
|
||||
private var batch: OpaquePointer
|
||||
private var tokens_list: [llama_token]
|
||||
var is_done: Bool = false
|
||||
|
||||
/// This variable is used to store temporarily invalid cchars
|
||||
private var temporary_invalid_cchars: [CChar]
|
||||
|
||||
var n_len: Int32 = 1024
|
||||
var n_len: Int32 = 128
|
||||
var n_cur: Int32 = 0
|
||||
|
||||
var n_decode: Int32 = 0
|
||||
@ -42,7 +26,7 @@ actor LlamaContext {
|
||||
self.model = model
|
||||
self.context = context
|
||||
self.tokens_list = []
|
||||
self.batch = llama_batch_init(512, 0, 1)
|
||||
self.batch = llama_batch_ext_init(512, 1)
|
||||
self.temporary_invalid_cchars = []
|
||||
let sparams = llama_sampler_chain_default_params()
|
||||
self.sampling = llama_sampler_chain_init(sparams)
|
||||
@ -53,7 +37,7 @@ actor LlamaContext {
|
||||
|
||||
deinit {
|
||||
llama_sampler_free(sampling)
|
||||
llama_batch_free(batch)
|
||||
llama_batch_ext_free(batch)
|
||||
llama_model_free(model)
|
||||
llama_free(context)
|
||||
llama_backend_free()
|
||||
@ -111,7 +95,7 @@ actor LlamaContext {
|
||||
}
|
||||
|
||||
func get_n_tokens() -> Int32 {
|
||||
return batch.n_tokens;
|
||||
return llama_batch_ext_get_n_tokens(batch)
|
||||
}
|
||||
|
||||
func completion_init(text: String) {
|
||||
@ -133,25 +117,25 @@ actor LlamaContext {
|
||||
print(String(cString: token_to_piece(token: id) + [0]))
|
||||
}
|
||||
|
||||
llama_batch_clear(&batch)
|
||||
llama_batch_ext_clear(batch)
|
||||
|
||||
for i1 in 0..<tokens_list.count {
|
||||
let i = Int(i1)
|
||||
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
||||
llama_batch_ext_add_text(batch, tokens_list[i], Int32(i), [llama_seq_id(0)], 1, false)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
llama_batch_ext_set_output_last(batch)
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed")
|
||||
if llama_decode_ext(context, batch) != 0 {
|
||||
print("llama_decode_ext() failed")
|
||||
}
|
||||
|
||||
n_cur = batch.n_tokens
|
||||
n_cur = llama_batch_ext_get_n_tokens(batch)
|
||||
}
|
||||
|
||||
func completion_loop() -> String {
|
||||
var new_token_id: llama_token = 0
|
||||
|
||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||
new_token_id = llama_sampler_sample(sampling, context, llama_batch_ext_get_n_tokens(batch) - 1)
|
||||
|
||||
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||
print("\n")
|
||||
@ -178,13 +162,13 @@ actor LlamaContext {
|
||||
print(new_token_str)
|
||||
// tokens_list.append(new_token_id)
|
||||
|
||||
llama_batch_clear(&batch)
|
||||
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
|
||||
llama_batch_ext_clear(batch)
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_cur, [llama_seq_id(0)], 1, true)
|
||||
|
||||
n_decode += 1
|
||||
n_cur += 1
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
if llama_decode_ext(context, batch) != 0 {
|
||||
print("failed to evaluate llama!")
|
||||
}
|
||||
|
||||
@ -201,21 +185,21 @@ actor LlamaContext {
|
||||
for _ in 0..<nr {
|
||||
// bench prompt processing
|
||||
|
||||
llama_batch_clear(&batch)
|
||||
llama_batch_ext_clear(batch)
|
||||
|
||||
let n_tokens = pp
|
||||
|
||||
for i in 0..<n_tokens {
|
||||
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
||||
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(0)], 1, false)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
llama_batch_ext_set_output_last(batch)
|
||||
|
||||
llama_kv_self_clear(context)
|
||||
|
||||
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed during prompt")
|
||||
if llama_decode_ext(context, batch) != 0 {
|
||||
print("llama_decode_ext() failed during prompt")
|
||||
}
|
||||
llama_synchronize(context)
|
||||
|
||||
@ -228,14 +212,14 @@ actor LlamaContext {
|
||||
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
|
||||
|
||||
for i in 0..<tg {
|
||||
llama_batch_clear(&batch)
|
||||
llama_batch_ext_clear(batch)
|
||||
|
||||
for j in 0..<pl {
|
||||
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
|
||||
llama_batch_ext_add_text(batch, 0, Int32(i), [llama_seq_id(Int32(j))], 1, true)
|
||||
}
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed during text generation")
|
||||
if llama_decode_ext(context, batch) != 0 {
|
||||
print("llama_decode_ext() failed during text generation")
|
||||
}
|
||||
llama_synchronize(context)
|
||||
}
|
||||
|
Reference in New Issue
Block a user