diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 80f1592e9..d55625da1 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -503,3 +503,49 @@ void llama_batch_ext_free(struct llama_batch_ext * batch) { } delete batch; } + +// deprecated +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +// deprecated +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +}