mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
llama : add high-throughput mode (#14363)
* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -98,10 +98,20 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
{
|
||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||
|
||||
if (!supports_set_rows && !cparams.kv_unified) {
|
||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||
cparams.kv_unified = true;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
@@ -112,6 +122,7 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
@@ -267,7 +278,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
@@ -300,7 +311,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -311,6 +322,10 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
||||
//
|
||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||
//
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
@@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
@@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
const int32_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = balloc->get_n_tokens();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
||||
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
@@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
@@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
|
||||
/*.no_perf =*/ true,
|
||||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
Reference in New Issue
Block a user