diff --git a/include/llama.h b/include/llama.h index 015a57898..d5e4cef68 100644 --- a/include/llama.h +++ b/include/llama.h @@ -243,14 +243,14 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - // Input data for llama_decode + // Input data for llama_encode/llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence - // (if set to NULL, the token position will be tracked automatically by llama_decode) + // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) // - seq_id : the sequence to which the respective token belongs // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index bdbf76626..2265db9b2 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -3,6 +3,7 @@ #include "llama-impl.h" #include "llama-cparams.h" #include "llama-vocab.h" +#include "llama-memory.h" #include #include @@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple llama_batch_allocr::llama_batch_allocr() { const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; + + seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES); + seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES); + for (auto & cur : seq_cpl) { + cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES); + } } -bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) { +bool llama_batch_allocr::init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory) { clear(); batch = batch_inp; GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - if (batch.seq_id) { - LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); - return false; - } - } + // + // validate input batch + // if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { @@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & } } - if (!batch.pos) { - assert(p0 >= 0); - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = p0 + i; - } - batch.pos = pos.data(); - } + // + // auto-generate missing fields + // if (!batch.n_seq_id) { n_seq_id.resize(batch.n_tokens); @@ -349,6 +351,32 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & batch.seq_id = seq_id.data(); } + if (!batch.pos) { + pos.resize(batch.n_tokens); + + // initialize the starting position for each sequence based on the positions in the memory + llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (!memory) { + p0[s] = 0; + } else { + p0[s] = memory->seq_pos_max(s) + 1; + } + } + + for (int32_t i = 0; i < batch.n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + + pos[i] = p0[seq_id]; + + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + p0[batch.seq_id[i][s]] = pos[i] + 1; + } + } + + batch.pos = pos.data(); + } + if (!batch.logits) { // by default return the output only for the last token output.resize(batch.n_tokens); @@ -356,13 +384,36 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & batch.logits = output.data(); } + // + // compute stats + // + for (int32_t i = 0; i < batch.n_tokens; ++i) { n_outputs += batch.logits[i] != 0; } + // determine coupled sequences + // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]); + + if (s > 0) { + const llama_seq_id s0 = batch.seq_id[i][0]; + const llama_seq_id s1 = batch.seq_id[i][s]; + + // mark that sequence s1 is coupled to s0 + seq_cpl[s1][s0] = true; + + // note: the other way around is not necessary for now + //seq_cpl[s0][s1] = true; + } + } + } + if (debug > 0) { - LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0); - LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens); + LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); + LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens); LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token); LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd); LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos); @@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]); } LLAMA_LOG_DEBUG("%s: ]\n", __func__); + + LLAMA_LOG_DEBUG("%s: seq = [\n", __func__); + for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) { + if (seq_pos[s0].empty()) { + continue; + } + + std::stringstream ss; + for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) { + if (seq_cpl[s0][s1]) { + ss << s1 << " "; + } + } + + LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n", + __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str()); + } + LLAMA_LOG_DEBUG("%s: ]\n", __func__); + } + } + + // + // consistency checks + // + + for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq_pos[s].empty()) { + continue; + } + + if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { + LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); + return false; + } + + if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { + LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); + return false; + } + } + + if (memory) { + for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) { + for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) { + if (seq_cpl[s0][s1]) { + if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || + memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { + LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1); + return false; + } + } + } } } @@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin(); +} + +llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin(); +} + void llama_batch_allocr::clear() { n_outputs = 0; @@ -426,6 +537,14 @@ void llama_batch_allocr::clear() { n_seq_id.clear(); seq_id.clear(); output.clear(); + + for (auto & cur : seq_pos) { + cur.clear(); + } + + for (auto & cur : seq_cpl) { + std::fill(cur.begin(), cur.end(), false); + } } // diff --git a/src/llama-batch.h b/src/llama-batch.h index 1e0be8ac2..04501ce5d 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -4,6 +4,7 @@ #include #include +#include // very similar to llama_batch, // but has more metadata about sequences @@ -77,18 +78,25 @@ struct llama_sbatch { llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); }; -// temporary allocate memory for the input batch if needed +// a helper for sanitizing and fulfilling a batch class llama_batch_allocr { public: llama_batch_allocr(); - // optionally fulfill the batch returned by llama_batch_get_one - bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0); + // sanitize and auto-gen missing data in the input batch + // memory is optional. if provided will be used to check for sequence continuity and to determine the positions + bool init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory); const llama_batch & get_batch() const; uint32_t get_n_outputs() const; + llama_pos seq_pos_min(llama_seq_id seq_id) const; + llama_pos seq_pos_max(llama_seq_id seq_id) const; + private: void clear(); @@ -103,5 +111,8 @@ private: std::vector seq_id; std::vector output; + std::vector> seq_pos; // seq_pos[s]: the set of positions in sequence s + std::vector> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 + int debug; }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ec1e1189b..47c60e960 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) { return -1; } - // temporary allocate memory for the input batch if needed // note: during encode, we always pass the full sequence starting from pos = 0 - if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) { + if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -895,8 +894,7 @@ int llama_context::decode(const llama_batch & batch_inp) { return -1; } - // temporary allocate memory for the input batch if needed - if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) { + if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2871031ef..51ebe5d17 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -4,6 +4,7 @@ #include +// TODO: rename to something shorter #define LLAMA_MAX_PARALLEL_SEQUENCES 64 struct llama_cparams {