From 548c230dff1060820b7ef66653896accee3772cc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Feb 2025 12:10:57 +0200 Subject: [PATCH] graph : remove worst_case from the API ggml-ci --- src/llama-context.cpp | 1432 ++++++++++++++++++++-------------------- src/llama-context.h | 274 +++++--- src/llama-graph.cpp | 44 +- src/llama-graph.h | 39 +- src/llama-kv-cache.cpp | 1 + src/llama-model.cpp | 132 ++-- src/llama-model.h | 3 +- 7 files changed, 958 insertions(+), 967 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4ce54b0d6..dc1eb70b8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -17,11 +17,12 @@ llama_context::llama_context( const llama_model & model, const llama_context_params & params) : - model (model), - t_start_us(model.t_start_us), - t_load_us (model.t_load_us) { + model (model) { LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + t_start_us = model.t_start_us; + t_load_us = model.t_load_us; + const auto & hparams = model.hparams; cparams.n_seq_max = std::max(1u, params.n_seq_max); @@ -186,135 +187,173 @@ void llama_context::init() { return; } - // buffer types used for the compute buffer of each backend - std::vector backend_buft; - std::vector backend_ptrs; - for (auto & backend : backends) { - auto * buft = ggml_backend_get_default_buffer_type(backend.get()); - auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); - if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { - // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); - if (host_buft) { - buft = host_buft; - } - } - backend_buft.push_back(buft); - backend_ptrs.push_back(backend.get()); - } - - const size_t max_nodes = this->max_nodes(); - - // buffer used to store the computation graph and the tensor meta data - // TODO: move to base class - buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); - - // TODO: move these checks to ggml_backend_sched - // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary - bool pipeline_parallel = - model.n_devices() > 1 && - model.params.n_gpu_layers > (int) model.hparams.n_layer && - model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && - cparams.offload_kqv; - - // pipeline parallelism requires support for async compute and events in all devices - if (pipeline_parallel) { - for (auto & backend : backends) { - auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); - if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { - // ignore CPU backend - continue; - } - auto * dev = ggml_backend_get_device(backend.get()); - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - if (!props.caps.async || !props.caps.events) { - // device does not support async compute or events - pipeline_parallel = false; - break; - } - } - } - - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); - - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); - } - - // initialize scheduler with the worst-case graph { - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + // buffer types used for the compute buffer of each backend + backend_buft.clear(); + backend_ptrs.clear(); - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp graph first so that buffers are only allocated once - { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, true); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); + for (auto & backend : backends) { + auto * buft = ggml_backend_get_default_buffer_type(backend.get()); + auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { + // use the host buffer of the first device CPU for faster transfer of the intermediate state + auto * dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft) { + buft = host_buft; + } } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); + backend_buft.push_back(buft); + backend_ptrs.push_back(backend.get()); } - // reserve with tg graph to get the number of splits and nodes - { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, true); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); - } - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } + const size_t max_nodes = this->max_nodes(); - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, true); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); + // buffer used to store the computation graph and the tensor meta data + // TODO: move to base class + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + // TODO: move these checks to ggml_backend_sched + // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary + bool pipeline_parallel = + model.n_devices() > 1 && + model.params.n_gpu_layers > (int) model.hparams.n_layer && + model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + cparams.offload_kqv; + + // pipeline parallelism requires support for async compute and events in all devices + if (pipeline_parallel) { + for (auto & backend : backends) { + auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { + // ignore CPU backend + continue; + } + auto * dev = ggml_backend_get_device(backend.get()); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.events) { + // device does not support async compute or events + pipeline_parallel = false; + break; + } } } - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); - if (size > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - size / 1024.0 / 1024.0); - } + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + + if (pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + } + } + + reserve(); +} + +void llama_context::synchronize() { + ggml_backend_sched_synchronize(sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (n_queued_tokens == 1) { + if (!cparams.no_perf) { + t_eval_us += ggml_time_us() - t_compute_start_us; + } + n_eval++; + } else if (n_queued_tokens > 1) { + if (!cparams.no_perf) { + t_p_eval_us += ggml_time_us() - t_compute_start_us; + } + n_p_eval += n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (n_queued_tokens > 0 && !has_evaluated_once) { + t_load_us = ggml_time_us() - t_start_us; + has_evaluated_once = true; + } + + n_queued_tokens = 0; + t_compute_start_us = 0; +} + +void llama_context::reserve() { + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // max number of outputs + n_outputs = n_tokens; + + // reserve pp graph first so that buffers are only allocated once + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_pp); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); } - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + // reserve with tg graph to get the number of splits and nodes + { + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_tg); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); } + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch_pp); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + size / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } @@ -547,200 +586,6 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -void llama_context::synchronize() { - ggml_backend_sched_synchronize(sched.get()); - - // FIXME: if multiple single tokens are evaluated without a synchronization, - // the stats will be added to the prompt evaluation stats - // this should only happen when using batch size 1 to evaluate a batch - - // add the evaluation to the stats - if (n_queued_tokens == 1) { - if (!cparams.no_perf) { - t_eval_us += ggml_time_us() - t_compute_start_us; - } - n_eval++; - } else if (n_queued_tokens > 1) { - if (!cparams.no_perf) { - t_p_eval_us += ggml_time_us() - t_compute_start_us; - } - n_p_eval += n_queued_tokens; - } - - // get a more accurate load time, upon first eval - if (n_queued_tokens > 0 && !has_evaluated_once) { - t_load_us = ggml_time_us() - t_start_us; - has_evaluated_once = true; - } - - n_queued_tokens = 0; - t_compute_start_us = 0; -} - -ggml_cgraph * llama_context::graph_init() { - inp_tokens = nullptr; - inp_embd = nullptr; - inp_pos = nullptr; - inp_out_ids = nullptr; - inp_mean = nullptr; - inp_cls = nullptr; - - inp_kq_mask = nullptr; - inp_kq_mask_cnv = nullptr; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ctx_compute.reset(ggml_init(params)); - - return ggml_new_graph_custom(ctx_compute.get(), max_nodes(), false); -} - -llama_graph_result llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - bool worst_case) { - return model.build_graph(ctx, gf, this, cparams, ubatch, worst_case); -} - -enum ggml_status llama_context::graph_compute( - ggml_cgraph * gf, - bool batched) { - int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; - ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; - - if (backend_cpu != nullptr) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); - auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); - set_threadpool_fn(backend_cpu, tp); - } - - // set the number of threads for all the backends - for (const auto & set_n_threads_fn : set_n_threads_fns) { - set_n_threads_fn.second(set_n_threads_fn.first, n_threads); - } - - auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); - if (status != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); - } - - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); - - return status; -} - -int32_t llama_context::output_reserve(int32_t n_outputs) { - const auto & hparams = model.hparams; - const auto & vocab = model.vocab; - - const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); - - const auto n_batch = cparams.n_batch; - const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; - - // TODO: use a per-batch flag for logits presence instead - const bool has_logits = !cparams.embeddings; - const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); - - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; - - if (output_ids.empty()) { - // init, never resized afterwards - output_ids.resize(n_batch); - } - - const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); - - // alloc only when more than the current capacity is required - // TODO: also consider shrinking the buffer - if (!buf_output || prev_size < new_size) { - if (buf_output) { -#ifndef NDEBUG - // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) - LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); -#endif - buf_output = nullptr; - logits = nullptr; - embd = nullptr; - } - - auto * buft = ggml_backend_cpu_buffer_type(); - // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = model.dev_output(); - auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; - if (output_dev_host_buft) { - buft = output_dev_host_buft; - } - buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); - if (buf_output == nullptr) { - LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); - return 0; - } - } - - float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; - - output_size = n_outputs_max; - - // set all ids as invalid (negative) - std::fill(output_ids.begin(), output_ids.end(), -1); - - ggml_backend_buffer_clear(buf_output.get(), 0); - - n_outputs = 0; - - return n_outputs_max; -} - -void llama_context::output_reorder() { - auto & out_ids = sbatch.out_ids; - if (!out_ids.empty()) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } - } - std::fill(output_ids.begin(), output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - int llama_context::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -794,13 +639,11 @@ int llama_context::encode(llama_batch & inp_batch) { n_outputs = n_tokens; - GGML_ASSERT(need_reserve == false); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + auto res = graph_build(ctx_compute.get(), gf, ubatch); ggml_backend_sched_alloc_graph(sched.get(), gf); @@ -950,13 +793,11 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_all; - GGML_ASSERT(need_reserve == false); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + auto res = graph_build(ctx_compute.get(), gf, ubatch); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -1083,6 +924,438 @@ int llama_context::decode(llama_batch & inp_batch) { return 0; } +// +// input +// + +void llama_context::input_set(const llama_ubatch & ubatch) { + const llama_hparams & hparams = model.hparams; + + if (ubatch.token) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens)); + } + + if (ubatch.embd) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd)); + } + + if (ubatch.pos && inp_pos) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos)); + } + + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs"); + + if (!inp_out_ids) { + LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer)); + int32_t * data = (int32_t *) inp_out_ids->data; + + if (n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch.output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch.output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(n_outputs == n_outputs); + } else if (n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(n_outputs == 0); + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_mean); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer)); + + float * data = (float *) inp_mean->data; + memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean)); + + std::vector sum(n_tokens, 0); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch.n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } + + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); + + uint32_t * data = (uint32_t *) inp_cls->data; + memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); + + uint32_t * data = (uint32_t *) inp_cls->data; + memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + + if (inp_kq_mask) { + if (cparams.causal_attn) { + const int64_t n_kv = ubatch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); + float * data = (float *) inp_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; + } + } + } + } + } + } else { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_stride = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); + + float * data = (float *) inp_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } + + GGML_ASSERT( + // (!a || b) is a logical implication (a -> b) + // !hparams.causal_attn -> !cparams.causal_attn + (hparams.causal_attn || !cparams.causal_attn) && + "causal attention is not supported by this model" + ); +} + +// +// output +// + +int32_t llama_context::output_reserve(int32_t n_outputs) { + const auto & hparams = model.hparams; + const auto & vocab = model.vocab; + + const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; + + // TODO: use a per-batch flag for logits presence instead + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + + logits_size = has_logits ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (output_ids.empty()) { + // init, never resized afterwards + output_ids.resize(n_batch); + } + + const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; + const size_t new_size = (logits_size + embd_size) * sizeof(float); + + // alloc only when more than the current capacity is required + // TODO: also consider shrinking the buffer + if (!buf_output || prev_size < new_size) { + if (buf_output) { +#ifndef NDEBUG + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + buf_output = nullptr; + logits = nullptr; + embd = nullptr; + } + + auto * buft = ggml_backend_cpu_buffer_type(); + // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory + auto * output_dev = model.dev_output(); + auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; + if (output_dev_host_buft) { + buft = output_dev_host_buft; + } + buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); + if (buf_output == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); + return 0; + } + } + + float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); + + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; + + output_size = n_outputs_max; + + // set all ids as invalid (negative) + std::fill(output_ids.begin(), output_ids.end(), -1); + + ggml_backend_buffer_clear(buf_output.get(), 0); + + n_outputs = 0; + + return n_outputs_max; +} + +void llama_context::output_reorder() { + auto & out_ids = sbatch.out_ids; + if (!out_ids.empty()) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint32_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + std::fill(output_ids.begin(), output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + +// +// graph +// + +ggml_cgraph * llama_context::graph_init() { + inp_tokens = nullptr; + inp_embd = nullptr; + inp_pos = nullptr; + inp_out_ids = nullptr; + inp_mean = nullptr; + inp_cls = nullptr; + + inp_kq_mask = nullptr; + inp_kq_mask_cnv = nullptr; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_compute.reset(ggml_init(params)); + + return ggml_new_graph_custom(ctx_compute.get(), max_nodes(), false); +} + +llama_graph_result llama_context::graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch) { + return model.build_graph(ctx, gf, this, cparams, ubatch); +} + +enum ggml_status llama_context::graph_compute( + ggml_cgraph * gf, + bool batched) { + int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; + ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; + + if (backend_cpu != nullptr) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); + auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); + set_threadpool_fn(backend_cpu, tp); + } + + // set the number of threads for all the backends + for (const auto & set_n_threads_fn : set_n_threads_fns) { + set_n_threads_fn.second(set_n_threads_fn.first, n_threads); + } + + auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); + } + + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + + return status; +} + +// +// graph build API +// + void llama_context::build_cb( ggml_tensor * cur, const char * name, @@ -1307,10 +1580,8 @@ ggml_tensor * llama_context::build_inp_pos( } ggml_tensor * llama_context::build_inp_out_ids( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) { - const int32_t n_out_ids = worst_case ? n_tokens : n_outputs; + ggml_context * ctx0) { + const int32_t n_out_ids = n_outputs; inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids); ggml_set_input(inp_out_ids); @@ -1336,6 +1607,22 @@ ggml_tensor * llama_context::build_inp_cls( return inp_cls; } +void llama_context::build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) { + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch + GGML_UNUSED(causal); + GGML_UNUSED(swa); + + inp_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp_kq_mask, "KQ_mask", -1); + ggml_set_input(inp_kq_mask); + + inp_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_kq_mask, GGML_TYPE_F16) : inp_kq_mask; +} + ggml_tensor * llama_context::build_attn( ggml_context * ctx0, ggml_cgraph * gf, @@ -1346,8 +1633,7 @@ ggml_tensor * llama_context::build_attn( ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; const auto & n_ctx = cparams.n_ctx; @@ -1364,7 +1650,6 @@ ggml_tensor * llama_context::build_attn( const auto & n_embd_head_v = hparams.n_embd_head_v; // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - GGML_UNUSED(worst_case); const auto n_kv = n_tokens; struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); @@ -1455,24 +1740,6 @@ ggml_tensor * llama_context::build_attn( return cur; } -void llama_context::build_attn_inp( - ggml_context * ctx0, - int32_t n_tokens, - bool causal, - bool swa, - bool worst_case) { - // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - GGML_UNUSED(causal); - GGML_UNUSED(swa); - GGML_UNUSED(worst_case); - - inp_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_kq_mask, "KQ_mask", -1); - ggml_set_input(inp_kq_mask); - - inp_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_kq_mask, GGML_TYPE_F16) : inp_kq_mask; -} - // // perf // @@ -1497,7 +1764,7 @@ void llama_context::perf_reset() { } // -// state +// state save/load // class llama_io_write_dummy : public llama_io_write_i { @@ -1963,263 +2230,6 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_ return io.n_bytes(); } -// -// input -// - -void llama_context::input_set(const llama_ubatch & ubatch) { - const llama_hparams & hparams = model.hparams; - - if (ubatch.token) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens)); - } - - if (ubatch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd)); - } - - if (ubatch.pos && inp_pos) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos)); - } - - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - //GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs"); - - if (!inp_out_ids) { - LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__); - } else { - const int64_t n_tokens = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer)); - int32_t * data = (int32_t *) inp_out_ids->data; - - if (n_outputs == n_tokens) { - for (int i = 0; i < n_tokens; ++i) { - data[i] = i; - } - } else if (ubatch.output) { - int32_t n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - if (ubatch.output[i]) { - data[n_outputs++] = i; - } - } - // the graph needs to have been passed the correct number of outputs - GGML_ASSERT(n_outputs == n_outputs); - } else if (n_outputs == 1) { - // only keep last output - data[0] = n_tokens - 1; - } else { - GGML_ASSERT(n_outputs == 0); - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_mean); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer)); - - float * data = (float *) inp_mean->data; - memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean)); - - std::vector sum(n_tokens, 0); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - - sum[seq_id] += ubatch.n_seq_tokens; - } - - std::vector div(n_tokens, 0.0f); - for (int i = 0; i < n_tokens; ++i) { - const uint64_t s = sum[i]; - if (s > 0) { - div[i] = 1.0f/float(s); - } - } - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - for (int i = 0; i < n_seq_tokens; ++i) { - data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; - } - } - } - - if (cparams.embeddings && ( - cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || - cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); - - uint32_t * data = (uint32_t *) inp_cls->data; - memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos == 0) { - data[seq_id] = s*n_seq_tokens + i; - } - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); - - uint32_t * data = (uint32_t *) inp_cls->data; - memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); - - std::vector last_pos(n_tokens, -1); - std::vector last_row(n_tokens, -1); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos >= last_pos[seq_id]) { - last_pos[seq_id] = pos; - last_row[seq_id] = s*n_seq_tokens + i; - } - } - } - - for (int i = 0; i < n_tokens; ++i) { - if (last_row[i] >= 0) { - data[i] = last_row[i]; - } - } - } - - if (inp_kq_mask) { - if (cparams.causal_attn) { - const int64_t n_kv = ubatch.n_tokens; - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); - float * data = (float *) inp_kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; - } - } - } - } - } - } else { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - const int64_t n_stride = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); - - float * data = (float *) inp_kq_mask->data; - - for (int h = 0; h < 1; ++h) { - for (int s1 = 0; s1 < n_seqs; ++s1) { - const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const int32_t tj = s1*n_seq_tokens + j; - - for (int s0 = 0; s0 < n_seqs; ++s0) { - for (int i = 0; i < n_seq_tokens; ++i) { - const int32_t ti = s0*n_seq_tokens + i; - float f = -INFINITY; - - for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { - if (ubatch.seq_id[s0][s] == seq_id) { - if (hparams.use_alibi) { - f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); - } else { - f = 0.0f; - } - break; - } - } - - data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; - } - } - - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; - } - } - } - } - } - } - - GGML_ASSERT( - // (!a || b) is a logical implication (a -> b) - // !hparams.causal_attn -> !cparams.causal_attn - (hparams.causal_attn || !cparams.causal_attn) && - "causal attention is not supported by this model" - ); -} - // // llama_context_kv_self // @@ -2235,7 +2245,7 @@ llama_context_kv_self::llama_context_kv_self( LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, get_ctx_padding(cparams)); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self.get_padding(cparams)); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -2271,6 +2281,13 @@ llama_context_kv_self::llama_context_kv_self( llama_context_kv_self::~llama_context_kv_self() = default; +void llama_context_kv_self::reserve() { + // simulate full KV cache + kv_self.n = kv_self.size; + + llama_context::reserve(); +} + llama_kv_cache * llama_context_kv_self::get_kv_self() { return &kv_self; } @@ -2282,6 +2299,8 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const { void llama_context_kv_self::kv_self_update() { auto & kv = kv_self; + bool need_reserve = false; + if (kv.has_shift) { if (!kv.can_shift) { GGML_ABORT("The current context does not support K-shift"); @@ -2332,20 +2351,30 @@ void llama_context_kv_self::kv_self_update() { need_reserve = true; } -} -ggml_cgraph * llama_context_kv_self::graph_init() { - inp_embd_enc = nullptr; - inp_pos_bucket = nullptr; - inp_kq_mask_cross = nullptr; + // reserve a worst case graph if needed + if (need_reserve) { + LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - inp_self_kq_mask = nullptr; - inp_self_kq_mask_cnv = nullptr; - inp_self_kq_mask_swa = nullptr; - inp_self_kq_mask_swa_cnv = nullptr; - inp_self_k_shift = nullptr; + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - return llama_context::graph_init(); + // simulate full KV cache + kv_self.n = kv_self.size; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(sched.get()); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + } } int llama_context_kv_self::encode(llama_batch & inp_batch) { @@ -2406,14 +2435,11 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { //batch_manager->prepare(ubatch); - // TODO: do reserve - GGML_ASSERT(need_reserve == false); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + auto res = graph_build(ctx_compute.get(), gf, ubatch); ggml_backend_sched_alloc_graph(sched.get(), gf); @@ -2658,42 +2684,18 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = get_ctx_padding(cparams); + const uint32_t pad = kv_self.get_padding(cparams); kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); } } //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, true); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } - - need_reserve = false; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + auto res = graph_build(ctx_compute.get(), gf, ubatch); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -2841,7 +2843,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f; + const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + kv_self.get_padding(cparams))/float(kv_self.n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { @@ -2858,12 +2860,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { return 0; } -uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const { - return kv_self.get_padding(cparams); -} - -// llama input - void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; @@ -3095,6 +3091,20 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } } +ggml_cgraph * llama_context_kv_self::graph_init() { + inp_embd_enc = nullptr; + inp_pos_bucket = nullptr; + inp_kq_mask_cross = nullptr; + + inp_self_kq_mask = nullptr; + inp_self_kq_mask_cnv = nullptr; + inp_self_kq_mask_swa = nullptr; + inp_self_kq_mask_swa_cnv = nullptr; + inp_self_k_shift = nullptr; + + return llama_context::graph_init(); +} + ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) { inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx()); ggml_set_input(inp_self_k_shift); @@ -3106,9 +3116,8 @@ void llama_context_kv_self::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa, - bool worst_case) { - const auto n_kv = worst_case ? kv_self.size : kv_self.n; + bool swa) { + const auto n_kv = kv_self.n; inp_self_kq_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) @@ -3143,8 +3152,7 @@ ggml_tensor * llama_context_kv_self::build_attn( ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; const auto & n_ctx = cparams.n_ctx; @@ -3156,7 +3164,7 @@ ggml_tensor * llama_context_kv_self::build_attn( { GGML_ASSERT(!kv_self.recurrent); - const auto kv_head = worst_case ? kv_self.size - n_tokens : kv_self.head; + const auto kv_head = kv_self.head; GGML_ASSERT(kv_self.size == n_ctx); @@ -3211,7 +3219,7 @@ ggml_tensor * llama_context_kv_self::build_attn( const auto & kq_mask = is_sliding ? inp_self_kq_mask_swa_cnv : inp_self_kq_mask_cnv; - const auto n_kv = worst_case ? kv_self.size : kv_self.n; + const auto n_kv = kv_self.n; const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); @@ -3626,14 +3634,12 @@ void llama_context_kv_self::build_kv_self_defrag( } ggml_tensor * llama_context_kv_self::build_inp_embd_enc( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) { + ggml_context * ctx0) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; // TODO: not sure if this is correct - const int32_t n_outputs_enc = worst_case ? n_tokens : embd_enc.size() / n_embd; + const int32_t n_outputs_enc = embd_enc.size() / n_embd; inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); ggml_set_input(inp_embd_enc); @@ -3643,13 +3649,12 @@ ggml_tensor * llama_context_kv_self::build_inp_embd_enc( ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) { + int32_t n_tokens) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; // TODO: not sure if this is correct - const int32_t n_outputs_enc = worst_case ? n_tokens : embd_enc.size() / n_embd; + const int32_t n_outputs_enc = embd_enc.size() / n_embd; inp_kq_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); ggml_set_input(inp_kq_mask_cross); @@ -3738,6 +3743,11 @@ llama_context_recurrent::llama_context_recurrent( llama_context_recurrent::~llama_context_recurrent() = default; +void llama_context_recurrent::reserve() { + // TODO: implement recurrent-specific reserve logic + llama_context::reserve(); +} + llama_kv_cache * llama_context_recurrent::get_kv_self() { return &kv_self; } @@ -3750,13 +3760,6 @@ void llama_context_recurrent::kv_self_update() { // noop } -ggml_cgraph * llama_context_recurrent::graph_init() { - inp_s_copy = nullptr; - inp_s_mask = nullptr; - - return llama_context::graph_init(); -} - int llama_context_recurrent::encode(llama_batch & inp_batch) { GGML_UNUSED(inp_batch); @@ -3917,34 +3920,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, true); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } - - need_reserve = false; - } - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + auto res = graph_build(ctx_compute.get(), gf, ubatch); // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -4147,24 +4127,32 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { } } +ggml_cgraph * llama_context_recurrent::graph_init() { + inp_s_copy = nullptr; + inp_s_mask = nullptr; + + return llama_context::graph_init(); +} + ggml_tensor * llama_context_recurrent::build_inp_s_copy( - ggml_context * ctx0, - bool worst_case) { - const auto n_kv = worst_case ? kv_self.size : kv_self.n; + ggml_context * ctx0) { + const auto n_kv = kv_self.n; inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); //cb(inp_s_copy, "inp_s_copy", -1); ggml_set_input(inp_s_copy); + return inp_s_copy; } ggml_tensor * llama_context_recurrent::build_inp_s_mask( - ggml_context * ctx0, - bool worst_case) { - const auto n_kv = worst_case ? kv_self.size : kv_self.n; + ggml_context * ctx0) { + const auto n_kv = kv_self.n; + inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); //cb(inp_s_mask, "inp_s_mask", -1); ggml_set_input(inp_s_mask); + return inp_s_mask; } @@ -4174,12 +4162,10 @@ ggml_tensor * llama_context_recurrent::build_copy_mask_state( ggml_tensor * s, ggml_tensor * state_copy, ggml_tensor * state_mask, - int32_t n_tokens, int32_t n_state, - int32_t n_seqs, - bool worst_case) { - const auto n_kv = worst_case ? kv_self.size : kv_self.n; - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + int32_t n_seqs) { + const auto n_kv = kv_self.n; + const auto kv_head = kv_self.head; struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self.size); @@ -4210,13 +4196,10 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; - const auto & n_tokens = ubatch.n_tokens; - - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + const auto kv_head = kv_self.head; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -4240,11 +4223,11 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer( // (ab)using the KV cache to store the states struct ggml_tensor * conv = build_copy_mask_state( ctx0, gf, conv_states_all, state_copy, state_mask, - n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case); + hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); struct ggml_tensor * ssm = build_copy_mask_state( ctx0, gf, ssm_states_all, state_copy, state_mask, - n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case); + hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -4345,20 +4328,18 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; const auto token_shift_count = hparams.token_shift_count; - const auto & n_tokens = ubatch.n_tokens; const int64_t n_seqs = ubatch.n_seqs; struct ggml_tensor * token_shift_all = kv_self.k_l[il]; struct ggml_tensor * token_shift = build_copy_mask_state( ctx0, gf, token_shift_all, state_copy, state_mask, - n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case); + hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); @@ -4369,17 +4350,15 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; - const auto & n_tokens = ubatch.n_tokens; - const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + const auto kv_head = kv_self.head; return ggml_cpy( ctx0, @@ -4396,8 +4375,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { const auto & hparams = model.hparams; const auto n_tokens = ubatch.n_tokens; @@ -4407,7 +4385,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + const auto kv_head = kv_self.head; const auto & layer = model.layers[il]; @@ -4516,7 +4494,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( struct ggml_tensor * wkv_state = build_copy_mask_state( ctx0, gf, kv_self.v_l[il], state_copy, state_mask, - n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case); + hparams.n_embd_v_s(), n_seqs); struct ggml_tensor * wkv_output; if (is_qrwkv) { diff --git a/src/llama-context.h b/src/llama-context.h index 9d8b70220..d4ab5d509 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -22,16 +22,25 @@ using llama_loras = std::unordered_map; // basic transformer without KV cache struct llama_context : public llama_graph_i { +public: llama_context( const llama_model & model, const llama_context_params & params); virtual ~llama_context(); - // init scheduler and compute buffers + // init scheduler and compute buffers, reserve worst-case graphs // call once after the context is constructed virtual void init(); + virtual void synchronize(); + +protected: + // called by init() to reserve the worst-case graphs + // override in child classes + virtual void reserve(); + +public: const llama_model & get_model() const; const llama_cparams & get_cparams() const; @@ -93,33 +102,6 @@ struct llama_context : public llama_graph_i { int32_t il_start, int32_t il_end); - //// - - virtual void synchronize(); - - // zero-out inputs and create ggml_context - virtual ggml_cgraph * graph_init(); - - // TODO: add encode/decode graphs - virtual llama_graph_result graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - bool worst_case); - - // returns the result of ggml_backend_sched_graph_compute_async execution - virtual enum ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); - - // Make sure enough space is available for outputs. - // Returns max number of outputs for which space was reserved. - virtual int32_t output_reserve(int32_t n_outputs); - - // make the outputs have the same order they had in the user-provided batch - // TODO: maybe remove this - virtual void output_reorder(); - // encode a batch of tokens by evaluating the encoder part of the transformer // // - lctx: llama context @@ -145,6 +127,60 @@ struct llama_context : public llama_graph_i { // virtual int decode(llama_batch & inp_batch); +protected: + // + // input + // + + // when the compute graph is built, it creates the input tensors that it needs + // the contents of the input tensors are set by the input_set() function + + virtual void input_set(const llama_ubatch & ubatch); + + // base input tensors + ggml_tensor * inp_tokens; // I32 [n_batch] + ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + ggml_tensor * inp_pos; // I32 [n_batch] + ggml_tensor * inp_out_ids; // I32 [n_outputs] + ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + ggml_tensor * inp_cls; // I32 [n_batch] + + // KQ mask input tensors + ggml_tensor * inp_kq_mask; // F32 [n_tokens, n_batch] + ggml_tensor * inp_kq_mask_cnv; // [n_tokens, n_batch] + + // + // output + // + + // Make sure enough space is available for outputs. + // Returns max number of outputs for which space was reserved. + virtual int32_t output_reserve(int32_t n_outputs); + + // make the outputs have the same order they had in the user-provided batch + // TODO: maybe remove this + virtual void output_reorder(); + + // + // graph + // + + // zero-out inputs and create the ctx_context for the compute graph + virtual ggml_cgraph * graph_init(); + + // TODO: add encode/decode graphs + virtual llama_graph_result graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch); + + // returns the result of ggml_backend_sched_graph_compute_async execution + virtual enum ggml_status graph_compute( + ggml_cgraph * gf, + bool batched); + + ggml_context_ptr ctx_compute; + // // graph build API (generic) // @@ -193,9 +229,7 @@ struct llama_context : public llama_graph_i { int32_t n_tokens); virtual ggml_tensor * build_inp_out_ids( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case); + ggml_context * ctx0); virtual ggml_tensor * build_inp_mean( ggml_context * ctx0, @@ -209,8 +243,7 @@ struct llama_context : public llama_graph_i { ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa, - bool worst_case); + bool swa); virtual ggml_tensor * build_attn( ggml_context * ctx0, @@ -222,15 +255,32 @@ struct llama_context : public llama_graph_i { ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case); + int il); +public: + // // perf + // virtual llama_perf_context_data perf_get_data() const; virtual void perf_reset(); +protected: + mutable int64_t t_start_us = 0; + mutable int64_t t_load_us = 0; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; + + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; + + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls + +public: + // // state save/load + // virtual size_t state_get_size(); virtual size_t state_get_data( uint8_t * dst, size_t size); @@ -265,31 +315,15 @@ struct llama_context : public llama_graph_i { size_t n_token_count); protected: - // state save/load - virtual size_t state_get_data(llama_io_write_i & io); virtual size_t state_set_data(llama_io_read_i & io); virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); - // input - - virtual void input_set(const llama_ubatch & ubatch); - - // base input tensors - ggml_tensor * inp_tokens; // I32 [n_batch] - ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - ggml_tensor * inp_pos; // I32 [n_batch] - ggml_tensor * inp_out_ids; // I32 [n_outputs] - ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - ggml_tensor * inp_cls; // I32 [n_batch] - - // KQ mask input tensors - ggml_tensor * inp_kq_mask; // F32 [n_tokens, n_batch] - ggml_tensor * inp_kq_mask_cnv; // [n_tokens, n_batch] - + // // members + // const llama_model & model; @@ -311,7 +345,9 @@ protected: ggml_backend_sched_ptr sched; - ggml_context_ptr ctx_compute; + // buffer types used for the compute buffer of each backend + std::vector backend_ptrs; + std::vector backend_buft; // memory buffers used to evaluate the model std::vector buf_compute_meta; @@ -340,19 +376,7 @@ protected: std::vector output_ids; // map batch token positions to ids of the logits and embd buffers - bool need_reserve = false; bool has_evaluated_once = false; - - mutable int64_t t_start_us = 0; - mutable int64_t t_load_us = 0; - mutable int64_t t_p_eval_us = 0; - mutable int64_t t_eval_us = 0; - - mutable int64_t t_compute_start_us = 0; - mutable int64_t n_queued_tokens = 0; - - mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - mutable int32_t n_eval = 0; // number of eval calls }; // transformer with a self-attention KV cache @@ -364,18 +388,40 @@ public: virtual ~llama_context_kv_self(); +protected: + virtual void reserve() override; + +public: virtual llama_kv_cache * get_kv_self() override; virtual const llama_kv_cache * get_kv_self() const override; virtual void kv_self_update() override; - virtual ggml_cgraph * graph_init() override; - virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; - // certain implementations could require a padding for the context size - uint32_t get_ctx_padding(const llama_cparams & cparams) const; +protected: + // + // input + // + + virtual void input_set(const llama_ubatch & ubatch) override; + + ggml_tensor * inp_self_kq_mask; // F32 [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_cnv; // [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_swa; // F32 [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_swa_cnv; // [kv_size, n_batch] + ggml_tensor * inp_self_k_shift; // I32 [kv_size] + + // + // graph + // + + virtual ggml_cgraph * graph_init() override; + + // + // graph build + // virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override; @@ -383,8 +429,7 @@ public: ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa, - bool worst_case) override; + bool swa) override; virtual ggml_tensor * build_attn( ggml_context * ctx0, @@ -396,8 +441,7 @@ public: ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case) override; + int il) override; virtual void build_kv_self_shift( ggml_context * ctx0, @@ -422,31 +466,27 @@ public: struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch] virtual ggml_tensor * build_inp_embd_enc( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) override; + ggml_context * ctx0) override; virtual ggml_tensor * build_inp_kq_mask_cross( ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) override; + int32_t n_tokens) override; + + // + // state save/load + // -protected: virtual size_t state_get_data(llama_io_write_i & io) override; virtual size_t state_set_data(llama_io_read_i & io) override; virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; - virtual void input_set(const llama_ubatch & ubatch) override; + // + // members + // llama_kv_cache kv_self; - - ggml_tensor * inp_self_kq_mask; // F32 [kv_size, n_batch] - ggml_tensor * inp_self_kq_mask_cnv; // [kv_size, n_batch] - ggml_tensor * inp_self_kq_mask_swa; // F32 [kv_size, n_batch] - ggml_tensor * inp_self_kq_mask_swa_cnv; // [kv_size, n_batch] - ggml_tensor * inp_self_k_shift; // I32 [kv_size] }; // a recurrent transformer (ie.e RWKV, Mamba) @@ -458,23 +498,43 @@ public: virtual ~llama_context_recurrent(); +protected: + virtual void reserve() override; + +public: virtual llama_kv_cache * get_kv_self() override; virtual const llama_kv_cache * get_kv_self() const override; virtual void kv_self_update() override; - virtual ggml_cgraph * graph_init() override; - virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; +protected: + // + // input + // + + virtual void input_set(const llama_ubatch & ubatch) override; + + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + + // + // graph + // + + virtual ggml_cgraph * graph_init() override; + + // + // graph build + // + virtual ggml_tensor * build_inp_s_copy( - ggml_context * ctx0, - bool worst_case) override; + ggml_context * ctx0) override; virtual ggml_tensor * build_inp_s_mask( - ggml_context * ctx0, - bool worst_case) override; + ggml_context * ctx0) override; virtual ggml_tensor * build_copy_mask_state( ggml_context * ctx0, @@ -482,10 +542,8 @@ public: ggml_tensor * s, ggml_tensor * state_copy, ggml_tensor * state_mask, - int32_t n_tokens, int32_t n_state, - int32_t n_seqs, - bool worst_case) override; + int32_t n_seqs) override; virtual ggml_tensor * build_mamba_layer( ggml_context * ctx0, @@ -494,8 +552,7 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) override; + int il) override; virtual ggml_tensor * build_rwkv_token_shift_load( ggml_context * ctx0, @@ -503,15 +560,13 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) override; + int il) override; virtual ggml_tensor * build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, - int il, - bool worst_case) override; + int il) override; virtual ggml_tensor * build_rwkv6_time_mix( ggml_context * ctx0, @@ -521,23 +576,24 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) override; + int il) override; + + // + // state save/load + // -protected: virtual size_t state_get_data(llama_io_write_i & io) override; virtual size_t state_set_data(llama_io_read_i & io) override; virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; - virtual void input_set(const llama_ubatch & ubatch) override; + // + // members + // // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models? llama_kv_cache_recurrent kv_self; - - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] }; // For internal test use diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d9d4e00e9..af556f5bb 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,8 +12,7 @@ ggml_tensor * llama_graph_i::build_attn( ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case) { + int il) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(wo); @@ -24,7 +23,6 @@ ggml_tensor * llama_graph_i::build_attn( GGML_UNUSED(n_tokens); GGML_UNUSED(kq_scale); GGML_UNUSED(il); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); return nullptr; @@ -57,12 +55,8 @@ ggml_tensor * llama_graph_i::build_inp_self_k_shift( } ggml_tensor * llama_graph_i::build_inp_embd_enc( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) { + ggml_context * ctx0) { GGML_UNUSED(ctx0); - GGML_UNUSED(n_tokens); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); return nullptr; @@ -70,21 +64,17 @@ ggml_tensor * llama_graph_i::build_inp_embd_enc( ggml_tensor * llama_graph_i::build_inp_kq_mask_cross( ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) { + int32_t n_tokens) { GGML_UNUSED(ctx0); GGML_UNUSED(n_tokens); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); return nullptr; } ggml_tensor * llama_graph_i::build_inp_s_copy ( - ggml_context * ctx0, - bool worst_case) { + ggml_context * ctx0) { GGML_UNUSED(ctx0); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -92,10 +82,8 @@ ggml_tensor * llama_graph_i::build_inp_s_copy ( } ggml_tensor * llama_graph_i::build_inp_s_mask( - ggml_context * ctx0, - bool worst_case) { + ggml_context * ctx0) { GGML_UNUSED(ctx0); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -108,19 +96,15 @@ ggml_tensor * llama_graph_i::build_copy_mask_state( ggml_tensor * s, ggml_tensor * state_copy, ggml_tensor * state_mask, - int32_t n_tokens, int32_t n_state, - int32_t n_seqs, - bool worst_case) { + int32_t n_seqs) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(s); GGML_UNUSED(state_copy); GGML_UNUSED(state_mask); - GGML_UNUSED(n_tokens); GGML_UNUSED(n_state); GGML_UNUSED(n_seqs); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -134,8 +118,7 @@ ggml_tensor * llama_graph_i::build_mamba_layer( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(cur); @@ -143,7 +126,6 @@ ggml_tensor * llama_graph_i::build_mamba_layer( GGML_UNUSED(state_mask); GGML_UNUSED(ubatch); GGML_UNUSED(il); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -156,15 +138,13 @@ ggml_tensor * llama_graph_i::build_rwkv_token_shift_load( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(state_copy); GGML_UNUSED(state_mask); GGML_UNUSED(ubatch); GGML_UNUSED(il); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -175,13 +155,11 @@ ggml_tensor * llama_graph_i::build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { GGML_UNUSED(ctx0); GGML_UNUSED(token_shift); GGML_UNUSED(ubatch); GGML_UNUSED(il); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); @@ -196,8 +174,7 @@ ggml_tensor * llama_graph_i::build_rwkv6_time_mix( ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case) { + int il) { GGML_UNUSED(ctx0); GGML_UNUSED(gf); GGML_UNUSED(cur); @@ -206,7 +183,6 @@ ggml_tensor * llama_graph_i::build_rwkv6_time_mix( GGML_UNUSED(state_mask); GGML_UNUSED(ubatch); GGML_UNUSED(il); - GGML_UNUSED(worst_case); LLAMA_LOG_ERROR("%s: not implemented\n", __func__); diff --git a/src/llama-graph.h b/src/llama-graph.h index 8d237431e..05349e587 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -69,9 +69,7 @@ public: int32_t n_tokens) = 0; virtual ggml_tensor * build_inp_out_ids( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) = 0; + ggml_context * ctx0) = 0; virtual ggml_tensor * build_inp_mean( ggml_context * ctx0, @@ -85,8 +83,7 @@ public: ggml_context * ctx0, int32_t n_tokens, bool causal, - bool swa, - bool worst_case) = 0; + bool swa) = 0; virtual ggml_tensor * build_attn( ggml_context * ctx0, @@ -98,8 +95,7 @@ public: ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, - int il, - bool worst_case); + int il); virtual void build_kv_self_shift( ggml_context * ctx0, @@ -114,22 +110,17 @@ public: ggml_context * ctx0); virtual ggml_tensor * build_inp_embd_enc( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case); + ggml_context * ctx0); virtual ggml_tensor * build_inp_kq_mask_cross( ggml_context * ctx0, - int32_t n_tokens, - bool worst_case); + int32_t n_tokens); virtual ggml_tensor * build_inp_s_copy( - ggml_context * ctx0, - bool worst_case); + ggml_context * ctx0); virtual ggml_tensor * build_inp_s_mask( - ggml_context * ctx0, - bool worst_case); + ggml_context * ctx0); virtual ggml_tensor * build_copy_mask_state( ggml_context * ctx0, @@ -137,10 +128,8 @@ public: ggml_tensor * s, ggml_tensor * state_copy, ggml_tensor * state_mask, - int32_t n_tokens, int32_t n_state, - int32_t n_seqs, - bool worst_case); + int32_t n_seqs); virtual ggml_tensor * build_mamba_layer( ggml_context * ctx0, @@ -149,8 +138,7 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case); + int il); virtual ggml_tensor * build_rwkv_token_shift_load( ggml_context * ctx0, @@ -158,15 +146,13 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case); + int il); virtual ggml_tensor * build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, - int il, - bool worst_case); + int il); virtual ggml_tensor * build_rwkv6_time_mix( ggml_context * ctx0, @@ -176,6 +162,5 @@ public: ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, - int il, - bool worst_case); + int il); }; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3aec6495f..e1b07c993 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -610,6 +610,7 @@ struct llama_kv_cache_slot_info llama_kv_cache::find_slot( // sanity check return llama_kv_cache_slot_info(n >= n_seqs); } + // otherwise, one cell per token. if (n_tokens > size) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a0a7816da..8eb99995e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3834,7 +3834,6 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_ctx_orig; - const bool worst_case; const bool flash_attn; const enum llama_pooling_type pooling_type; @@ -3851,8 +3850,7 @@ struct llm_build_context { llama_graph_i * lgf, const llama_model & model, const llama_cparams & cparams, - const llama_ubatch & ubatch, - bool worst_case) : + const llama_ubatch & ubatch) : model (model), hparams (model.hparams), cparams (cparams), @@ -3879,7 +3877,6 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), n_ctx_orig (cparams.n_ctx_orig_yarn), - worst_case (worst_case), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -3910,7 +3907,7 @@ struct llm_build_context { // TODO: tmp struct ggml_tensor * build_inp_out_ids() { - ggml_tensor * cur = lgf->build_inp_out_ids(ctx0, n_tokens, worst_case); + ggml_tensor * cur = lgf->build_inp_out_ids(ctx0); cb(cur, "inp_out_ids", -1); return cur; @@ -3949,7 +3946,7 @@ struct llm_build_context { // TODO: tmp struct ggml_tensor * build_inp_embd_enc() { - ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0, n_tokens, worst_case); + ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0); cb(cur, "embd_enc", -1); return cur; @@ -3957,7 +3954,7 @@ struct llm_build_context { // TODO: tmp struct ggml_tensor * build_inp_kq_mask_cross() { - ggml_tensor * cur = lgf->build_inp_kq_mask_cross(ctx0, n_tokens, worst_case); + ggml_tensor * cur = lgf->build_inp_kq_mask_cross(ctx0, n_tokens); cb(cur, "KQ_mask_cross", -1); return cur; @@ -4258,7 +4255,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, n_tokens, kq_scale, il, worst_case); + ggml_tensor * cur = lgf->build_attn(ctx0, gf, wo, wo_b, q_cur, k_cur, v_cur, n_tokens, kq_scale, il); cb(cur, "kqv_out", il); return cur; @@ -4405,7 +4402,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; for (int il = 0; il < n_layer; ++il) { @@ -4566,7 +4563,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; for (int il = 0; il < n_layer; ++il) { @@ -4722,7 +4719,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -4838,7 +4835,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -4943,7 +4940,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -5066,7 +5063,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5218,7 +5215,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5340,7 +5337,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -5441,7 +5438,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5555,7 +5552,7 @@ struct llm_build_context { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); cb(inpL, "inp_norm", -1); - lgf->build_attn_inp(ctx0, n_tokens, false, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, false, false); // iterate layers for (int il = 0; il < n_layer; ++il) { @@ -5700,7 +5697,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); inpL = build_norm(inpL, model.tok_norm, @@ -5803,7 +5800,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); if (model.pos_embd) { // inp_pos - contains the positions @@ -5945,7 +5942,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -6096,7 +6093,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6210,7 +6207,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6323,7 +6320,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -6441,7 +6438,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6588,7 +6585,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { attn_norm_output = build_norm(inpL, @@ -6711,7 +6708,7 @@ struct llm_build_context { struct ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - lgf->build_attn_inp(ctx0, n_tokens, true, true, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, true); for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -6855,7 +6852,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -6961,7 +6958,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -7067,7 +7064,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -7178,7 +7175,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7297,7 +7294,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7425,7 +7422,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7626,7 +7623,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { // norm @@ -7734,7 +7731,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, true, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, true); for (int il = 0; il < n_layer; ++il) { // norm @@ -7864,7 +7861,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7977,8 +7974,8 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0, worst_case); - struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0, worst_case); + struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0); + struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0); for (int il = 0; il < n_layer; ++il) { // norm @@ -7988,7 +7985,7 @@ struct llm_build_context { cb(cur, "attn_norm", il); //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); - cur = lgf->build_mamba_layer(ctx0, gf, cur, state_copy, state_mask, ubatch, il, worst_case); + cur = lgf->build_mamba_layer(ctx0, gf, cur, state_copy, state_mask, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -8039,7 +8036,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { @@ -8187,7 +8184,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, true, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, true); // sliding window switch pattern const int32_t sliding_window_pattern = 4; @@ -8322,7 +8319,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8442,7 +8439,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8566,7 +8563,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8687,7 +8684,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { const int64_t n_head = hparams.n_head(il); @@ -8815,7 +8812,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -8959,7 +8956,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9089,7 +9086,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -9252,7 +9249,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9470,7 +9467,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -9951,7 +9948,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { cur = build_norm(inpL, @@ -10045,7 +10042,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10175,7 +10172,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10296,7 +10293,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10414,8 +10411,8 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0, worst_case); - struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0, worst_case); + struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0); + struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -10425,7 +10422,7 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; struct ggml_tensor * token_shift = lgf->build_rwkv_token_shift_load( - ctx0, gf, state_copy, state_mask, ubatch, il, worst_case + ctx0, gf, state_copy, state_mask, ubatch, il ); struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -10441,7 +10438,7 @@ struct llm_build_context { 1 ); - cur = lgf->build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case); + cur = lgf->build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -10464,7 +10461,7 @@ struct llm_build_context { ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), 1 ); - ggml_build_forward_expand(gf, lgf->build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case)); + ggml_build_forward_expand(gf, lgf->build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il)); if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { cur = ggml_scale(ctx0, cur, 0.5F); @@ -10506,8 +10503,8 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); - struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0, worst_case); - struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0, worst_case); + struct ggml_tensor * state_copy = lgf->build_inp_s_copy(ctx0); + struct ggml_tensor * state_mask = lgf->build_inp_s_mask(ctx0); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -10519,7 +10516,7 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; struct ggml_tensor * token_shift = lgf->build_rwkv_token_shift_load( - ctx0, gf, state_copy, state_mask, ubatch, il, worst_case + ctx0, gf, state_copy, state_mask, ubatch, il ); struct ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -10532,10 +10529,10 @@ struct llm_build_context { 1 ); - cur = lgf->build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case); + cur = lgf->build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); - ggml_build_forward_expand(gf, lgf->build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case)); + ggml_build_forward_expand(gf, lgf->build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il)); struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -10601,7 +10598,7 @@ struct llm_build_context { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); - lgf->build_attn_inp(ctx0, n_tokens, true, false, worst_case); + lgf->build_attn_inp(ctx0, n_tokens, true, false); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -10912,9 +10909,8 @@ llama_graph_result llama_model::build_graph( ggml_cgraph * gf, llama_graph_i * lgf, const llama_cparams & cparams, - const llama_ubatch & ubatch, - bool worst_case) const { - struct llm_build_context llm(ctx, lgf, *this, cparams, ubatch, worst_case); + const llama_ubatch & ubatch) const { + struct llm_build_context llm(ctx, lgf, *this, cparams, ubatch); switch (arch) { case LLM_ARCH_LLAMA: diff --git a/src/llama-model.h b/src/llama-model.h index 94e762294..b2d75e593 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -374,8 +374,7 @@ struct llama_model { ggml_cgraph * gf, llama_graph_i * lgf, const llama_cparams & cparams, - const llama_ubatch & ubatch, - bool worst_case) const; + const llama_ubatch & ubatch) const; private: struct impl;