batch : auto-gen positions + verify multi-sequence input (#14177)

* batch : verify multi-sequence input batches

ggml-ci

* cont : auto-gen positions + verify multi-seq input

ggml-ci

* cont : first print debug info, then perform validation

ggml-ci

* cont : fix position auto-gen + add comments

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-15 09:18:37 +03:00
committed by GitHub
parent 00ba772610
commit b9912ac570
5 changed files with 155 additions and 26 deletions

View File

@@ -4,6 +4,7 @@
#include <array>
#include <vector>
#include <set>
// very similar to llama_batch,
// but has more metadata about sequences
@@ -77,18 +78,25 @@ struct llama_sbatch {
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};
// temporary allocate memory for the input batch if needed
// a helper for sanitizing and fulfilling a batch
class llama_batch_allocr {
public:
llama_batch_allocr();
// optionally fulfill the batch returned by llama_batch_get_one
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);
const llama_batch & get_batch() const;
uint32_t get_n_outputs() const;
llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;
private:
void clear();
@@ -103,5 +111,8 @@ private:
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
int debug;
};