apply various in places

This commit is contained in:
Xuan Son Nguyen
2025-03-01 20:42:18 +01:00
parent 1d6ba97789
commit 46596caf6d
12 changed files with 142 additions and 133 deletions

View File

@@ -565,6 +565,52 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id;
bool logits;
};
std::vector<batch_token> tokens;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_logits_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
}
return view;
}
};
//
// Token utils
//