mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 12:25:03 +00:00
llama_batch_ext_ptr::from_text/embd
This commit is contained in:
@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
|||||||
|
|
||||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||||
if (llama_decode_ext(ctx, batch.get())) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {
|
|||||||
|
|
||||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||||
|
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true);
|
||||||
if (llama_decode_ext(ctx, batch.get())) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -353,7 +353,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||||
|
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true);
|
||||||
if (llama_decode_ext(ctx, batch.get())) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||||||
for (int i = 1; i < n_tokens; i++) {
|
for (int i = 1; i < n_tokens; i++) {
|
||||||
tokens[i] = std::rand() % n_vocab;
|
tokens[i] = std::rand() % n_vocab;
|
||||||
}
|
}
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
|
||||||
llama_decode_ext(ctx, batch.get());
|
llama_decode_ext(ctx, batch.get());
|
||||||
n_processed += n_tokens;
|
n_processed += n_tokens;
|
||||||
}
|
}
|
||||||
@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
|
|||||||
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
||||||
|
|
||||||
for (int i = 0; i < n_gen; i++) {
|
for (int i = 0; i < n_gen; i++) {
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(&token, 1, n_past + i, 0, true);
|
||||||
llama_decode_ext(ctx, batch.get());
|
llama_decode_ext(ctx, batch.get());
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
token = std::rand() % n_vocab;
|
token = std::rand() % n_vocab;
|
||||||
|
@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true);
|
||||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||||
return false;
|
return false;
|
||||||
|
@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
float * embd = image_embed->embed+i*n_embd;
|
float * embd = image_embed->embed+i*n_embd;
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0));
|
auto batch = llama_batch_ext_ptr::from_embd(embd, n_eval, n_embd, 0, 0);
|
||||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true);
|
||||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||||
return false;
|
return false;
|
||||||
|
@ -67,7 +67,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
|||||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||||
|
|
||||||
float * batch_embd = image_embed->embed+i*n_embd;
|
float * batch_embd = image_embed->embed+i*n_embd;
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0));
|
auto batch = llama_batch_ext_ptr::from_embd(batch_embd, n_eval, n_embd, 0, 0);
|
||||||
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
|
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
|
||||||
|
|
||||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
|
@ -548,7 +548,7 @@ int main(int argc, char ** argv) {
|
|||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
llama_token * enc_input_buf = embd_inp.data();
|
llama_token * enc_input_buf = embd_inp.data();
|
||||||
|
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(enc_input_buf, enc_input_size, 0, 0, true);
|
||||||
if (llama_decode_ext(ctx, batch.get())) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
@ -669,7 +669,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||||
|
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true);
|
||||||
llama_batch_ext_set_output_last(batch.get());
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
if (llama_decode_ext(ctx, batch.get())) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
|
@ -947,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare a batch for the prompt
|
// prepare a batch for the prompt
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true);
|
||||||
llama_token new_token_id;
|
llama_token new_token_id;
|
||||||
while (true) {
|
while (true) {
|
||||||
check_context_size(llama_data.context, batch);
|
check_context_size(llama_data.context, batch);
|
||||||
|
@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
|
|||||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
||||||
|
|
||||||
// eval the prompt
|
// eval the prompt
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true));
|
auto batch = llama_batch_ext_ptr::from_text(inp.data(), inp.size() - 1, 0, 0, true);
|
||||||
llama_decode_ext(ctx_tgt, batch.get());
|
llama_decode_ext(ctx_tgt, batch.get());
|
||||||
|
|
||||||
// note: keep the last token separate!
|
// note: keep the last token separate!
|
||||||
|
@ -32,4 +32,24 @@ typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
|||||||
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
||||||
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
||||||
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
||||||
typedef std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> llama_batch_ext_ptr;
|
|
||||||
|
struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> {
|
||||||
|
llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>(batch) {}
|
||||||
|
|
||||||
|
// convience function to create a batch from text tokens, without worrying about manually freeing it
|
||||||
|
static llama_batch_ext_ptr from_text(llama_token * tokens,
|
||||||
|
int32_t n_tokens,
|
||||||
|
int32_t pos0,
|
||||||
|
int32_t seq_id,
|
||||||
|
bool output_last) {
|
||||||
|
return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last));
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_batch_ext_ptr from_embd(float * embd,
|
||||||
|
size_t n_tokens,
|
||||||
|
size_t n_embd,
|
||||||
|
int32_t pos0,
|
||||||
|
int32_t seq_id) {
|
||||||
|
return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Reference in New Issue
Block a user