mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-15 15:17:44 +00:00
context : abstract constructor and init
ggml-ci
This commit is contained in:
@ -14,14 +14,290 @@
|
||||
// llama_context
|
||||
//
|
||||
|
||||
llama_context::llama_context(const llama_model & model) :
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params) :
|
||||
model (model),
|
||||
t_start_us(model.t_start_us),
|
||||
t_load_us (model.t_load_us) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
||||
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
||||
hparams.n_ctx_train;
|
||||
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||
}
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
||||
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
||||
}
|
||||
|
||||
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
} else {
|
||||
cparams.pooling_type = hparams.pooling_type;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
||||
cparams.causal_attn = hparams.causal_attn;
|
||||
} else {
|
||||
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
logits_all = params.logits_all;
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
// GPU backends
|
||||
for (auto * dev : model.devices) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
throw std::runtime_error("failed to initialize backend");
|
||||
}
|
||||
backends.emplace_back(backend);
|
||||
}
|
||||
|
||||
// add ACCEL backends (such as BLAS)
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
throw std::runtime_error("failed to initialize backend");
|
||||
}
|
||||
backends.emplace_back(backend);
|
||||
}
|
||||
}
|
||||
|
||||
// add CPU backend
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (backend_cpu == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
|
||||
throw std::runtime_error("failed to initialize CPU backend");
|
||||
}
|
||||
backends.emplace_back(backend_cpu);
|
||||
|
||||
// create a list of the set_n_threads functions in the backends
|
||||
for (auto & backend : backends) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
|
||||
|
||||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
||||
throw std::runtime_error("failed to reserve initial output buffer");
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buffer_name (buf_output.get()),
|
||||
ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() = default;
|
||||
|
||||
void llama_context::init() {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
if (hparams.vocab_only) {
|
||||
LLAMA_LOG_WARN("%s: model is vocab-only -- no computation will be performed\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
for (auto & backend : backends) {
|
||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
||||
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
||||
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
||||
auto * dev = model.devices[0];
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
}
|
||||
backend_buft.push_back(buft);
|
||||
backend_ptrs.push_back(backend.get());
|
||||
}
|
||||
|
||||
const size_t max_nodes = model.max_nodes();
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
// TODO: move to base class
|
||||
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
|
||||
// TODO: move these checks to ggml_backend_sched
|
||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||
bool pipeline_parallel =
|
||||
model.n_devices() > 1 &&
|
||||
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
||||
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
||||
cparams.offload_kqv;
|
||||
|
||||
// pipeline parallelism requires support for async compute and events in all devices
|
||||
if (pipeline_parallel) {
|
||||
for (auto & backend : backends) {
|
||||
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
||||
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
// ignore CPU backend
|
||||
continue;
|
||||
}
|
||||
auto * dev = ggml_backend_get_device(backend.get());
|
||||
ggml_backend_dev_props props;
|
||||
ggml_backend_dev_get_props(dev, &props);
|
||||
if (!props.caps.async || !props.caps.events) {
|
||||
// device does not support async compute or events
|
||||
pipeline_parallel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
||||
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||
}
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
{
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_pp = build_graph(ubatch_pp, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
}
|
||||
int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_tg = build_graph(ubatch_tg, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
}
|
||||
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
gf_pp = build_graph(ubatch_pp, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
if (size > 1) {
|
||||
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buft_name(buft),
|
||||
size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_nodes_pp == n_nodes_tg) {
|
||||
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
||||
}
|
||||
|
||||
if (n_splits_pp == n_splits_tg) {
|
||||
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const llama_model & llama_context::get_model() const {
|
||||
return model;
|
||||
}
|
||||
@ -161,46 +437,6 @@ int64_t llama_context::n_pos_per_token() const {
|
||||
return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
||||
}
|
||||
|
||||
ggml_context_ptr llama_context::init() {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
return ggml_context_ptr { ggml_init(params) };
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
ggml_backend_sched_synchronize(sched.get());
|
||||
|
||||
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
||||
// the stats will be added to the prompt evaluation stats
|
||||
// this should only happen when using batch size 1 to evaluate a batch
|
||||
|
||||
// add the evaluation to the stats
|
||||
if (n_queued_tokens == 1) {
|
||||
if (!cparams.no_perf) {
|
||||
t_eval_us += ggml_time_us() - t_compute_start_us;
|
||||
}
|
||||
n_eval++;
|
||||
} else if (n_queued_tokens > 1) {
|
||||
if (!cparams.no_perf) {
|
||||
t_p_eval_us += ggml_time_us() - t_compute_start_us;
|
||||
}
|
||||
n_p_eval += n_queued_tokens;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
if (n_queued_tokens > 0 && !has_evaluated_once) {
|
||||
t_load_us = ggml_time_us() - t_start_us;
|
||||
has_evaluated_once = true;
|
||||
}
|
||||
|
||||
n_queued_tokens = 0;
|
||||
t_compute_start_us = 0;
|
||||
}
|
||||
|
||||
void llama_context::attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) {
|
||||
@ -269,7 +505,54 @@ bool llama_context::apply_adapter_cvec(
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
void llama_context::synchronize() {
|
||||
ggml_backend_sched_synchronize(sched.get());
|
||||
|
||||
// FIXME: if multiple single tokens are evaluated without a synchronization,
|
||||
// the stats will be added to the prompt evaluation stats
|
||||
// this should only happen when using batch size 1 to evaluate a batch
|
||||
|
||||
// add the evaluation to the stats
|
||||
if (n_queued_tokens == 1) {
|
||||
if (!cparams.no_perf) {
|
||||
t_eval_us += ggml_time_us() - t_compute_start_us;
|
||||
}
|
||||
n_eval++;
|
||||
} else if (n_queued_tokens > 1) {
|
||||
if (!cparams.no_perf) {
|
||||
t_p_eval_us += ggml_time_us() - t_compute_start_us;
|
||||
}
|
||||
n_p_eval += n_queued_tokens;
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
if (n_queued_tokens > 0 && !has_evaluated_once) {
|
||||
t_load_us = ggml_time_us() - t_start_us;
|
||||
has_evaluated_once = true;
|
||||
}
|
||||
|
||||
n_queued_tokens = 0;
|
||||
t_compute_start_us = 0;
|
||||
}
|
||||
|
||||
ggml_context_ptr llama_context::graph_init() {
|
||||
inp_tokens = nullptr;
|
||||
inp_embd = nullptr;
|
||||
inp_pos = nullptr;
|
||||
inp_out_ids = nullptr;
|
||||
inp_mean = nullptr;
|
||||
inp_cls = nullptr;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
return ggml_context_ptr { ggml_init(params) };
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::graph_compute(
|
||||
ggml_cgraph * graph,
|
||||
bool batched) {
|
||||
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
||||
@ -608,7 +891,7 @@ void llama_context::build_cb(
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::build_graph(const llama_ubatch & ubatch, bool worst_case) {
|
||||
return model.build_graph(*this, cparams, ubatch, init(), worst_case);
|
||||
return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case);
|
||||
}
|
||||
|
||||
llama_perf_context_data llama_context::perf_get_data() const {
|
||||
@ -1183,100 +1466,15 @@ void llama_context::perf_reset() {
|
||||
|
||||
llama_context_kv_self::llama_context_kv_self(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params) : llama_context(model) {
|
||||
const llama_context_params & params) :
|
||||
llama_context(model, params) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams));
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams));
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
|
||||
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
|
||||
hparams.n_ctx_train;
|
||||
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||
}
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
|
||||
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
|
||||
}
|
||||
|
||||
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
|
||||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
} else {
|
||||
cparams.pooling_type = hparams.pooling_type;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
||||
cparams.causal_attn = hparams.causal_attn;
|
||||
} else {
|
||||
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
||||
}
|
||||
|
||||
logits_all = params.logits_all;
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
// build worst-case graph for encoder if a model contains encoder
|
||||
is_encoding = llama_model_has_encoder(&model); // TODO: model.has_encoder()
|
||||
@ -1298,51 +1496,6 @@ llama_context_kv_self::llama_context_kv_self(
|
||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
// GPU backends
|
||||
for (auto * dev : model.devices) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
throw std::runtime_error("failed to initialize backend");
|
||||
}
|
||||
backends.emplace_back(backend);
|
||||
}
|
||||
|
||||
// add ACCEL backends (such as BLAS)
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
throw std::runtime_error("failed to initialize backend");
|
||||
}
|
||||
backends.emplace_back(backend);
|
||||
}
|
||||
}
|
||||
|
||||
// add CPU backend
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
if (backend_cpu == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
|
||||
throw std::runtime_error("failed to initialize CPU backend");
|
||||
}
|
||||
backends.emplace_back(backend_cpu);
|
||||
|
||||
// create a list of the set_n_threads functions in the backends
|
||||
for (auto & backend : backends) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
|
||||
|
||||
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
throw std::runtime_error("failed to initialize self-attention cache");
|
||||
@ -1357,128 +1510,6 @@ llama_context_kv_self::llama_context_kv_self(
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
||||
throw std::runtime_error("failed to reserve initial output buffer");
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buffer_name (buf_output.get()),
|
||||
ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// scheduler and compute buffers
|
||||
{
|
||||
// buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
for (auto & backend : backends) {
|
||||
auto * buft = ggml_backend_get_default_buffer_type(backend.get());
|
||||
auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
||||
if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
|
||||
// use the host buffer of the first device CPU for faster transfer of the intermediate state
|
||||
auto * dev = model.devices[0];
|
||||
auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
if (host_buft) {
|
||||
buft = host_buft;
|
||||
}
|
||||
}
|
||||
backend_buft.push_back(buft);
|
||||
backend_ptrs.push_back(backend.get());
|
||||
}
|
||||
|
||||
const size_t max_nodes = model.max_nodes();
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
// TODO: move to base class
|
||||
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
|
||||
// TODO: move these checks to ggml_backend_sched
|
||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||
bool pipeline_parallel =
|
||||
model.n_devices() > 1 &&
|
||||
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
||||
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
||||
params.offload_kqv;
|
||||
|
||||
// pipeline parallelism requires support for async compute and events in all devices
|
||||
if (pipeline_parallel) {
|
||||
for (auto & backend : backends) {
|
||||
auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
|
||||
if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
// ignore CPU backend
|
||||
continue;
|
||||
}
|
||||
auto * dev = ggml_backend_get_device(backend.get());
|
||||
ggml_backend_dev_props props;
|
||||
ggml_backend_dev_get_props(dev, &props);
|
||||
if (!props.caps.async || !props.caps.events) {
|
||||
// device does not support async compute or events
|
||||
pipeline_parallel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
||||
|
||||
if (pipeline_parallel) {
|
||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||
}
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_pp = build_graph(ubatch_pp, true);
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
ggml_backend_sched_reserve(sched.get(), gf_pp);
|
||||
int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_tg = build_graph(ubatch_tg, true);
|
||||
ggml_backend_sched_reserve(sched.get(), gf_tg);
|
||||
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
gf_pp = build_graph(ubatch_pp, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < backend_ptrs.size(); ++i) {
|
||||
ggml_backend_t backend = backend_ptrs[i];
|
||||
ggml_backend_buffer_type_t buft = backend_buft[i];
|
||||
size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
|
||||
if (size > 1) {
|
||||
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
|
||||
ggml_backend_buft_name(buft),
|
||||
size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_nodes_pp == n_nodes_tg) {
|
||||
LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
|
||||
}
|
||||
if (n_splits_pp == n_splits_tg) {
|
||||
LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1497,15 +1528,7 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
||||
return &kv_self;
|
||||
}
|
||||
|
||||
ggml_context_ptr llama_context_kv_self::init() {
|
||||
inp_tokens = nullptr;
|
||||
inp_embd = nullptr;
|
||||
inp_pos = nullptr;
|
||||
inp_out_ids = nullptr;
|
||||
inp_mean = nullptr;
|
||||
inp_cls = nullptr;
|
||||
inp_embd_enc = nullptr;
|
||||
inp_pos_bucket = nullptr;
|
||||
ggml_context_ptr llama_context_kv_self::graph_init() {
|
||||
inp_KQ_mask = nullptr;
|
||||
inp_KQ_mask_cnv = nullptr;
|
||||
inp_KQ_mask_swa = nullptr;
|
||||
@ -1514,8 +1537,10 @@ ggml_context_ptr llama_context_kv_self::init() {
|
||||
inp_K_shift = nullptr;
|
||||
inp_s_copy = nullptr;
|
||||
inp_s_mask = nullptr;
|
||||
inp_embd_enc = nullptr;
|
||||
inp_pos_bucket = nullptr;
|
||||
|
||||
return llama_context::init();
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
|
||||
struct llama_context_kv_self::batch_manager {
|
||||
@ -1817,7 +1842,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
GGML_ASSERT(strcmp(t_logits->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
|
||||
const auto compute_status = compute_graph(gf, ubatch.n_tokens > 1);
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
bman->restore();
|
||||
switch (compute_status) {
|
||||
@ -2035,7 +2060,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
const auto compute_status = compute_graph(gf, n_tokens > 1);
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
@ -2422,7 +2447,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = init();
|
||||
auto ctx = graph_init();
|
||||
auto ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
@ -2433,7 +2458,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
|
||||
input_set({});
|
||||
|
||||
compute_graph(gf, false);
|
||||
graph_compute(gf, false);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
@ -2451,7 +2476,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
if (kv.do_defrag) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = init();
|
||||
auto ctx = graph_init();
|
||||
auto ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
@ -2463,7 +2488,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
// no input
|
||||
//input_set({});
|
||||
|
||||
compute_graph(gf, false);
|
||||
graph_compute(gf, false);
|
||||
|
||||
kv.do_defrag = false;
|
||||
|
||||
|
@ -21,9 +21,16 @@ class llama_io_write_i;
|
||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||
|
||||
struct llama_context : public llama_graph_i {
|
||||
llama_context(const llama_model & model);
|
||||
llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params);
|
||||
|
||||
virtual ~llama_context();
|
||||
|
||||
// init scheduler and compute buffers
|
||||
// call once after the context is constructed
|
||||
virtual void init();
|
||||
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
@ -52,10 +59,6 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
virtual int64_t n_pos_per_token() const; // vision
|
||||
|
||||
virtual ggml_context_ptr init();
|
||||
|
||||
virtual void synchronize();
|
||||
|
||||
virtual void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
@ -85,8 +88,14 @@ struct llama_context : public llama_graph_i {
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
////
|
||||
|
||||
virtual void synchronize();
|
||||
|
||||
virtual ggml_context_ptr graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
virtual enum ggml_status compute_graph(
|
||||
virtual enum ggml_status graph_compute(
|
||||
ggml_cgraph * graph,
|
||||
bool batched);
|
||||
|
||||
@ -297,7 +306,7 @@ public:
|
||||
|
||||
virtual void kv_self_update() override;
|
||||
|
||||
virtual ggml_context_ptr init() override;
|
||||
virtual ggml_context_ptr graph_init() override;
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
@ -312,7 +321,7 @@ public:
|
||||
// certain implementations could require a padding for the context size
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
// === unified KV cache ===
|
||||
// === KV cache ===
|
||||
|
||||
llama_kv_cache kv_self;
|
||||
|
||||
|
@ -328,6 +328,7 @@ struct llama_context * llama_init_from_model(
|
||||
try {
|
||||
// TODO: add logic which llama_context implementation to construct
|
||||
ctx = new llama_context_kv_self(*model, params);
|
||||
ctx->init();
|
||||
} catch (const std::exception & e) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
|
||||
return nullptr;
|
||||
@ -410,7 +411,6 @@ const char * llama_print_system_info(void) {
|
||||
static std::string s;
|
||||
s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
|
||||
|
||||
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto * reg = ggml_backend_reg_get(i);
|
||||
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
|
||||
|
Reference in New Issue
Block a user