fix common_batch missing seq_id

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:38:04 +01:00
parent 47086fa82d
commit 9fb2d81eab

View File

@ -586,6 +586,7 @@ struct common_batch {
llama_batch_ext_ptr batch; llama_batch_ext_ptr batch;
struct batch_token { struct batch_token {
llama_token token; llama_token token;
llama_seq_id seq_id; // only support single seq for now
bool logits; bool logits;
}; };
std::vector<batch_token> tokens; std::vector<batch_token> tokens;
@ -601,14 +602,14 @@ struct common_batch {
} }
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, logits}); tokens.push_back({token, seq_id, logits});
if (logits) { if (logits) {
n_outputs++; n_outputs++;
} }
} }
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, logits}); tokens.push_back({token, seq_ids[0], logits});
if (logits) { if (logits) {
n_outputs++; n_outputs++;
} }