mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
@ -362,29 +362,31 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (!prepare(ubatches)) {
|
||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
|
||||
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
||||
@ -423,9 +425,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
|
||||
}
|
||||
|
||||
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
@ -445,9 +446,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// everything should fit if all seq_ids are smaller than the max
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
||||
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[i];
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
||||
// too big seq_id
|
||||
@ -506,7 +509,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// find usable cell range
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
auto & seq_meta = cells[seq_id];
|
||||
bool has_cell = false;
|
||||
if (seq_meta.tail >= 0) {
|
||||
@ -530,7 +534,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
if (s + 1 < n_seqs) {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
auto & cell = cells[next_empty_cell];
|
||||
@ -544,8 +548,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// gather and re-order
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const int32_t dst_id = s + min;
|
||||
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
auto & dst_cell = cells[dst_id];
|
||||
auto & src_cell = cells[src_id];
|
||||
@ -555,8 +560,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
||||
|
||||
// swap tails
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
int32_t & tail = cells[i].tail;
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
int32_t & tail = cells[j].tail;
|
||||
if (tail == src_id) {
|
||||
tail = dst_id;
|
||||
} else if (tail == dst_id) {
|
||||
@ -568,7 +573,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// update the pos of the used seqs
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
|
||||
const int32_t cell_id = s + min;
|
||||
auto & cell = cells[cell_id];
|
||||
|
||||
@ -576,12 +582,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
|
||||
}
|
||||
cell.pos = last_pos;
|
||||
cell.seq_id.clear();
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
||||
cell.seq_id.insert(seq_id);
|
||||
cells[seq_id].tail = cell_id;
|
||||
}
|
||||
@ -827,12 +833,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
seq_rm(dest_seq_id, -1, -1);
|
||||
|
||||
llama_sbatch sbatch;
|
||||
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
||||
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seq_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@ -846,12 +849,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
ubatch.pos[i] = pos;
|
||||
}
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0] = &dest_seq_id;
|
||||
ubatch.n_seq_id[0] = 1;
|
||||
ubatch.seq_id[0] = &dest_seq_id;
|
||||
|
||||
if (!find_slot(batch)) {
|
||||
if (!find_slot(ubatch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@ -859,8 +862,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head + cell_count <= size);
|
||||
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
} else {
|
||||
@ -1048,8 +1051,7 @@ llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||
|
||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||
llama_memory_recurrent * mem,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
||||
|
||||
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
||||
|
||||
@ -1071,12 +1073,6 @@ bool llama_memory_recurrent_state::apply() {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
Reference in New Issue
Block a user