mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (#14188)
ggml-ci
This commit is contained in:
@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
|
|||||||
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
||||||
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
||||||
|
|
||||||
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
seq_pos.resize(LLAMA_MAX_SEQ);
|
||||||
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
seq_cpl.resize(LLAMA_MAX_SEQ);
|
||||||
for (auto & cur : seq_cpl) {
|
for (auto & cur : seq_cpl) {
|
||||||
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
|
cur.resize(LLAMA_MAX_SEQ);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
|
|||||||
if (batch.seq_id) {
|
if (batch.seq_id) {
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
|
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
|
|||||||
pos.resize(batch.n_tokens);
|
pos.resize(batch.n_tokens);
|
||||||
|
|
||||||
// initialize the starting position for each sequence based on the positions in the memory
|
// initialize the starting position for each sequence based on the positions in the memory
|
||||||
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
|
llama_pos p0[LLAMA_MAX_SEQ];
|
||||||
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (!memory) {
|
if (!memory) {
|
||||||
p0[s] = 0;
|
p0[s] = 0;
|
||||||
} else {
|
} else {
|
||||||
@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
|
|||||||
// consistency checks
|
// consistency checks
|
||||||
//
|
//
|
||||||
|
|
||||||
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq_pos[s].empty()) {
|
if (seq_pos[s].empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
|
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
|
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||||
if (seq_cpl[s0][s1]) {
|
if (seq_cpl[s0][s1]) {
|
||||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||||
|
@ -29,8 +29,8 @@ llama_context::llama_context(
|
|||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||||
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
|
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
||||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
|
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||||
}
|
}
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||||
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
llama_pos pos_min[LLAMA_MAX_SEQ];
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
|
|
||||||
size_t llama_max_parallel_sequences(void) {
|
size_t llama_max_parallel_sequences(void) {
|
||||||
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
return LLAMA_MAX_SEQ;
|
||||||
}
|
}
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
// TODO: rename to something shorter
|
#define LLAMA_MAX_SEQ 64
|
||||||
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
|
||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
|
@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (cells.seq_pos_min(s) < 0) {
|
if (cells.seq_pos_min(s) < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
|
|
||||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
// for non-SWA cache, this would be always empty
|
// for non-SWA cache, this would be always empty
|
||||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq_pos_max_rm[s] == -1) {
|
if (seq_pos_max_rm[s] == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ public:
|
|||||||
|
|
||||||
used.clear();
|
used.clear();
|
||||||
|
|
||||||
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
seq_pos[s].clear();
|
seq_pos[s].clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -240,7 +240,7 @@ public:
|
|||||||
llama_seq_id seq_get(uint32_t i) const {
|
llama_seq_id seq_get(uint32_t i) const {
|
||||||
assert(seq[i].count() == 1);
|
assert(seq[i].count() == 1);
|
||||||
|
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq[i].test(s)) {
|
if (seq[i].test(s)) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
@ -253,7 +253,7 @@ public:
|
|||||||
// return -1 if the sequence is not present
|
// return -1 if the sequence is not present
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
||||||
assert(seq_id >= 0);
|
assert(seq_id >= 0);
|
||||||
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
assert(seq_id < LLAMA_MAX_SEQ);
|
||||||
|
|
||||||
if (seq_pos[seq_id].empty()) {
|
if (seq_pos[seq_id].empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
@ -266,7 +266,7 @@ public:
|
|||||||
// return -1 if the sequence is not present
|
// return -1 if the sequence is not present
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
||||||
assert(seq_id >= 0);
|
assert(seq_id >= 0);
|
||||||
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
|
assert(seq_id < LLAMA_MAX_SEQ);
|
||||||
|
|
||||||
if (seq_pos[seq_id].empty()) {
|
if (seq_pos[seq_id].empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
@ -384,20 +384,20 @@ private:
|
|||||||
//
|
//
|
||||||
std::vector<llama_pos> shift;
|
std::vector<llama_pos> shift;
|
||||||
|
|
||||||
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
|
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||||
|
|
||||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||||
std::vector<bits_t> seq;
|
std::vector<bits_t> seq;
|
||||||
|
|
||||||
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
||||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||||
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
|
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
||||||
|
|
||||||
// helper functions for updating `seq_pos`, once cell at a time:
|
// helper functions for updating `seq_pos`, once cell at a time:
|
||||||
|
|
||||||
// remove cell i
|
// remove cell i
|
||||||
void seq_pos_rm(uint32_t i) {
|
void seq_pos_rm(uint32_t i) {
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq[i].test(s)) {
|
if (seq[i].test(s)) {
|
||||||
seq_pos[s].erase(pos[i]);
|
seq_pos[s].erase(pos[i]);
|
||||||
}
|
}
|
||||||
@ -406,7 +406,7 @@ private:
|
|||||||
|
|
||||||
// add cell i
|
// add cell i
|
||||||
void seq_pos_add(uint32_t i) {
|
void seq_pos_add(uint32_t i) {
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq[i].test(s)) {
|
if (seq[i].test(s)) {
|
||||||
seq_pos[s].insert(pos[i]);
|
seq_pos[s].insert(pos[i]);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user