context : abstract input

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 15:53:15 +02:00
parent 107d1e2c32
commit ed3cb55abe
2 changed files with 323 additions and 316 deletions

View File

@ -269,6 +269,309 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
if (backend_cpu != nullptr) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
set_threadpool_fn(backend_cpu, tp);
}
// set the number of threads for all the backends
for (const auto & set_n_threads_fn : set_n_threads_fns) {
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
return status;
}
void llama_context::input_set(const llama_ubatch & ubatch) {
const llama_hparams & hparams = model.hparams;
if (ubatch.token) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens));
}
if (ubatch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd));
}
if (ubatch.pos && inp_pos) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
//GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs");
if (!inp_out_ids) {
LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__);
} else {
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer));
int32_t * data = (int32_t *) inp_out_ids->data;
if (n_outputs == n_tokens) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
} else if (ubatch.output) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (ubatch.output[i]) {
data[n_outputs++] = i;
}
}
// the graph needs to have been passed the correct number of outputs
GGML_ASSERT(n_outputs == n_outputs);
} else if (n_outputs == 1) {
// only keep last output
data[0] = n_tokens - 1;
} else {
GGML_ASSERT(n_outputs == 0);
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_mean);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer));
float * data = (float *) inp_mean->data;
memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += ubatch.n_seq_tokens;
}
std::vector<float> div(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
}
}
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
for (int i = 0; i < n_seq_tokens; ++i) {
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
}
}
}
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
uint32_t * data = (uint32_t *) inp_cls->data;
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
uint32_t * data = (uint32_t *) inp_cls->data;
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
}
}
}
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
GGML_ASSERT(
// (!a || b) is a logical implication (a -> b)
// !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
"causal attention is not supported by this model"
);
}
size_t llama_context::output_reserve(size_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
if (output_ids.empty()) {
// init, never resized afterwards
output_ids.resize(n_batch);
}
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
if (!buf_output || prev_size < new_size) {
if (buf_output) {
#ifndef NDEBUG
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
buf_output = nullptr;
logits = nullptr;
embd = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
auto * output_dev = model.dev_output();
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
if (output_dev_host_buft) {
buft = output_dev_host_buft;
}
buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
if (buf_output == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
return 0;
}
}
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
logits = has_logits ? output_base : nullptr;
embd = has_embd ? output_base + logits_size : nullptr;
output_size = n_outputs_max;
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
ggml_backend_buffer_clear(buf_output.get(), 0);
n_outputs = 0;
return n_outputs_max;
}
void llama_context::output_reorder() {
std::vector<size_t> & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
out_ids.clear();
}
}
void llama_context::build_cb(
ggml_tensor * cur,
const char * name,
@ -1489,7 +1792,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
set_inputs(ubatch);
input_set(ubatch);
// the output is always the last tensor in the graph
struct ggml_tensor * t_logits = ggml_graph_node(gf, -1);
@ -1710,7 +2013,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
set_inputs(ubatch);
input_set(ubatch);
// the output embeddings after the final encoder normalization
struct ggml_tensor * t_embd = nullptr;
@ -1829,13 +2132,9 @@ uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) c
// llama input
void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
const llama_hparams & hparams = model.hparams;
//
// set input data
//
if (inp_K_shift) {
assert(ggml_backend_buffer_is_host(inp_K_shift->buffer));
@ -1849,64 +2148,8 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
return;
}
if (ubatch.token) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens));
}
if (ubatch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd));
}
if (ubatch.pos && inp_pos) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos));
}
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
//GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs");
if (!inp_out_ids) {
LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__);
} else {
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer));
int32_t * data = (int32_t *) inp_out_ids->data;
if (n_outputs == n_tokens) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
} else if (ubatch.output) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (ubatch.output[i]) {
data[n_outputs++] = i;
}
}
// the graph needs to have been passed the correct number of outputs
GGML_ASSERT(n_outputs == n_outputs);
} else if (n_outputs == 1) {
// only keep last output
data[0] = n_tokens - 1;
} else {
GGML_ASSERT(n_outputs == 0);
}
}
}
GGML_ASSERT(
// (!a || b) is a logical implication (a -> b)
// !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
"causal attention is not supported by this model"
);
// call base functionality
llama_context::input_set(ubatch);
if (inp_KQ_mask || inp_KQ_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
@ -2029,111 +2272,6 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_mean);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer));
float * data = (float *) inp_mean->data;
memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += ubatch.n_seq_tokens;
}
std::vector<float> div(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
}
}
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
for (int i = 0; i < n_seq_tokens; ++i) {
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
}
}
}
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
uint32_t * data = (uint32_t *) inp_cls->data;
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer));
uint32_t * data = (uint32_t *) inp_cls->data;
memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
}
}
}
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;
@ -2293,7 +2431,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_alloc_graph(sched.get(), gf);
set_inputs({});
input_set({});
compute_graph(gf, false);
@ -2323,7 +2461,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_alloc_graph(sched.get(), gf);
// no input
//set_inputs({});
//input_set({});
compute_graph(gf, false);
@ -3624,140 +3762,6 @@ int32_t llama_apply_adapter_cvec(
return res ? 0 : -1;
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
if (backend_cpu != nullptr) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
set_threadpool_fn(backend_cpu, tp);
}
// set the number of threads for all the backends
for (const auto & set_n_threads_fn : set_n_threads_fns) {
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
return status;
}
size_t llama_context::output_reserve(size_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
if (output_ids.empty()) {
// init, never resized afterwards
output_ids.resize(n_batch);
}
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
if (!buf_output || prev_size < new_size) {
if (buf_output) {
#ifndef NDEBUG
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
buf_output = nullptr;
logits = nullptr;
embd = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
auto * output_dev = model.dev_output();
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
if (output_dev_host_buft) {
buft = output_dev_host_buft;
}
buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
if (buf_output == nullptr) {
LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
return 0;
}
}
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
logits = has_logits ? output_base : nullptr;
embd = has_embd ? output_base + logits_size : nullptr;
output_size = n_outputs_max;
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
ggml_backend_buffer_clear(buf_output.get(), 0);
n_outputs = 0;
return n_outputs_max;
}
void llama_context::output_reorder() {
std::vector<size_t> & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
out_ids.clear();
}
}
//
// kv cache view
//

View File

@ -90,6 +90,8 @@ struct llama_context : public llama_graph_i {
ggml_cgraph * graph,
bool batched);
virtual void input_set(const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
virtual size_t output_reserve(size_t n_outputs);
@ -204,6 +206,15 @@ protected:
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
// members
const llama_model & model;
@ -288,6 +299,8 @@ public:
virtual ggml_context_ptr init() override;
virtual void input_set(const llama_ubatch & ubatch) override;
virtual int decode(llama_batch & inp_batch) override;
virtual int encode(llama_batch & inp_batch) override;
@ -299,16 +312,6 @@ public:
// certain implementations could require a padding for the context size
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
void set_inputs(const llama_ubatch & ubatch);
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
// === unified KV cache ===
llama_kv_cache kv_self;