#include "llama-batch.h" #include #include llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { // clear empty sequences // the previous ubatch is assumed to be gone, // so nothing should refer to values in these sequences anymore. for (size_t i = seq.size(); i-- > 0;) { if (seq[i].length == 0) { seq.pop_back(); } else { break; } } ubatch_token.resize(!has_embd ? n_ubatch : 0); ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); ubatch_pos.resize(n_ubatch); ubatch_n_seq_id.resize(n_ubatch); ubatch_seq_id.resize(n_ubatch); ubatch_output.resize(n_ubatch); llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, /*token =*/ !has_embd ? ubatch_token.data() : nullptr, /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, /*pos =*/ ubatch_pos.data(), /*n_seq_id =*/ ubatch_n_seq_id.data(), /*seq_id =*/ ubatch_seq_id.data(), /*output =*/ ubatch_output.data(), }; return ubatch; } void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { GGML_ASSERT(batch != nullptr); GGML_ASSERT(length <= seq.length); // Can only add sequences of equal lengths to a batch, // otherwise it isn't clear to which sequence a token belongs GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); // NOTE: loops are separated for cache-friendliness if (batch->token) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; } } else { // simple split ubatch.token = batch->token + seq.offset; } } else { ubatch.token = nullptr; } if (batch->embd) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { memcpy( ubatch.embd + (n_embd * (ubatch.n_tokens + i)), batch->embd + (n_embd * ids[seq.offset + i]), n_embd * sizeof(float) ); } } else { // simple split ubatch.embd = batch->embd + (n_embd * seq.offset); } } else { ubatch.embd = nullptr; } if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; } } else { // simple split ubatch.pos = batch->pos + seq.offset; } if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; } } else { // simple split if (batch->n_seq_id) { ubatch.n_seq_id = batch->n_seq_id + seq.offset; } else { for (size_t i = 0; i < length; ++i) { ubatch.n_seq_id[ubatch.n_seqs + i] = 1; } } if (batch->seq_id) { ubatch.seq_id = batch->seq_id + seq.offset; } } if (logits_all) { for (size_t i = 0; i < length; ++i) { ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } } else if (batch->logits) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; int8_t is_output = batch->logits[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split ubatch.output = batch->logits + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } } } else { // only get last output for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; int8_t is_last = id == ids.size() - 1; ubatch.output[ubatch.n_tokens + i] = is_last; if (is_last) { out_ids.push_back(id); } } } if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; } ubatch.n_tokens += length; ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits seq.offset += length; seq.length -= length; n_tokens -= length; GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); } llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); ubatch.equal_seqs = false; if (!seq.empty()) { llama_sbatch_seq & s = seq[0]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits add_seq_to_ubatch(ubatch, s, length); } return ubatch; } llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits // smallest first, because it's easier to split this way; // starting from the end to pop in constant time. for (size_t i = seq.size(); i-- > 0;) { llama_sbatch_seq & s = seq[i]; GGML_ASSERT(s.length > 0); if (length == 0) { length = s.length < n_ubatch ? s.length : n_ubatch; } add_seq_to_ubatch(ubatch, s, length); n_tokens_in_ubatch += length; // shared prompts can't be mixed with any of their sequences, // so it's safer to compute them in their own ubatch if (s.n_seq_id > 1) { break; } // stop when there isn't enough space for another sequence if (length + n_tokens_in_ubatch > n_ubatch) { break; } } } return ubatch; } llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits add_seq_to_ubatch(ubatch, s, length); } return ubatch; } void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; this->logits_all = logits_all; n_tokens = batch.n_tokens; ids.resize(n_tokens); out_ids.clear(); // TODO: reserve out_ids and seq for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; s.n_seq_id = 0; s.seq_id = nullptr; s.offset = 0; s.length = n_tokens; return; } std::sort(ids.begin(), ids.end(), [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; // sort by seq_id, then by pos if (n_seq_a == n_seq_b) { if (batch.seq_id) { for (int32_t i = 0; i < n_seq_a; ++i) { llama_seq_id seq_id_a = batch.seq_id[a][i]; llama_seq_id seq_id_b = batch.seq_id[b][i]; // smaller seq_ids go first if (seq_id_a != seq_id_b) { return seq_id_a < seq_id_b; } } } // when all else is equal, sort by pos if (batch.pos) { return batch.pos[a] < batch.pos[b]; } // no pos, sort by id return a < b; } // shared prompts go first return n_seq_a > n_seq_b; } ); // init seq llama_sbatch_seq * last_seq = nullptr; for (size_t i = 0; i < n_tokens; ++i) { const size_t bi = ids[i]; const int32_t n_seqs = batch.n_seq_id[bi]; llama_seq_id * seq_ids = batch.seq_id[bi]; if (last_seq != nullptr) { bool same = n_seqs == last_seq->n_seq_id; for (int32_t j = 0; same && j < n_seqs; ++j) { if (seq_ids[j] != last_seq->seq_id[j]) { same = false; } } if (same) { last_seq->length += 1; continue; } } llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; seq.push_back(new_seq); last_seq = &seq.back(); } // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), [](llama_sbatch_seq & a, llama_sbatch_seq & b) { if (a.n_seq_id == b.n_seq_id) { return a.length > b.length; } return a.n_seq_id < b.n_seq_id; } ); } llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) { batch = new llama_batch_ext{ /*n_tokens =*/ in_batch.n_tokens, /*max_tokens =*/ in_batch.n_tokens, /*is_view =*/ false, /*tokens =*/ in_batch.token, /*embd =*/ in_batch.embd, /*pos =*/ in_batch.pos, /*n_seq_id =*/ in_batch.n_seq_id, /*seq_id =*/ in_batch.seq_id, /*logits =*/ in_batch.logits, }; GGML_ASSERT(batch->n_tokens > 0); if (!in_batch.pos) { pos.resize(batch->n_tokens); for (int32_t i = 0; i < batch->n_tokens; i++) { pos[i] = i + p0; } batch->pos = pos.data(); } if (!batch->n_seq_id) { n_seq_id.resize(batch->n_tokens); for (int32_t i = 0; i < batch->n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } batch->n_seq_id = n_seq_id.data(); } if (!batch->seq_id) { seq_id.resize(batch->n_tokens + 1); seq_id[batch->n_tokens] = NULL; for (int32_t i = 0; i < batch->n_tokens; i++) { seq_id[i] = seq_id_0.data(); } batch->seq_id = seq_id.data(); } if (!batch->logits) { logits.resize(batch->n_tokens); logits[logits.size() - 1] = true; batch->logits = logits.data(); } } llama_batch_allocr::~llama_batch_allocr() { delete batch; } // // interface implementation // struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens) { return llama_batch{ /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, }; } struct llama_batch_ext * llama_batch_ext_init_from_text( llama_token * tokens, int32_t n_tokens, int32_t pos0, int32_t seq_id) { llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); for (int32_t i = 0; i < n_tokens; i++) { llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } return batch; } static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, /*max_tokens =*/ n_tokens_alloc, /*is_view =*/ false, /*tokens =*/ nullptr, /*embd =*/ nullptr, /*pos =*/ nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, }; if (embd) { batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); } else { batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } batch->seq_id[n_tokens_alloc] = nullptr; batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) { return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max); } struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, size_t n_embd, int32_t pos0, int32_t seq_id) { struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); memcpy(batch->embd, embd, n_embd * sizeof(float)); for (size_t i = 0; i < n_embd; i++) { batch->pos [i] = pos0 + i; batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } return batch; } int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, size_t n_seq_ids, float logits) { if (batch->n_tokens + 1 > batch->max_tokens) { return -1; // llama_batch size exceeded } if (batch->embd) { return -2; // embd is already set, cannot add text tokens } batch->token [batch->n_tokens] = token; batch->pos [batch->n_tokens] = pos; batch->n_seq_id[batch->n_tokens] = n_seq_ids; for (size_t j = 0; j < n_seq_ids; j++) { batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } batch->logits [batch->n_tokens] = logits; batch->n_tokens++; return 0; } int32_t llama_batch_ext_set_logits( struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id) { for (int32_t i = 0; i < batch->n_tokens; i++) { // find the token having seq_id for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { if (batch->seq_id[i][j] == seq_id) { // found the sequence if (pos == -1 || pos == batch->pos[i]) { batch->logits[i] = true; return 0; } } } } return -1; // not found } int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { if (batch->n_tokens == 0) { return -1; } batch->logits[batch->n_tokens - 1] = true; return 0; } void llama_batch_ext_clear(struct llama_batch_ext * batch) { batch->n_tokens = 0; } struct llama_batch_ext * llama_batch_ext_get_view( struct llama_batch_ext * batch, int32_t offset, int32_t n_tokens) { if (batch->embd) { return nullptr; // not yet supported } llama_batch_ext * batch_view = new llama_batch_ext{ /*n_tokens =*/ n_tokens, /*max_tokens =*/ n_tokens, /*is_view =*/ true, /*tokens =*/ batch->token + offset, /*embd =*/ nullptr, /*pos =*/ batch->pos + offset, /*n_seq_id =*/ batch->n_seq_id + offset, /*seq_id =*/ batch->seq_id + offset, /*logits =*/ batch->logits + offset, }; return batch_view; } struct llama_batch_ext_token_info llama_batch_ext_get_token_info( struct llama_batch_ext * batch, int32_t i) { GGML_ASSERT(i >= 0 && i < batch->n_tokens); return llama_batch_ext_token_info{ /*token =*/ batch->token [i], /*pos =*/ batch->pos [i], /*n_seq_id =*/ batch->n_seq_id[i], /*seq_id =*/ batch->seq_id [i], /*logits =*/ batch->logits [i], }; } void llama_batch_ext_free(struct llama_batch_ext * batch) { // do not free the members if it's a view if (!batch->is_view) { if (batch->token) free(batch->token); if (batch->embd) free(batch->embd); if (batch->pos) free(batch->pos); if (batch->n_seq_id) free(batch->n_seq_id); if (batch->seq_id) { for (int i = 0; batch->seq_id[i] != nullptr; ++i) { free(batch->seq_id[i]); } free(batch->seq_id); } if (batch->logits) free(batch->logits); } delete batch; }