diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 485430095..31085f644 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -14,14 +14,290 @@ // llama_context // -llama_context::llama_context(const llama_model & model) : +llama_context::llama_context( + const llama_model & model, + const llama_context_params & params) : model (model), t_start_us(model.t_start_us), t_load_us (model.t_load_us) { + const auto & hparams = model.hparams; + + cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.defrag_thold = params.defrag_thold; + cparams.embeddings = params.embeddings; + cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; + cparams.no_perf = params.no_perf; + cparams.pooling_type = params.pooling_type; + + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + + // with causal attention, the batch size is limited by the context size + cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; + + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + + cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : + hparams.n_ctx_train; + + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; + + auto rope_scaling_type = params.rope_scaling_type; + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { + rope_scaling_type = hparams.rope_scaling_type_train; + } + + if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { + cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none + } + + if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' + cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; + } + + cparams.yarn_attn_factor *= hparams.rope_attn_factor; + + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + } else { + cparams.pooling_type = hparams.pooling_type; + } + } + + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; + } else { + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; + } + + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + + if (n_ctx_per_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + if (n_ctx_per_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, n_ctx_per_seq, hparams.n_ctx_train); + } + + logits_all = params.logits_all; + + if (!hparams.vocab_only) { + // GPU backends + for (auto * dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + throw std::runtime_error("failed to initialize backend"); + } + backends.emplace_back(backend); + } + + // add ACCEL backends (such as BLAS) + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + throw std::runtime_error("failed to initialize backend"); + } + backends.emplace_back(backend); + } + } + + // add CPU backend + backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); + throw std::runtime_error("failed to initialize CPU backend"); + } + backends.emplace_back(backend_cpu); + + // create a list of the set_n_threads functions in the backends + for (auto & backend : backends) { + ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); + } + } + } + + llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); + + // graph outputs buffer + { + // resized during inference when a batch uses more outputs + if (output_reserve(params.n_seq_max) < params.n_seq_max) { + LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__); + throw std::runtime_error("failed to reserve initial output buffer"); + } + + LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, + ggml_backend_buffer_name (buf_output.get()), + ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); + } + } } llama_context::~llama_context() = default; +void llama_context::init() { + const auto & hparams = model.hparams; + + if (hparams.vocab_only) { + LLAMA_LOG_WARN("%s: model is vocab-only -- no computation will be performed\n", __func__); + return; + } + + // buffer types used for the compute buffer of each backend + std::vector backend_buft; + std::vector backend_ptrs; + for (auto & backend : backends) { + auto * buft = ggml_backend_get_default_buffer_type(backend.get()); + auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { + // use the host buffer of the first device CPU for faster transfer of the intermediate state + auto * dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft) { + buft = host_buft; + } + } + backend_buft.push_back(buft); + backend_ptrs.push_back(backend.get()); + } + + const size_t max_nodes = model.max_nodes(); + + // buffer used to store the computation graph and the tensor meta data + // TODO: move to base class + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + // TODO: move these checks to ggml_backend_sched + // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary + bool pipeline_parallel = + model.n_devices() > 1 && + model.params.n_gpu_layers > (int) model.hparams.n_layer && + model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + cparams.offload_kqv; + + // pipeline parallelism requires support for async compute and events in all devices + if (pipeline_parallel) { + for (auto & backend : backends) { + auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { + // ignore CPU backend + continue; + } + auto * dev = ggml_backend_get_device(backend.get()); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.events) { + // device does not support async compute or events + pipeline_parallel = false; + break; + } + } + } + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + + if (pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + } + + // initialize scheduler with the worst-case graph + { + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + + // reserve pp graph first so that buffers are only allocated once + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf_pp = build_graph(ubatch_pp, true); + if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + int n_nodes_pp = ggml_graph_n_nodes(gf_pp); + + // reserve with tg graph to get the number of splits and nodes + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + ggml_cgraph * gf_tg = build_graph(ubatch_tg, true); + if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + int n_nodes_tg = ggml_graph_n_nodes(gf_tg); + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + gf_pp = build_graph(ubatch_pp, true); + if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + } +} + const llama_model & llama_context::get_model() const { return model; } @@ -161,46 +437,6 @@ int64_t llama_context::n_pos_per_token() const { return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1; } -ggml_context_ptr llama_context::init() { - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - return ggml_context_ptr { ggml_init(params) }; -} - -void llama_context::synchronize() { - ggml_backend_sched_synchronize(sched.get()); - - // FIXME: if multiple single tokens are evaluated without a synchronization, - // the stats will be added to the prompt evaluation stats - // this should only happen when using batch size 1 to evaluate a batch - - // add the evaluation to the stats - if (n_queued_tokens == 1) { - if (!cparams.no_perf) { - t_eval_us += ggml_time_us() - t_compute_start_us; - } - n_eval++; - } else if (n_queued_tokens > 1) { - if (!cparams.no_perf) { - t_p_eval_us += ggml_time_us() - t_compute_start_us; - } - n_p_eval += n_queued_tokens; - } - - // get a more accurate load time, upon first eval - if (n_queued_tokens > 0 && !has_evaluated_once) { - t_load_us = ggml_time_us() - t_start_us; - has_evaluated_once = true; - } - - n_queued_tokens = 0; - t_compute_start_us = 0; -} - void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -269,7 +505,54 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -enum ggml_status llama_context::compute_graph( +void llama_context::synchronize() { + ggml_backend_sched_synchronize(sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (n_queued_tokens == 1) { + if (!cparams.no_perf) { + t_eval_us += ggml_time_us() - t_compute_start_us; + } + n_eval++; + } else if (n_queued_tokens > 1) { + if (!cparams.no_perf) { + t_p_eval_us += ggml_time_us() - t_compute_start_us; + } + n_p_eval += n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (n_queued_tokens > 0 && !has_evaluated_once) { + t_load_us = ggml_time_us() - t_start_us; + has_evaluated_once = true; + } + + n_queued_tokens = 0; + t_compute_start_us = 0; +} + +ggml_context_ptr llama_context::graph_init() { + inp_tokens = nullptr; + inp_embd = nullptr; + inp_pos = nullptr; + inp_out_ids = nullptr; + inp_mean = nullptr; + inp_cls = nullptr; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + return ggml_context_ptr { ggml_init(params) }; +} + +enum ggml_status llama_context::graph_compute( ggml_cgraph * graph, bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; @@ -608,7 +891,7 @@ void llama_context::build_cb( } ggml_cgraph * llama_context::build_graph(const llama_ubatch & ubatch, bool worst_case) { - return model.build_graph(*this, cparams, ubatch, init(), worst_case); + return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case); } llama_perf_context_data llama_context::perf_get_data() const { @@ -1183,100 +1466,15 @@ void llama_context::perf_reset() { llama_context_kv_self::llama_context_kv_self( const llama_model & model, - const llama_context_params & params) : llama_context(model) { + const llama_context_params & params) : + llama_context(model, params) { const auto & hparams = model.hparams; - cparams.n_seq_max = std::max(1u, params.n_seq_max); - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow; - cparams.defrag_thold = params.defrag_thold; - cparams.embeddings = params.embeddings; - cparams.offload_kqv = params.offload_kqv; - cparams.flash_attn = params.flash_attn; - cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; - cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; - cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams)); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams)); - - // with causal attention, the batch size is limited by the context size - cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - if (cparams.n_batch < GGML_KQ_MASK_PAD) { - LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = GGML_KQ_MASK_PAD; - } - - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - - cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : - hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : - hparams.n_ctx_train; - - cparams.cb_eval = params.cb_eval; - cparams.cb_eval_user_data = params.cb_eval_user_data; - - auto rope_scaling_type = params.rope_scaling_type; - if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { - rope_scaling_type = hparams.rope_scaling_type_train; - } - - if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { - cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none - } - - if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' - cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; - } - - cparams.yarn_attn_factor *= hparams.rope_attn_factor; - - if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { - if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { - cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; - } else { - cparams.pooling_type = hparams.pooling_type; - } - } - - if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { - cparams.causal_attn = hparams.causal_attn; - } else { - cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; - } - - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; - - LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); - LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); - LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); - LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - - if (n_ctx_per_seq < hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); - } - - if (n_ctx_per_seq > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); - } - - logits_all = params.logits_all; + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); // build worst-case graph for encoder if a model contains encoder is_encoding = llama_model_has_encoder(&model); // TODO: model.has_encoder() @@ -1298,51 +1496,6 @@ llama_context_kv_self::llama_context_kv_self( GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); if (!hparams.vocab_only) { - // GPU backends - for (auto * dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); - throw std::runtime_error("failed to initialize backend"); - } - backends.emplace_back(backend); - } - - // add ACCEL backends (such as BLAS) - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); - throw std::runtime_error("failed to initialize backend"); - } - backends.emplace_back(backend); - } - } - - // add CPU backend - backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - if (backend_cpu == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); - throw std::runtime_error("failed to initialize CPU backend"); - } - backends.emplace_back(backend_cpu); - - // create a list of the set_n_threads functions in the backends - for (auto & backend : backends) { - ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - if (reg) { - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); - } - } - } - - llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); - if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); throw std::runtime_error("failed to initialize self-attention cache"); @@ -1357,128 +1510,6 @@ llama_context_kv_self::llama_context_kv_self( ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - - // graph outputs buffer - { - // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { - LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__); - throw std::runtime_error("failed to reserve initial output buffer"); - } - - LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, - ggml_backend_buffer_name (buf_output.get()), - ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); - } - - // scheduler and compute buffers - { - // buffer types used for the compute buffer of each backend - std::vector backend_buft; - std::vector backend_ptrs; - for (auto & backend : backends) { - auto * buft = ggml_backend_get_default_buffer_type(backend.get()); - auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); - if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { - // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); - if (host_buft) { - buft = host_buft; - } - } - backend_buft.push_back(buft); - backend_ptrs.push_back(backend.get()); - } - - const size_t max_nodes = model.max_nodes(); - - // buffer used to store the computation graph and the tensor meta data - // TODO: move to base class - buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); - - // TODO: move these checks to ggml_backend_sched - // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary - bool pipeline_parallel = - model.n_devices() > 1 && - model.params.n_gpu_layers > (int) model.hparams.n_layer && - model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && - params.offload_kqv; - - // pipeline parallelism requires support for async compute and events in all devices - if (pipeline_parallel) { - for (auto & backend : backends) { - auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); - if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { - // ignore CPU backend - continue; - } - auto * dev = ggml_backend_get_device(backend.get()); - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - if (!props.caps.async || !props.caps.events) { - // device does not support async compute or events - pipeline_parallel = false; - break; - } - } - } - - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); - - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); - } - - // initialize scheduler with the worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_pp = build_graph(ubatch_pp, true); - - // reserve pp graph first so that buffers are only allocated once - ggml_backend_sched_reserve(sched.get(), gf_pp); - int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - int n_nodes_pp = ggml_graph_n_nodes(gf_pp); - - // reserve with tg graph to get the number of splits and nodes - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_tg = build_graph(ubatch_tg, true); - ggml_backend_sched_reserve(sched.get(), gf_tg); - int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - int n_nodes_tg = ggml_graph_n_nodes(gf_tg); - - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - gf_pp = build_graph(ubatch_pp, true); - if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); - } - - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); - if (size > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - size / 1024.0 / 1024.0); - } - } - - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); - } - } } } @@ -1497,15 +1528,7 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const { return &kv_self; } -ggml_context_ptr llama_context_kv_self::init() { - inp_tokens = nullptr; - inp_embd = nullptr; - inp_pos = nullptr; - inp_out_ids = nullptr; - inp_mean = nullptr; - inp_cls = nullptr; - inp_embd_enc = nullptr; - inp_pos_bucket = nullptr; +ggml_context_ptr llama_context_kv_self::graph_init() { inp_KQ_mask = nullptr; inp_KQ_mask_cnv = nullptr; inp_KQ_mask_swa = nullptr; @@ -1514,8 +1537,10 @@ ggml_context_ptr llama_context_kv_self::init() { inp_K_shift = nullptr; inp_s_copy = nullptr; inp_s_mask = nullptr; + inp_embd_enc = nullptr; + inp_pos_bucket = nullptr; - return llama_context::init(); + return llama_context::graph_init(); } struct llama_context_kv_self::batch_manager { @@ -1817,7 +1842,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { GGML_ASSERT(strcmp(t_logits->name, "result_output") == 0 && "missing result_output tensor"); } - const auto compute_status = compute_graph(gf, ubatch.n_tokens > 1); + const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { bman->restore(); switch (compute_status) { @@ -2035,7 +2060,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { } } - const auto compute_status = compute_graph(gf, n_tokens > 1); + const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: break; @@ -2422,7 +2447,7 @@ void llama_context_kv_self::kv_self_update() { if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched.get()); - auto ctx = init(); + auto ctx = graph_init(); auto ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -2433,7 +2458,7 @@ void llama_context_kv_self::kv_self_update() { input_set({}); - compute_graph(gf, false); + graph_compute(gf, false); need_reserve = true; } @@ -2451,7 +2476,7 @@ void llama_context_kv_self::kv_self_update() { if (kv.do_defrag) { ggml_backend_sched_reset(sched.get()); - auto ctx = init(); + auto ctx = graph_init(); auto ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -2463,7 +2488,7 @@ void llama_context_kv_self::kv_self_update() { // no input //input_set({}); - compute_graph(gf, false); + graph_compute(gf, false); kv.do_defrag = false; diff --git a/src/llama-context.h b/src/llama-context.h index f80401382..e70c99f33 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -21,9 +21,16 @@ class llama_io_write_i; using llama_loras = std::unordered_map; struct llama_context : public llama_graph_i { - llama_context(const llama_model & model); + llama_context( + const llama_model & model, + const llama_context_params & params); + virtual ~llama_context(); + // init scheduler and compute buffers + // call once after the context is constructed + virtual void init(); + const llama_model & get_model() const; const llama_cparams & get_cparams() const; @@ -52,10 +59,6 @@ struct llama_context : public llama_graph_i { virtual int64_t n_pos_per_token() const; // vision - virtual ggml_context_ptr init(); - - virtual void synchronize(); - virtual void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -85,8 +88,14 @@ struct llama_context : public llama_graph_i { int32_t il_start, int32_t il_end); + //// + + virtual void synchronize(); + + virtual ggml_context_ptr graph_init(); + // returns the result of ggml_backend_sched_graph_compute_async execution - virtual enum ggml_status compute_graph( + virtual enum ggml_status graph_compute( ggml_cgraph * graph, bool batched); @@ -297,7 +306,7 @@ public: virtual void kv_self_update() override; - virtual ggml_context_ptr init() override; + virtual ggml_context_ptr graph_init() override; virtual void input_set(const llama_ubatch & ubatch) override; @@ -312,7 +321,7 @@ public: // certain implementations could require a padding for the context size uint32_t get_ctx_padding(const llama_cparams & cparams) const; - // === unified KV cache === + // === KV cache === llama_kv_cache kv_self; diff --git a/src/llama.cpp b/src/llama.cpp index d20a2a6d5..a677902f0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -328,6 +328,7 @@ struct llama_context * llama_init_from_model( try { // TODO: add logic which llama_context implementation to construct ctx = new llama_context_kv_self(*model, params); + ctx->init(); } catch (const std::exception & e) { LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what()); return nullptr; @@ -410,7 +411,6 @@ const char * llama_print_system_info(void) { static std::string s; s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls. - for (size_t i = 0; i < ggml_backend_reg_count(); i++) { auto * reg = ggml_backend_reg_get(i); auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");