diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 2265db9b2..a9f4a3d4c 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() { const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; - seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES); - seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES); + seq_pos.resize(LLAMA_MAX_SEQ); + seq_cpl.resize(LLAMA_MAX_SEQ); for (auto & cur : seq_cpl) { - cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES); + cur.resize(LLAMA_MAX_SEQ); } } @@ -322,8 +322,8 @@ bool llama_batch_allocr::init( if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES); + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); return false; } } @@ -355,8 +355,8 @@ bool llama_batch_allocr::init( pos.resize(batch.n_tokens); // initialize the starting position for each sequence based on the positions in the memory - llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + llama_pos p0[LLAMA_MAX_SEQ]; + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (!memory) { p0[s] = 0; } else { @@ -480,7 +480,7 @@ bool llama_batch_allocr::init( // consistency checks // - for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos[s].empty()) { continue; } @@ -497,8 +497,8 @@ bool llama_batch_allocr::init( } if (memory) { - for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) { - for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) { + for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { + for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { if (seq_cpl[s0][s1]) { if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 47c60e960..3a113d1bc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -29,8 +29,8 @@ llama_context::llama_context( const auto & hparams = model.hparams; cparams.n_seq_max = std::max(1u, params.n_seq_max); - if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { - throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); + if (cparams.n_seq_max > LLAMA_MAX_SEQ) { + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } cparams.n_threads = params.n_threads; @@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) { if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache - llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + llama_pos pos_min[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { pos_min[s] = std::numeric_limits::max(); } @@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) { pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); } - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (pos_min[s] == std::numeric_limits::max()) { continue; } diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp index f7b36590f..a3e7a37ee 100644 --- a/src/llama-cparams.cpp +++ b/src/llama-cparams.cpp @@ -1,5 +1,5 @@ #include "llama-cparams.h" size_t llama_max_parallel_sequences(void) { - return LLAMA_MAX_PARALLEL_SEQUENCES; + return LLAMA_MAX_SEQ; } diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 51ebe5d17..118615d5b 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -4,8 +4,7 @@ #include -// TODO: rename to something shorter -#define LLAMA_MAX_PARALLEL_SEQUENCES 64 +#define LLAMA_MAX_SEQ 64 struct llama_cparams { uint32_t n_ctx; // context size used during inference diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index d4e92eab3..031070570 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); } - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (cells.seq_pos_min(s) < 0) { continue; } @@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } @@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index acf30aebe..1d4e70f4d 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -23,7 +23,7 @@ public: used.clear(); - for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos[s].clear(); } } @@ -240,7 +240,7 @@ public: llama_seq_id seq_get(uint32_t i) const { assert(seq[i].count() == 1); - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq[i].test(s)) { return s; } @@ -253,7 +253,7 @@ public: // return -1 if the sequence is not present llama_pos seq_pos_min(llama_seq_id seq_id) const { assert(seq_id >= 0); - assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + assert(seq_id < LLAMA_MAX_SEQ); if (seq_pos[seq_id].empty()) { return -1; @@ -266,7 +266,7 @@ public: // return -1 if the sequence is not present llama_pos seq_pos_max(llama_seq_id seq_id) const { assert(seq_id >= 0); - assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + assert(seq_id < LLAMA_MAX_SEQ); if (seq_pos[seq_id].empty()) { return -1; @@ -384,20 +384,20 @@ private: // std::vector shift; - using bits_t = std::bitset; + using bits_t = std::bitset; // the bitset seq[i] tells us which sequences are currently occupying the i-th cell std::vector seq; // the set seq_pos[s] tells us which positions are currently present for sequence s // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache - std::set seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; + std::set seq_pos[LLAMA_MAX_SEQ]; // helper functions for updating `seq_pos`, once cell at a time: // remove cell i void seq_pos_rm(uint32_t i) { - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq[i].test(s)) { seq_pos[s].erase(pos[i]); } @@ -406,7 +406,7 @@ private: // add cell i void seq_pos_add(uint32_t i) { - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq[i].test(s)) { seq_pos[s].insert(pos[i]); }