cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (#14188)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-15 10:08:58 +03:00
committed by GitHub
parent b9912ac570
commit c311ac664d
6 changed files with 29 additions and 30 deletions

View File

@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES); seq_pos.resize(LLAMA_MAX_SEQ);
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES); seq_cpl.resize(LLAMA_MAX_SEQ);
for (auto & cur : seq_cpl) { for (auto & cur : seq_cpl) {
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES); cur.resize(LLAMA_MAX_SEQ);
} }
} }
@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
if (batch.seq_id) { if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES); LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
return false; return false;
} }
} }
@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
pos.resize(batch.n_tokens); pos.resize(batch.n_tokens);
// initialize the starting position for each sequence based on the positions in the memory // initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES]; llama_pos p0[LLAMA_MAX_SEQ];
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (!memory) { if (!memory) {
p0[s] = 0; p0[s] = 0;
} else { } else {
@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
// consistency checks // consistency checks
// //
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos[s].empty()) { if (seq_pos[s].empty()) {
continue; continue;
} }
@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
} }
if (memory) { if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) { for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) { for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
if (seq_cpl[s0][s1]) { if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {

View File

@ -29,8 +29,8 @@ llama_context::llama_context(
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_seq_max = std::max(1u, params.n_seq_max);
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
} }
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max(); pos_min[s] = std::numeric_limits<llama_pos>::max();
} }
@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
} }
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
continue; continue;
} }

View File

@ -1,5 +1,5 @@
#include "llama-cparams.h" #include "llama-cparams.h"
size_t llama_max_parallel_sequences(void) { size_t llama_max_parallel_sequences(void) {
return LLAMA_MAX_PARALLEL_SEQUENCES; return LLAMA_MAX_SEQ;
} }

View File

@ -4,8 +4,7 @@
#include <cstdint> #include <cstdint>
// TODO: rename to something shorter #define LLAMA_MAX_SEQ 64
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
struct llama_cparams { struct llama_cparams {
uint32_t n_ctx; // context size used during inference uint32_t n_ctx; // context size used during inference

View File

@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
} }
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (cells.seq_pos_min(s) < 0) { if (cells.seq_pos_min(s) < 0) {
continue; continue;
} }
@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
// will be present in the cache. so we have to purge any position which is less than those we would overwrite // will be present in the cache. so we have to purge any position which is less than those we would overwrite
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq_pos_max_rm[s] == -1) { if (seq_pos_max_rm[s] == -1) {
continue; continue;
} }

View File

@ -23,7 +23,7 @@ public:
used.clear(); used.clear();
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos[s].clear(); seq_pos[s].clear();
} }
} }
@ -240,7 +240,7 @@ public:
llama_seq_id seq_get(uint32_t i) const { llama_seq_id seq_get(uint32_t i) const {
assert(seq[i].count() == 1); assert(seq[i].count() == 1);
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) { if (seq[i].test(s)) {
return s; return s;
} }
@ -253,7 +253,7 @@ public:
// return -1 if the sequence is not present // return -1 if the sequence is not present
llama_pos seq_pos_min(llama_seq_id seq_id) const { llama_pos seq_pos_min(llama_seq_id seq_id) const {
assert(seq_id >= 0); assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) { if (seq_pos[seq_id].empty()) {
return -1; return -1;
@ -266,7 +266,7 @@ public:
// return -1 if the sequence is not present // return -1 if the sequence is not present
llama_pos seq_pos_max(llama_seq_id seq_id) const { llama_pos seq_pos_max(llama_seq_id seq_id) const {
assert(seq_id >= 0); assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) { if (seq_pos[seq_id].empty()) {
return -1; return -1;
@ -384,20 +384,20 @@ private:
// //
std::vector<llama_pos> shift; std::vector<llama_pos> shift;
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>; using bits_t = std::bitset<LLAMA_MAX_SEQ>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq; std::vector<bits_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s // the set seq_pos[s] tells us which positions are currently present for sequence s
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time: // helper functions for updating `seq_pos`, once cell at a time:
// remove cell i // remove cell i
void seq_pos_rm(uint32_t i) { void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) { if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]); seq_pos[s].erase(pos[i]);
} }
@ -406,7 +406,7 @@ private:
// add cell i // add cell i
void seq_pos_add(uint32_t i) { void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) { if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]); seq_pos[s].insert(pos[i]);
} }