diff --git a/common/common.h b/common/common.h index 94f335206..4bcc4a2c4 100644 --- a/common/common.h +++ b/common/common.h @@ -586,6 +586,7 @@ struct common_batch { llama_batch_ext_ptr batch; struct batch_token { llama_token token; + llama_seq_id seq_id; // only support single seq for now bool logits; }; std::vector tokens; @@ -601,14 +602,14 @@ struct common_batch { } void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, logits}); + tokens.push_back({token, seq_id, logits}); if (logits) { n_outputs++; } } void add_text(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); - tokens.push_back({token, logits}); + tokens.push_back({token, seq_ids[0], logits}); if (logits) { n_outputs++; }