mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
fix llama_batch_ext_init_from_text
This commit is contained in:
@ -1014,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
|
||||
llama_encode_ext(lctx, batch.get());
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||
@ -1024,7 +1024,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
|
||||
llama_decode_ext(lctx, batch.get());
|
||||
}
|
||||
llama_kv_self_clear(lctx);
|
||||
|
@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||
|
||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||
llama_kv_self_clear(ctx);
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
@ -353,7 +353,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
|
@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
|
||||
for (int i = 1; i < n_tokens; i++) {
|
||||
tokens[i] = std::rand() % n_vocab;
|
||||
}
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true));
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true));
|
||||
llama_decode_ext(ctx, batch.get());
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
|
@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
|
@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
|
||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
|
@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
|
||||
llama_decode_ext(ctx, batch0.get());
|
||||
llama_decode_ext(ctx, batch1.get());
|
||||
|
||||
|
@ -91,8 +91,8 @@ int main(int argc, char ** argv){
|
||||
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
|
||||
llama_decode_ext(ctx, batch0.get());
|
||||
llama_decode_ext(ctx, batch1.get());
|
||||
|
||||
|
@ -548,7 +548,7 @@ int main(int argc, char ** argv) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true));
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
@ -669,7 +669,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
|
||||
llama_batch_ext_set_output_last(batch.get());
|
||||
if (llama_decode_ext(ctx, batch.get())) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
|
@ -946,7 +946,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
}
|
||||
|
||||
// prepare a batch for the prompt
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
check_context_size(llama_data.context, batch);
|
||||
@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
print_word_and_concatenate_to_response(piece, response);
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0));
|
||||
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true));
|
||||
}
|
||||
|
||||
printf(LOG_COL_DEFAULT);
|
||||
|
@ -48,7 +48,7 @@ int main(int argc, char ** argv) {
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0);
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode_ext(ctx, batch);
|
||||
|
@ -108,8 +108,11 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// prepare a batch for the prompt
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
||||
llama_pos n_past = 0;
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
n_past += llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
// check if we have enough space in the context to evaluate this batch
|
||||
@ -147,7 +150,8 @@ int main(int argc, char ** argv) {
|
||||
// prepare the next batch with the sampled token
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
|
||||
llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true);
|
||||
n_past++;
|
||||
}
|
||||
|
||||
llama_batch_ext_free(batch);
|
||||
|
@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// prepare a batch for the prompt
|
||||
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
|
||||
// main loop
|
||||
|
@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
|
||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
||||
|
||||
// eval the prompt
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0));
|
||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true));
|
||||
llama_decode_ext(ctx_tgt, batch.get());
|
||||
|
||||
// note: keep the last token separate!
|
||||
|
@ -166,9 +166,9 @@ int main(int argc, char ** argv) {
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt with both models
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0));
|
||||
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0));
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
|
||||
llama_decode_ext(ctx_tgt, batch0);
|
||||
llama_decode_ext(ctx_tgt, batch1);
|
||||
llama_decode_ext(ctx_dft, batch2);
|
||||
|
@ -928,12 +928,14 @@ extern "C" {
|
||||
// Same with llama_batch_init, but initializes the batch with the provided text tokens
|
||||
// First token will be at position pos0
|
||||
// The sequence ID will be fixed to seq_id
|
||||
// If output_last is true, the last token will have output set
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
int32_t seq_id,
|
||||
bool output_last);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
||||
// First token will be at position pos0
|
||||
|
@ -341,11 +341,15 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
int32_t seq_id,
|
||||
bool output_last) {
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
|
||||
}
|
||||
if (output_last) {
|
||||
llama_batch_ext_set_output_last(batch);
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user