mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
apply to the rest
This commit is contained in:
@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
|
|||||||
return buf.str();
|
return buf.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
|
|
||||||
std::stringstream buf;
|
|
||||||
|
|
||||||
buf << "[ ";
|
|
||||||
|
|
||||||
bool first = true;
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
||||||
if (!first) {
|
|
||||||
buf << ", ";
|
|
||||||
} else {
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
|
|
||||||
|
|
||||||
detokenized.erase(
|
|
||||||
std::remove_if(
|
|
||||||
detokenized.begin(),
|
|
||||||
detokenized.end(),
|
|
||||||
[](const unsigned char c) { return !std::isprint(c); }),
|
|
||||||
detokenized.end());
|
|
||||||
|
|
||||||
buf << "\n" << std::to_string(i)
|
|
||||||
<< ", token '" << detokenized << "'"
|
|
||||||
<< ", pos " << std::to_string(batch.pos[i])
|
|
||||||
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
|
|
||||||
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
|
|
||||||
<< ", logits " << std::to_string(batch.logits[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
buf << " ]";
|
|
||||||
|
|
||||||
return buf.str();
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
void string_process_escapes(std::string & input) {
|
void string_process_escapes(std::string & input) {
|
||||||
std::size_t input_len = input.length();
|
std::size_t input_len = input.length();
|
||||||
std::size_t output_idx = 0;
|
std::size_t output_idx = 0;
|
||||||
|
@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
|
|||||||
std::string string_from(bool value);
|
std::string string_from(bool value);
|
||||||
std::string string_from(const std::vector<int> & values);
|
std::string string_from(const std::vector<int> & values);
|
||||||
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
||||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Filesystem utils
|
// Filesystem utils
|
||||||
@ -587,10 +586,10 @@ struct common_batch {
|
|||||||
llama_batch_ext_ptr batch;
|
llama_batch_ext_ptr batch;
|
||||||
struct batch_token {
|
struct batch_token {
|
||||||
llama_token token;
|
llama_token token;
|
||||||
llama_seq_id seq_id;
|
|
||||||
bool logits;
|
bool logits;
|
||||||
};
|
};
|
||||||
std::vector<batch_token> tokens;
|
std::vector<batch_token> tokens;
|
||||||
|
int n_outputs = 0;
|
||||||
common_batch() = default;
|
common_batch() = default;
|
||||||
common_batch(int32_t n_tokens, int32_t n_seq_max) {
|
common_batch(int32_t n_tokens, int32_t n_seq_max) {
|
||||||
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
|
||||||
@ -602,7 +601,17 @@ struct common_batch {
|
|||||||
}
|
}
|
||||||
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
|
||||||
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
|
||||||
tokens.push_back({token, seq_id, logits});
|
tokens.push_back({token, logits});
|
||||||
|
if (logits) {
|
||||||
|
n_outputs++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
|
||||||
|
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
|
||||||
|
tokens.push_back({token, logits});
|
||||||
|
if (logits) {
|
||||||
|
n_outputs++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void set_logits_last() {
|
void set_logits_last() {
|
||||||
if (!tokens.empty()) {
|
if (!tokens.empty()) {
|
||||||
@ -622,6 +631,9 @@ struct common_batch {
|
|||||||
view.tokens.reserve(n_tokens);
|
view.tokens.reserve(n_tokens);
|
||||||
for (int32_t i = 0; i < n_tokens; i++) {
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
view.tokens.push_back(tokens[offset + i]);
|
view.tokens.push_back(tokens[offset + i]);
|
||||||
|
if (tokens[offset + i].logits) {
|
||||||
|
view.n_outputs++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "console.h"
|
#include "console.h"
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ struct gemma3_context {
|
|||||||
llama_model * model;
|
llama_model * model;
|
||||||
llama_context * lctx;
|
llama_context * lctx;
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
llama_batch batch;
|
llama_batch_ext_ptr batch;
|
||||||
|
|
||||||
int n_threads = 1;
|
int n_threads = 1;
|
||||||
llama_pos n_past = 0;
|
llama_pos n_past = 0;
|
||||||
@ -73,7 +74,7 @@ struct gemma3_context {
|
|||||||
lctx = llama_init.context.get();
|
lctx = llama_init.context.get();
|
||||||
vocab = llama_model_get_vocab(model);
|
vocab = llama_model_get_vocab(model);
|
||||||
n_threads = params.cpuparams.n_threads;
|
n_threads = params.cpuparams.n_threads;
|
||||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
batch.reset(llama_batch_ext_init(params.n_batch, 1));
|
||||||
init_clip_model(params);
|
init_clip_model(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,50 +88,18 @@ struct gemma3_context {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct decode_embd_batch {
|
|
||||||
std::vector<llama_pos> pos;
|
|
||||||
std::vector<int32_t> n_seq_id;
|
|
||||||
std::vector<llama_seq_id> seq_id_0;
|
|
||||||
std::vector<llama_seq_id *> seq_ids;
|
|
||||||
std::vector<int8_t> logits;
|
|
||||||
llama_batch batch;
|
|
||||||
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
|
||||||
pos .resize(n_tokens);
|
|
||||||
n_seq_id.resize(n_tokens);
|
|
||||||
seq_ids .resize(n_tokens + 1);
|
|
||||||
logits .resize(n_tokens);
|
|
||||||
seq_id_0.resize(1);
|
|
||||||
seq_id_0[0] = seq_id;
|
|
||||||
seq_ids [n_tokens] = nullptr;
|
|
||||||
batch = {
|
|
||||||
/*n_tokens =*/ n_tokens,
|
|
||||||
/*tokens =*/ nullptr,
|
|
||||||
/*embd =*/ embd,
|
|
||||||
/*pos =*/ pos.data(),
|
|
||||||
/*n_seq_id =*/ n_seq_id.data(),
|
|
||||||
/*seq_id =*/ seq_ids.data(),
|
|
||||||
/*logits =*/ logits.data(),
|
|
||||||
};
|
|
||||||
for (int i = 0; i < n_tokens; i++) {
|
|
||||||
batch.pos [i] = pos_0 + i;
|
|
||||||
batch.n_seq_id[i] = 1;
|
|
||||||
batch.seq_id [i] = seq_id_0.data();
|
|
||||||
batch.logits [i] = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
||||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
||||||
common_batch_clear(ctx.batch);
|
llama_batch_ext_clear(ctx.batch.get());
|
||||||
for (llama_token & t : tokens) {
|
for (llama_token & t : tokens) {
|
||||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
if (logits_last) {
|
if (logits_last) {
|
||||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(ctx.batch.get());
|
||||||
}
|
}
|
||||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
||||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||||
LOG_ERR("Failed to decode text\n");
|
LOG_ERR("Failed to decode text\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
|
|||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
eval_text(ctx, "<start_of_image>");
|
eval_text(ctx, "<start_of_image>");
|
||||||
llama_set_causal_attn(ctx.lctx, false);
|
llama_set_causal_attn(ctx.lctx, false);
|
||||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
|
||||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
|
||||||
LOG_ERR("failed to decode image\n");
|
LOG_ERR("failed to decode image\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
// eval the token
|
// eval the token
|
||||||
common_batch_clear(ctx.batch);
|
llama_batch_ext_clear(ctx.batch.get());
|
||||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
llama_seq_id seq_id = 0;
|
||||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
|
||||||
|
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
|
||||||
LOG_ERR("failed to decode token\n");
|
LOG_ERR("failed to decode token\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "llava.h"
|
#include "llava.h"
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llava_embd_batch {
|
|
||||||
std::vector<llama_pos> pos;
|
|
||||||
std::vector<int32_t> n_seq_id;
|
|
||||||
std::vector<llama_seq_id> seq_id_0;
|
|
||||||
std::vector<llama_seq_id *> seq_ids;
|
|
||||||
std::vector<int8_t> logits;
|
|
||||||
llama_batch batch;
|
|
||||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
|
||||||
pos .resize(n_tokens);
|
|
||||||
n_seq_id.resize(n_tokens);
|
|
||||||
seq_ids .resize(n_tokens + 1);
|
|
||||||
logits .resize(n_tokens);
|
|
||||||
seq_id_0.resize(1);
|
|
||||||
seq_id_0[0] = seq_id;
|
|
||||||
seq_ids [n_tokens] = nullptr;
|
|
||||||
batch = {
|
|
||||||
/*n_tokens =*/ n_tokens,
|
|
||||||
/*tokens =*/ nullptr,
|
|
||||||
/*embd =*/ embd,
|
|
||||||
/*pos =*/ pos.data(),
|
|
||||||
/*n_seq_id =*/ n_seq_id.data(),
|
|
||||||
/*seq_id =*/ seq_ids.data(),
|
|
||||||
/*logits =*/ logits.data(),
|
|
||||||
};
|
|
||||||
for (int i = 0; i < n_tokens; i++) {
|
|
||||||
batch.pos [i] = pos_0 + i;
|
|
||||||
batch.n_seq_id[i] = 1;
|
|
||||||
batch.seq_id [i] = seq_id_0.data();
|
|
||||||
batch.logits [i] = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
||||||
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
float * embd = image_embed->embed+i*n_embd;
|
float * embd = image_embed->embed+i*n_embd;
|
||||||
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
|
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
|
||||||
if (llama_decode(ctx_llama, llava_batch.batch)) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
|||||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||||
|
|
||||||
|
// TODO: move this to llama_batch_ext API
|
||||||
llama_batch batch = {
|
llama_batch batch = {
|
||||||
int32_t(n_eval), // n_tokens
|
int32_t(n_eval), // n_tokens
|
||||||
nullptr, // token
|
nullptr, // token
|
||||||
|
@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
|
|||||||
// seq_id == 0 : the current input token
|
// seq_id == 0 : the current input token
|
||||||
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
||||||
// seq_id [W + 1, W + G] : verification n-grams
|
// seq_id [W + 1, W + G] : verification n-grams
|
||||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
|
||||||
|
|
||||||
// target model sampling context
|
// target model sampling context
|
||||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||||
@ -204,10 +204,10 @@ int main(int argc, char ** argv) {
|
|||||||
// V V V V V V
|
// V V V V V V
|
||||||
// id
|
// id
|
||||||
{
|
{
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
|
|
||||||
// current token - first token of the first level
|
// current token - first token of the first level
|
||||||
common_batch_add(batch, id, n_past, seq_id_all, true);
|
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
|
||||||
|
|
||||||
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
|
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
|
||||||
{
|
{
|
||||||
@ -230,9 +230,10 @@ int main(int argc, char ** argv) {
|
|||||||
const llama_token t = ngrams_observed.tokens[idx + j];
|
const llama_token t = ngrams_observed.tokens[idx + j];
|
||||||
|
|
||||||
ngrams_cur[g].tokens [j + 1] = t;
|
ngrams_cur[g].tokens [j + 1] = t;
|
||||||
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
|
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
|
||||||
|
|
||||||
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
|
llama_seq_id seq_id = W + 1 + g;
|
||||||
|
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,18 +245,20 @@ int main(int argc, char ** argv) {
|
|||||||
seq_id_look[j] = i + j + 1;
|
seq_id_look[j] = i + j + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
|
||||||
|
seq_id_look.data(), seq_id_look.size(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// fill the rest of the levels
|
// fill the rest of the levels
|
||||||
for (int j = 1; j < N - 1; j++) {
|
for (int j = 1; j < N - 1; j++) {
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
llama_seq_id seq_id = i + 1;
|
||||||
|
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode_ext(ctx, batch) != 0) {
|
||||||
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
|
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -475,7 +478,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_kv_cache_view_free(&kvc_view);
|
llama_kv_cache_view_free(&kvc_view);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
|
||||||
|
|
||||||
int32_t n_total_prompt = 0;
|
int32_t n_total_prompt = 0;
|
||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
@ -192,10 +192,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
|
||||||
|
|
||||||
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
for (int32_t i = 0; i < n_tokens_system; ++i) {
|
||||||
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode_ext(ctx, batch) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
|
|||||||
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
@ -224,14 +225,15 @@ int main(int argc, char ** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
client.i_batch = batch.n_tokens;
|
client.i_batch = llama_batch_ext_get_n_tokens(batch);
|
||||||
|
|
||||||
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
llama_seq_id seq_id = client.id + 1;
|
||||||
|
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
|
||||||
|
|
||||||
client.n_decoded += 1;
|
client.n_decoded += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||||
// all sequences have ended - clear the entire KV cache
|
// all sequences have ended - clear the entire KV cache
|
||||||
for (int i = 1; i <= n_clients; ++i) {
|
for (int i = 1; i <= n_clients; ++i) {
|
||||||
llama_kv_self_seq_rm(ctx, i, -1, -1);
|
llama_kv_self_seq_rm(ctx, i, -1, -1);
|
||||||
@ -243,7 +245,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// insert new sequences for decoding
|
// insert new sequences for decoding
|
||||||
if (cont_batching || batch.n_tokens == 0) {
|
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||||
client.seq_id = g_seq_id;
|
client.seq_id = g_seq_id;
|
||||||
@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
|
|||||||
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
tokens_prompt = common_tokenize(ctx, client.prompt, false);
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||||
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
llama_seq_id seq_id = client.id + 1;
|
||||||
|
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
if (batch.n_tokens > 0) {
|
if (llama_batch_ext_get_n_tokens(batch) > 0) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
client.n_prompt = tokens_prompt.size();
|
client.n_prompt = tokens_prompt.size();
|
||||||
client.n_decoded = 0;
|
client.n_decoded = 0;
|
||||||
client.i_batch = batch.n_tokens - 1;
|
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
|
||||||
|
|
||||||
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
||||||
|
|
||||||
@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
int32_t n_batch = params.n_batch;
|
int32_t n_batch = params.n_batch;
|
||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
|
||||||
|
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
|
||||||
// experiment: process in powers of 2
|
// experiment: process in powers of 2
|
||||||
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
|
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
|
||||||
// n_batch /= 2;
|
// n_batch /= 2;
|
||||||
@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
|
|||||||
// continue;
|
// continue;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
|
||||||
n_tokens,
|
const int ret = llama_decode_ext(ctx, batch_view);
|
||||||
batch.token + i,
|
llama_batch_ext_free(batch_view);
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
if (n_batch == 1 || ret < 0) {
|
if (n_batch == 1 || ret < 0) {
|
||||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||||
@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
|
|||||||
// TODO: print sampling/grammar timings for all clients
|
// TODO: print sampling/grammar timings for all clients
|
||||||
llama_perf_context_print(ctx);
|
llama_perf_context_print(ctx);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_INF("prompt tokens: %d\n", n_tokens_all);
|
LOG_INF("prompt tokens: %d\n", n_tokens_all);
|
||||||
//LOG_INF("prompt: %s\n", params.prompt.c_str());
|
//LOG_INF("prompt: %s\n", params.prompt.c_str());
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
|
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
|
|||||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch.get());
|
||||||
|
|
||||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode_ext(ctx, batch.get()) != 0) {
|
||||||
LOG_INF("%s: llama_decode() failed\n", __func__);
|
LOG_INF("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch.get());
|
||||||
|
|
||||||
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
|
||||||
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode_ext(ctx, batch.get()) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
|
|||||||
while (n_cur <= n_len) {
|
while (n_cur <= n_len) {
|
||||||
// sample the next token
|
// sample the next token
|
||||||
{
|
{
|
||||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
|
||||||
|
|
||||||
// is it an end of generation?
|
// is it an end of generation?
|
||||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
|
||||||
@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
|
|||||||
n_decode += 1;
|
n_decode += 1;
|
||||||
|
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch.get());
|
||||||
|
|
||||||
// push this new token for next evaluation
|
// push this new token for next evaluation
|
||||||
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
|
|
||||||
// evaluate the current batch with the transformer model
|
// evaluate the current batch with the transformer model
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler_free(smpl);
|
llama_sampler_free(smpl);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_model_free(model);
|
llama_model_free(model);
|
||||||
|
|
||||||
|
@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
common_batch batch(n_batch, 1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
for (int i = 0; i < batch_size; i++) {
|
for (int i = 0; i < batch_size; i++) {
|
||||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
//LOG_ERR("%s : failed to eval\n", __func__);
|
//LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
llama_batch_free(batch);
|
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
||||||
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = 0;
|
||||||
|
|
||||||
batch.n_tokens = 0;
|
batch.clear();
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
int seq_start = batch_start + seq*n_ctx;
|
int seq_start = batch_start + seq*n_ctx;
|
||||||
|
|
||||||
@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
|
|
||||||
for (int k = 0; k < batch_size; ++k) {
|
for (int k = 0; k < batch_size; ++k) {
|
||||||
const int idx = seq*n_ctx + k;
|
const int idx = seq*n_ctx + k;
|
||||||
batch.token [idx] = tokens[seq_start + k];
|
const llama_pos pos = j*n_batch + k;
|
||||||
batch.pos [idx] = j*n_batch + k;
|
bool output = pos >= first;
|
||||||
batch.n_seq_id[idx] = 1;
|
batch.add_text(tokens[seq_start + k], pos, seq, output);
|
||||||
batch.seq_id [idx][0] = seq;
|
|
||||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
|
||||||
|
|
||||||
n_outputs += batch.logits[idx] != 0;
|
n_outputs += output ? 1 : 0;
|
||||||
}
|
}
|
||||||
batch.n_tokens += batch_size;
|
|
||||||
|
|
||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[seq_start] = token_org;
|
tokens[seq_start] = token_org;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_INF("%s : failed to eval\n", __func__);
|
LOG_INF("%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
return {tokens, ppl, logit_history, prob_history};
|
return {tokens, ppl, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
||||||
int prev_outputs = 0;
|
int prev_outputs = 0;
|
||||||
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
|
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
|
||||||
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
|
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
|
||||||
|
|
||||||
llama_batch batch_view = {
|
common_batch batch_view = batch.get_view(i, n_tokens);
|
||||||
n_tokens,
|
|
||||||
batch.token + i,
|
|
||||||
nullptr,
|
|
||||||
batch.pos + i,
|
|
||||||
batch.n_seq_id + i,
|
|
||||||
batch.seq_id + i,
|
|
||||||
batch.logits + i,
|
|
||||||
};
|
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = batch_view.n_outputs;
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
|
||||||
n_outputs += batch_view.logits[i] != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
||||||
|
|
||||||
@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
|
common_batch batch(n_ctx, 4);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
// TODO: this could be made smaller; it's currently the worst-case size
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique sequence ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
// TODO: don't evaluate the last token of each sequence
|
// TODO: don't evaluate the last token of each sequence
|
||||||
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
const bool needs_logits = i < seq_tokens_size - 1;
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
n_logits += needs_logits;
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
i0 = i1 - 1;
|
i0 = i1 - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||||||
const int max_tasks_per_batch = 128;
|
const int max_tasks_per_batch = 128;
|
||||||
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
|
common_batch batch(n_ctx, 2);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
// TODO: this could be made smaller; it's currently the worst-case size
|
// TODO: this could be made smaller; it's currently the worst-case size
|
||||||
@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_logits = 0;
|
size_t i_logits = 0;
|
||||||
|
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
|
|
||||||
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
||||||
int n_logits = 0;
|
int n_logits = 0;
|
||||||
@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 2; ++s) {
|
for (int s = 0; s < 2; ++s) {
|
||||||
// TODO: end before the last token, no need to predict past the end of the sequences
|
// TODO: end before the last token, no need to predict past the end of the sequences
|
||||||
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
||||||
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
common_batch batch(n_ctx, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
||||||
@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
size_t i1 = i0;
|
size_t i1 = i0;
|
||||||
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
||||||
|
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique sequence ids - one for each ending
|
// each task has 4 unique sequence ids - one for each ending
|
||||||
@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
|
|
||||||
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
|
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
|
||||||
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||||
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
// TODO: don't evaluate the last token of each sequence
|
// TODO: don't evaluate the last token of each sequence
|
||||||
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
||||||
const bool needs_logits = i < seq_tokens_size - 1;
|
const bool needs_logits = i < seq_tokens_size - 1;
|
||||||
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
||||||
n_logits += needs_logits;
|
n_logits += needs_logits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
i0 = i1 - 1;
|
i0 = i1 - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
|
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
|
||||||
|
|
||||||
float p = 1.f*n_correct/n_done;
|
float p = 1.f*n_correct/n_done;
|
||||||
@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
common_batch batch(n_batch, 1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
for (int i = 0; i < batch_size; i++) {
|
for (int i = 0; i < batch_size; i++) {
|
||||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
llama_batch_free(batch);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
|
@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
|||||||
return chunks;
|
return chunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||||
size_t n_tokens = tokens.size();
|
size_t n_tokens = tokens.size();
|
||||||
for (size_t i = 0; i < n_tokens; i++) {
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
batch.add_text(tokens[i], i, seq_id, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
|
||||||
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||||
|
const struct llama_model * model = llama_get_model(ctx);
|
||||||
|
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_self_clear(ctx);
|
llama_kv_self_clear(ctx);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
|
||||||
if (llama_decode(ctx, batch) < 0) {
|
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
|
||||||
LOG_ERR("%s : failed to decode\n", __func__);
|
// encoder-only model
|
||||||
|
if (llama_encode_ext(ctx, batch.get()) < 0) {
|
||||||
|
LOG_ERR("%s : failed to encode\n", __func__);
|
||||||
|
}
|
||||||
|
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||||
|
// decoder-only model
|
||||||
|
if (llama_decode_ext(ctx, batch.get()) < 0) {
|
||||||
|
LOG_ERR("%s : failed to decode\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.tokens[i].logits) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
const float * embd = nullptr;
|
||||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
int embd_pos = 0;
|
||||||
if (embd == NULL) {
|
|
||||||
|
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
// try to get token embeddings
|
||||||
embd = llama_get_embeddings_ith(ctx, i);
|
embd = llama_get_embeddings_ith(ctx, i);
|
||||||
if (embd == NULL) {
|
embd_pos = i;
|
||||||
LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i);
|
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
|
||||||
continue;
|
} else {
|
||||||
}
|
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||||
|
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
|
||||||
|
embd_pos = batch.tokens[i].seq_id;
|
||||||
|
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||||
}
|
}
|
||||||
|
|
||||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
float * out = output + embd_pos * n_embd;
|
||||||
common_embd_normalize(embd, out, n_embd, 2);
|
common_embd_normalize(embd, out, n_embd, embd_norm);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,7 +230,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// initialize batch
|
// initialize batch
|
||||||
const int n_chunks = chunks.size();
|
const int n_chunks = chunks.size();
|
||||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
struct common_batch batch = common_batch(n_batch, 1);
|
||||||
|
|
||||||
// allocate output
|
// allocate output
|
||||||
const int n_embd = llama_model_n_embd(model);
|
const int n_embd = llama_model_n_embd(model);
|
||||||
@ -231,10 +247,10 @@ int main(int argc, char ** argv) {
|
|||||||
const uint64_t n_toks = inp.size();
|
const uint64_t n_toks = inp.size();
|
||||||
|
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
if (batch.n_tokens + n_toks > n_batch) {
|
if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_decode(ctx, batch, out, s, n_embd);
|
batch_decode(ctx, batch, out, s, n_embd);
|
||||||
common_batch_clear(batch);
|
batch.clear();
|
||||||
p += s;
|
p += s;
|
||||||
s = 0;
|
s = 0;
|
||||||
}
|
}
|
||||||
@ -255,7 +271,7 @@ int main(int argc, char ** argv) {
|
|||||||
chunks[i].tokens.clear();
|
chunks[i].tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
struct common_batch query_batch = common_batch(n_batch, 1);
|
||||||
|
|
||||||
// start loop, receive query and return top k similar chunks based on cosine similarity
|
// start loop, receive query and return top k similar chunks based on cosine similarity
|
||||||
std::string query;
|
std::string query;
|
||||||
@ -269,7 +285,7 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<float> query_emb(n_embd, 0);
|
std::vector<float> query_emb(n_embd, 0);
|
||||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||||
|
|
||||||
common_batch_clear(query_batch);
|
query_batch.clear();
|
||||||
|
|
||||||
// compute cosine similarities
|
// compute cosine similarities
|
||||||
{
|
{
|
||||||
@ -299,6 +315,5 @@ int main(int argc, char ** argv) {
|
|||||||
llama_perf_context_print(ctx);
|
llama_perf_context_print(ctx);
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
llama_batch_free(query_batch);
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
}
|
}
|
||||||
|
@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if we have enough space in the context to evaluate this batch
|
// Check if we have enough space in the context to evaluate this batch
|
||||||
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
|
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
|
||||||
const int n_ctx = llama_n_ctx(ctx.get());
|
const int n_ctx = llama_n_ctx(ctx.get());
|
||||||
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
|
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
|
||||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
|
||||||
printf(LOG_COL_DEFAULT "\n");
|
printf(LOG_COL_DEFAULT "\n");
|
||||||
printe("context size exceeded\n");
|
printe("context size exceeded\n");
|
||||||
return 1;
|
return 1;
|
||||||
@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare a batch for the prompt
|
// prepare a batch for the prompt
|
||||||
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
|
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
|
||||||
llama_token new_token_id;
|
llama_token new_token_id;
|
||||||
while (true) {
|
while (true) {
|
||||||
check_context_size(llama_data.context, batch);
|
check_context_size(llama_data.context, batch);
|
||||||
if (llama_decode(llama_data.context.get(), batch)) {
|
if (llama_decode_ext(llama_data.context.get(), batch.get())) {
|
||||||
printe("failed to decode\n");
|
printe("failed to decode\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
|||||||
print_word_and_concatenate_to_response(piece, response);
|
print_word_and_concatenate_to_response(piece, response);
|
||||||
|
|
||||||
// prepare the next batch with the sampled token
|
// prepare the next batch with the sampled token
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
printf(LOG_COL_DEFAULT);
|
printf(LOG_COL_DEFAULT);
|
||||||
|
@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
|
|||||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
// prepare the batch
|
// prepare the batch
|
||||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0);
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
|
||||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
|
||||||
}
|
|
||||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_decode(ctx, batch);
|
llama_decode_ext(ctx, batch);
|
||||||
n_past += batch.n_tokens;
|
n_past += llama_batch_ext_get_n_tokens(batch);
|
||||||
|
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result0 += next_token_str;
|
result0 += next_token_str;
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||||
|
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
n_past += 1;
|
n_past += 1;
|
||||||
@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result1 += next_token_str;
|
result1 += next_token_str;
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
llama_seq_id seq_id = 1;
|
||||||
|
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||||
|
|
||||||
if (llama_decode(ctx2, batch)) {
|
if (llama_decode_ext(ctx2, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
n_past += 1;
|
n_past += 1;
|
||||||
@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
|
|||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result2 += next_token_str;
|
result2 += next_token_str;
|
||||||
|
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
llama_seq_id seq_id = 1;
|
||||||
|
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||||
|
|
||||||
if (llama_decode(ctx3, batch)) {
|
if (llama_decode_ext(ctx3, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
n_past += 1;
|
n_past += 1;
|
||||||
@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_sampler_free(smpl2);
|
llama_sampler_free(smpl2);
|
||||||
llama_sampler_free(smpl3);
|
llama_sampler_free(smpl3);
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
|
|
||||||
if (result0 != result2) {
|
if (result0 != result2) {
|
||||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||||
|
@ -108,19 +108,20 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepare a batch for the prompt
|
// prepare a batch for the prompt
|
||||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
||||||
|
llama_batch_ext_set_output_last(batch);
|
||||||
llama_token new_token_id;
|
llama_token new_token_id;
|
||||||
while (true) {
|
while (true) {
|
||||||
// check if we have enough space in the context to evaluate this batch
|
// check if we have enough space in the context to evaluate this batch
|
||||||
int n_ctx = llama_n_ctx(ctx);
|
int n_ctx = llama_n_ctx(ctx);
|
||||||
int n_ctx_used = llama_kv_self_used_cells(ctx);
|
int n_ctx_used = llama_kv_self_used_cells(ctx);
|
||||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
|
||||||
printf("\033[0m\n");
|
printf("\033[0m\n");
|
||||||
fprintf(stderr, "context size exceeded\n");
|
fprintf(stderr, "context size exceeded\n");
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch)) {
|
||||||
GGML_ABORT("failed to decode\n");
|
GGML_ABORT("failed to decode\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,9 +145,13 @@ int main(int argc, char ** argv) {
|
|||||||
response += piece;
|
response += piece;
|
||||||
|
|
||||||
// prepare the next batch with the sampled token
|
// prepare the next batch with the sampled token
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
llama_batch_ext_clear(batch);
|
||||||
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_ext_free(batch);
|
||||||
|
|
||||||
return response;
|
return response;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -143,7 +143,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// prepare a batch for the prompt
|
// prepare a batch for the prompt
|
||||||
|
|
||||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
|
||||||
|
llama_batch_ext_set_output_last(batch);
|
||||||
|
|
||||||
// main loop
|
// main loop
|
||||||
|
|
||||||
@ -151,14 +152,14 @@ int main(int argc, char ** argv) {
|
|||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
llama_token new_token_id;
|
llama_token new_token_id;
|
||||||
|
|
||||||
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
|
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
|
||||||
// evaluate the current batch with the transformer model
|
// evaluate the current batch with the transformer model
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode_ext(ctx, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
n_pos += batch.n_tokens;
|
n_pos += llama_batch_ext_get_n_tokens(batch);
|
||||||
|
|
||||||
// sample the next token
|
// sample the next token
|
||||||
{
|
{
|
||||||
@ -180,7 +181,9 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
// prepare the next batch with the sampled token
|
// prepare the next batch with the sampled token
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
llama_batch_ext_clear(batch);
|
||||||
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
|
||||||
|
|
||||||
n_decode += 1;
|
n_decode += 1;
|
||||||
}
|
}
|
||||||
@ -198,6 +201,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_perf_context_print(ctx);
|
llama_perf_context_print(ctx);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
llama_batch_ext_free(batch);
|
||||||
llama_sampler_free(smpl);
|
llama_sampler_free(smpl);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_model_free(model);
|
llama_model_free(model);
|
||||||
|
@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
||||||
|
|
||||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
@ -151,8 +151,9 @@ int main(int argc, char ** argv) {
|
|||||||
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
||||||
|
|
||||||
// always have a token to evaluate from before - id_last
|
// always have a token to evaluate from before - id_last
|
||||||
common_batch_clear(batch_tgt);
|
llama_batch_ext_clear(batch_tgt);
|
||||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
|
||||||
|
|
||||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||||
{
|
{
|
||||||
@ -162,12 +163,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < draft.size(); ++i) {
|
for (size_t i = 0; i < draft.size(); ++i) {
|
||||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||||
|
|
||||||
llama_decode(ctx_tgt, batch_tgt);
|
llama_decode_ext(ctx_tgt, batch_tgt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample from the full target batch and return the accepted tokens based on the target sampler
|
// sample from the full target batch and return the accepted tokens based on the target sampler
|
||||||
@ -253,6 +254,7 @@ int main(int argc, char ** argv) {
|
|||||||
common_sampler_free(smpl);
|
common_sampler_free(smpl);
|
||||||
common_speculative_free(spec);
|
common_speculative_free(spec);
|
||||||
|
|
||||||
|
llama_batch_ext_free(batch_tgt);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
LOG("\n\n");
|
LOG("\n\n");
|
||||||
|
@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
#ifdef 0
|
||||||
if (params.speculative.model.empty()) {
|
if (params.speculative.model.empty()) {
|
||||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
@ -199,8 +199,8 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
|
||||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
@ -441,12 +441,13 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||||
drafts[0].i_batch_tgt.push_back(0);
|
drafts[0].i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
common_batch_clear(batch_dft);
|
llama_batch_ext_clear(batch_dft);
|
||||||
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
|
||||||
|
|
||||||
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||||
llama_decode(ctx_dft, batch_dft);
|
llama_decode_ext(ctx_dft, batch_dft);
|
||||||
|
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
}
|
}
|
||||||
@ -471,8 +472,9 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[0].drafting = true;
|
drafts[0].drafting = true;
|
||||||
drafts[0].i_batch_dft = 0;
|
drafts[0].i_batch_dft = 0;
|
||||||
|
|
||||||
common_batch_clear(batch_tgt);
|
llama_batch_ext_clear(batch_tgt);
|
||||||
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
|
||||||
|
|
||||||
// sample n_draft tokens from the draft model using tree-based sampling
|
// sample n_draft tokens from the draft model using tree-based sampling
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
@ -640,5 +642,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG("\n\n");
|
LOG("\n\n");
|
||||||
|
|
||||||
|
#endif
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
|
|
||||||
// create a llama_batch
|
// create a llama_batch
|
||||||
// we use this object to submit token data for decoding
|
// we use this object to submit token data for decoding
|
||||||
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
|
llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel);
|
||||||
|
|
||||||
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
||||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||||
@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
for (size_t i = 0; i < prompt_inp.size(); ++i) {
|
for (size_t i = 0; i < prompt_inp.size(); ++i) {
|
||||||
common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
|
llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false);
|
||||||
}
|
}
|
||||||
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
|
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size());
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch);
|
||||||
|
|
||||||
if (llama_decode(ctx_ttc, batch) != 0) {
|
if (llama_decode_ext(ctx_ttc, batch) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
|
|
||||||
// remember the batch index of the last token for each parallel sequence
|
// remember the batch index of the last token for each parallel sequence
|
||||||
// we need this to determine which logits to sample from
|
// we need this to determine which logits to sample from
|
||||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
|
||||||
|
|
||||||
int n_past = batch.n_tokens;
|
int n_past = llama_batch_ext_get_n_tokens(batch);
|
||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
|
|
||||||
bool next_token_uses_guide_token = true;
|
bool next_token_uses_guide_token = true;
|
||||||
|
|
||||||
while (n_decode <= n_predict) {
|
while (n_decode <= n_predict) {
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
common_batch_clear(batch);
|
llama_batch_ext_clear(batch);
|
||||||
|
|
||||||
// sample the next token for each parallel sequence / stream
|
// sample the next token for each parallel sequence / stream
|
||||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||||
@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
//LOG_CNT("%d", i);
|
//LOG_CNT("%d", i);
|
||||||
}
|
}
|
||||||
|
|
||||||
i_batch[i] = batch.n_tokens;
|
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
|
||||||
|
|
||||||
// push this new token for next evaluation
|
// push this new token for next evaluation
|
||||||
common_batch_add(batch, new_token_id, n_past, { i }, true);
|
llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// all streams are finished
|
// all streams are finished
|
||||||
if (batch.n_tokens == 0) {
|
if (llama_batch_ext_get_n_tokens(batch) == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
n_past += 1;
|
n_past += 1;
|
||||||
|
|
||||||
// evaluate the current batch with the transformer model
|
// evaluate the current batch with the transformer model
|
||||||
if (llama_decode(ctx_ttc, batch)) {
|
if (llama_decode_ext(ctx_ttc, batch)) {
|
||||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_ext_free(batch);
|
||||||
|
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
|
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
|
||||||
@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
|
|
||||||
const int n_codes = codes.size();
|
const int n_codes = codes.size();
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_codes, 0, 1);
|
llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1);
|
||||||
|
|
||||||
for (size_t i = 0; i < codes.size(); ++i) {
|
for (size_t i = 0; i < codes.size(); ++i) {
|
||||||
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits?
|
||||||
}
|
}
|
||||||
GGML_ASSERT(batch.n_tokens == n_codes);
|
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes);
|
||||||
|
|
||||||
if (llama_decode(ctx_cts, batch) != 0) {
|
if (llama_decode_ext(ctx_cts, batch) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
|
|
||||||
LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
|
LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
llama_batch_ext_free(batch);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -995,9 +995,9 @@ extern "C" {
|
|||||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the KV cache state is restored to the state before this call
|
||||||
DEPRECATED(LLAMA_API int32_t llama_encode(
|
LLAMA_API int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch), "use llama_batch_ext API instead");
|
struct llama_batch batch);
|
||||||
|
|
||||||
LLAMA_API int32_t llama_encode_ext(
|
LLAMA_API int32_t llama_encode_ext(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
@ -1007,9 +1007,9 @@ extern "C" {
|
|||||||
// 0 - success
|
// 0 - success
|
||||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||||
// < 0 - error. the KV cache state is restored to the state before this call
|
// < 0 - error. the KV cache state is restored to the state before this call
|
||||||
DEPRECATED(LLAMA_API int32_t llama_decode(
|
LLAMA_API int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch), "use llama_batch_ext API instead");
|
struct llama_batch batch);
|
||||||
|
|
||||||
LLAMA_API int32_t llama_decode_ext(
|
LLAMA_API int32_t llama_decode_ext(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
|
Reference in New Issue
Block a user