mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-12 22:23:13 +00:00
llama : use "stream" vs "virtual sequence"
ggml-ci
This commit is contained in:
@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||||
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||||
|
|
||||||
int32_t n_total_prompt = 0;
|
int32_t n_total_prompt = 0;
|
||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
@ -290,7 +290,6 @@ int main(int argc, char ** argv) {
|
|||||||
// all sequences have ended - clear the entire KV cache
|
// all sequences have ended - clear the entire KV cache
|
||||||
for (int i = 1; i <= n_clients; ++i) {
|
for (int i = 1; i <= n_clients; ++i) {
|
||||||
llama_memory_seq_rm(mem, i, -1, -1);
|
llama_memory_seq_rm(mem, i, -1, -1);
|
||||||
|
|
||||||
// but keep the system prompt
|
// but keep the system prompt
|
||||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char * LLAMA_HT = getenv("LLAMA_HT");
|
const char * LLAMA_HT = getenv("LLAMA_HT");
|
||||||
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
|
cparams.kv_unified = (LLAMA_HT && atoi(LLAMA_HT) > 0) ? false : true;
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
@ -270,7 +270,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve worst-case graph
|
// reserve worst-case graph
|
||||||
if (!hparams.vocab_only && memory) {
|
if (!hparams.vocab_only && memory) {
|
||||||
const uint32_t n_seqs = 1; // reserve worst-case graph for single-sequence batches
|
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
@ -303,7 +303,7 @@ llama_context::llama_context(
|
|||||||
|
|
||||||
// reserve with tg graph to get the number of splits and nodes
|
// reserve with tg graph to get the number of splits and nodes
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ struct llama_cparams {
|
|||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
uint32_t n_ubatch;
|
uint32_t n_ubatch;
|
||||||
uint32_t n_seq_max;
|
uint32_t n_seq_max;
|
||||||
uint32_t n_seq_virt;
|
|
||||||
int32_t n_threads; // number of threads to use for generation
|
int32_t n_threads; // number of threads to use for generation
|
||||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||||
|
|
||||||
@ -34,6 +33,7 @@ struct llama_cparams {
|
|||||||
bool no_perf;
|
bool no_perf;
|
||||||
bool warmup;
|
bool warmup;
|
||||||
bool op_offload;
|
bool op_offload;
|
||||||
|
bool kv_unified;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||||
|
|
||||||
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1033,9 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||||||
float kq_scale) const {
|
float kq_scale) const {
|
||||||
const bool v_trans = v->nb[1] > v->nb[2];
|
const bool v_trans = v->nb[1] > v->nb[2];
|
||||||
|
|
||||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
// split the batch into streams if needed
|
||||||
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||||
|
|
||||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
|
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
||||||
|
|
||||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||||
@ -1085,7 +1086,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
// recombine streams
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
|
||||||
} else {
|
} else {
|
||||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
|
|
||||||
@ -1130,7 +1132,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||||||
|
|
||||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
// recombine streams
|
||||||
|
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_stream);
|
||||||
|
|
||||||
if (!cparams.offload_kqv) {
|
if (!cparams.offload_kqv) {
|
||||||
// all nodes between the KV store and the attention output are run on the CPU
|
// all nodes between the KV store and the attention output are run on the CPU
|
||||||
@ -1207,13 +1210,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||||
|
|
||||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1455,7 +1458,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
@ -1463,7 +1466,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1477,7 +1480,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
@ -257,8 +257,8 @@ public:
|
|||||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
@ -293,10 +293,10 @@ public:
|
|||||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
@ -343,8 +343,8 @@ public:
|
|||||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seq, 1, n_seq]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
@ -18,17 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||||||
bool v_trans,
|
bool v_trans,
|
||||||
bool offload,
|
bool offload,
|
||||||
bool swa_full,
|
bool swa_full,
|
||||||
|
bool unified,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_seq_virt,
|
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
|
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||||
|
|
||||||
const uint32_t size_base = kv_size;
|
const uint32_t size_base = kv_size;
|
||||||
|
|
||||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
|
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
||||||
|
|
||||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||||
if (swa_full) {
|
if (swa_full) {
|
||||||
@ -42,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||||||
|
|
||||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||||
model, std::move(filter_base), type_k, type_v,
|
model, std::move(filter_base), type_k, type_v,
|
||||||
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
|
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||||
0, LLAMA_SWA_TYPE_NONE);
|
0, LLAMA_SWA_TYPE_NONE);
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||||
|
|
||||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||||
model, std::move(filter_swa), type_k, type_v,
|
model, std::move(filter_swa), type_k, type_v,
|
||||||
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
|
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||||
hparams.n_swa, hparams.swa_type);
|
hparams.n_swa, hparams.swa_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
|
|
||||||
// first try simple split
|
// first try simple split
|
||||||
do {
|
do {
|
||||||
if (n_seq_virt > 1) {
|
if (!unified) {
|
||||||
// requires equal splits, so we skip the simple split
|
// requires equal splits, so we skip the simple split
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
|
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
|
@ -20,9 +20,9 @@ public:
|
|||||||
bool v_trans,
|
bool v_trans,
|
||||||
bool offload,
|
bool offload,
|
||||||
bool swa_full,
|
bool swa_full,
|
||||||
|
bool unified,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_seq_virt,
|
|
||||||
uint32_t n_ubatch,
|
uint32_t n_ubatch,
|
||||||
uint32_t n_pad);
|
uint32_t n_pad);
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ public:
|
|||||||
private:
|
private:
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
const uint32_t n_seq_virt = 1;
|
const bool unified;
|
||||||
|
|
||||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||||
|
@ -23,14 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
ggml_type type_v,
|
ggml_type type_v,
|
||||||
bool v_trans,
|
bool v_trans,
|
||||||
bool offload,
|
bool offload,
|
||||||
|
bool unified,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_seq_virt,
|
|
||||||
uint32_t n_pad,
|
uint32_t n_pad,
|
||||||
uint32_t n_swa,
|
uint32_t n_swa,
|
||||||
llama_swa_type swa_type) :
|
llama_swa_type swa_type) :
|
||||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||||
n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||||
|
|
||||||
GGML_ASSERT(kv_size % n_pad == 0);
|
GGML_ASSERT(kv_size % n_pad == 0);
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()),
|
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
@ -65,25 +65,25 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
return it->second;
|
return it->second;
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
|
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
|
||||||
|
|
||||||
v_heads.resize(n_seq_virt);
|
v_heads.resize(n_stream);
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
v_heads[s] = 0;
|
v_heads[s] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
v_cells.resize(n_seq_virt);
|
v_cells.resize(n_stream);
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
v_cells[s].resize(kv_size);
|
v_cells[s].resize(kv_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// by default, all sequence ids are mapped to the 0th virtual sequence
|
// by default, all sequence ids are mapped to the 0th stream
|
||||||
seq_virt_idx.resize(LLAMA_MAX_SEQ, 0);
|
seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
|
||||||
|
|
||||||
if (n_seq_virt > 1) {
|
if (n_stream > 1) {
|
||||||
seq_virt_idx.resize(n_seq_virt, 0);
|
seq_to_stream.resize(n_stream, 0);
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
seq_virt_idx[s] = s;
|
seq_to_stream[s] = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,23 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
ggml_tensor * k;
|
ggml_tensor * k;
|
||||||
ggml_tensor * v;
|
ggml_tensor * v;
|
||||||
|
|
||||||
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt);
|
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
||||||
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt);
|
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
||||||
|
|
||||||
ggml_format_name(k, "cache_k_l%d", il);
|
ggml_format_name(k, "cache_k_l%d", il);
|
||||||
ggml_format_name(v, "cache_v_l%d", il);
|
ggml_format_name(v, "cache_v_l%d", il);
|
||||||
|
|
||||||
std::vector<ggml_tensor *> k_seq;
|
std::vector<ggml_tensor *> k_stream;
|
||||||
std::vector<ggml_tensor *> v_seq;
|
std::vector<ggml_tensor *> v_stream;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
||||||
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
||||||
}
|
}
|
||||||
|
|
||||||
map_layer_ids[il] = layers.size();
|
map_layer_ids[il] = layers.size();
|
||||||
|
|
||||||
layers.push_back({ il, k, v, k_seq, v_seq, });
|
layers.push_back({ il, k, v, k_stream, v_stream, });
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||||
@ -184,7 +184,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
const size_t memory_size_v = size_v_bytes();
|
const size_t memory_size_v = size_v_bytes();
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt,
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
|
||||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||||
}
|
}
|
||||||
@ -201,7 +201,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::clear(bool data) {
|
void llama_kv_cache_unified::clear(bool data) {
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
v_cells[s].reset();
|
v_cells[s].reset();
|
||||||
v_heads[s] = 0;
|
v_heads[s] = 0;
|
||||||
}
|
}
|
||||||
@ -214,8 +214,8 @@ void llama_kv_cache_unified::clear(bool data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
uint32_t new_head = cells.size();
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
@ -263,8 +263,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
const auto s0 = seq_virt_idx[seq_id_src];
|
const auto s0 = seq_to_stream[seq_id_src];
|
||||||
const auto s1 = seq_virt_idx[seq_id_dst];
|
const auto s1 = seq_to_stream[seq_id_dst];
|
||||||
|
|
||||||
if (s0 == s1) {
|
if (s0 == s1) {
|
||||||
auto & cells = v_cells[s0];
|
auto & cells = v_cells[s0];
|
||||||
@ -306,13 +306,13 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||||||
|
|
||||||
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
|
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
|
||||||
|
|
||||||
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
|
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (stream = %d) to %d (stream = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
|
||||||
|
|
||||||
for (uint32_t il = 0; il < layers.size(); ++il) {
|
for (uint32_t il = 0; il < layers.size(); ++il) {
|
||||||
const auto & layer = layers[il];
|
const auto & layer = layers[il];
|
||||||
|
|
||||||
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
|
ggml_backend_tensor_copy(layer.k_stream[s0], layer.k_stream[s1]);
|
||||||
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
|
ggml_backend_tensor_copy(layer.v_stream[s0], layer.v_stream[s1]);
|
||||||
|
|
||||||
// TODO: do we need synchronization here?
|
// TODO: do we need synchronization here?
|
||||||
}
|
}
|
||||||
@ -330,14 +330,14 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||||||
|
|
||||||
v_heads[s1] = v_heads[s0];
|
v_heads[s1] = v_heads[s0];
|
||||||
|
|
||||||
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
//for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
|
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
|
||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
uint32_t new_head = cells.size();
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
@ -356,8 +356,8 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
if (shift == 0) {
|
if (shift == 0) {
|
||||||
return;
|
return;
|
||||||
@ -398,7 +398,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
if (d == 1) {
|
if (d == 1) {
|
||||||
return;
|
return;
|
||||||
@ -429,13 +429,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
return cells.seq_pos_min(seq_id);
|
return cells.seq_pos_min(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
return cells.seq_pos_max(seq_id);
|
return cells.seq_pos_max(seq_id);
|
||||||
}
|
}
|
||||||
@ -451,7 +451,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
@ -487,9 +487,9 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|||||||
defrag_info dinfo;
|
defrag_info dinfo;
|
||||||
|
|
||||||
// see if we need to defrag
|
// see if we need to defrag
|
||||||
if (n_seq_virt == 1) {
|
if (n_stream == 1) {
|
||||||
// note : for now do not consider defrag for n_seq_virt > 1
|
// note : for now do not consider defrag for n_stream > 1
|
||||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
const auto & cells = v_cells[seq_to_stream[0]];
|
||||||
|
|
||||||
bool do_defrag = optimize;
|
bool do_defrag = optimize;
|
||||||
|
|
||||||
@ -551,8 +551,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|||||||
{
|
{
|
||||||
state_t state = { sinfo_new, v_heads, {} };
|
state_t state = { sinfo_new, v_heads, {} };
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
|
||||||
auto & cells = v_cells[sinfo_new.seq_id_virt[s]];
|
auto & cells = v_cells[sinfo_new.strm_id[s]];
|
||||||
|
|
||||||
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
|
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
|
||||||
}
|
}
|
||||||
@ -570,9 +570,9 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|||||||
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
||||||
const auto & sinfo = it->sinfo;
|
const auto & sinfo = it->sinfo;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
auto & cells = v_cells[sinfo.strm_id[s]];
|
||||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
auto & head = v_heads[sinfo.strm_id[s]];
|
||||||
|
|
||||||
cells.set(sinfo.idxs[s], it->v_cells[s]);
|
cells.set(sinfo.idxs[s], it->v_cells[s]);
|
||||||
head = it->v_heads_old[s];
|
head = it->v_heads_old[s];
|
||||||
@ -625,7 +625,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||||||
updated = true;
|
updated = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
auto & cells = v_cells[s];
|
auto & cells = v_cells[s];
|
||||||
|
|
||||||
cells.reset_shift();
|
cells.reset_shift();
|
||||||
@ -635,9 +635,9 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||||||
if (!dinfo.empty()) {
|
if (!dinfo.empty()) {
|
||||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||||
|
|
||||||
// note: for now do not consider defrag for n_seq_virt > 1
|
// note: for now do not consider defrag for n_stream > 1
|
||||||
auto & cells = v_cells[seq_virt_idx[0]];
|
auto & cells = v_cells[seq_to_stream[0]];
|
||||||
auto & head = v_heads[seq_virt_idx[0]];
|
auto & head = v_heads[seq_to_stream[0]];
|
||||||
|
|
||||||
// apply moves:
|
// apply moves:
|
||||||
{
|
{
|
||||||
@ -687,7 +687,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||||||
|
|
||||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
const auto & cells = v_cells[seq_virt_idx[1]];
|
const auto & cells = v_cells[seq_to_stream[1]];
|
||||||
|
|
||||||
const uint32_t head_cur = v_heads[1];
|
const uint32_t head_cur = v_heads[1];
|
||||||
|
|
||||||
@ -752,7 +752,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
uint32_t n_tokens = ubatch.n_tokens;
|
uint32_t n_tokens = ubatch.n_tokens;
|
||||||
uint32_t n_seqs = 1;
|
uint32_t n_seqs = 1;
|
||||||
|
|
||||||
if (n_seq_virt > 1) {
|
if (n_stream > 1) {
|
||||||
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
|
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
|
||||||
|
|
||||||
n_seqs = ubatch.n_seqs_unq;
|
n_seqs = ubatch.n_seqs_unq;
|
||||||
@ -760,10 +760,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
}
|
}
|
||||||
|
|
||||||
slot_info res = {
|
slot_info res = {
|
||||||
/*.s0 =*/ LLAMA_MAX_SEQ,
|
/*.s0 =*/ LLAMA_MAX_SEQ,
|
||||||
/*.s1 =*/ 0,
|
/*.s1 =*/ 0,
|
||||||
/*.seq_id_virt =*/ { },
|
/*.strm_id =*/ { },
|
||||||
/*.idxs =*/ { },
|
/*.idxs =*/ { },
|
||||||
};
|
};
|
||||||
|
|
||||||
res.resize(n_seqs);
|
res.resize(n_seqs);
|
||||||
@ -771,20 +771,20 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||||
const auto seq_id = ubatch.seq_id_unq[s];
|
const auto seq_id = ubatch.seq_id_unq[s];
|
||||||
|
|
||||||
if (n_seq_virt > 1) {
|
if (n_stream > 1) {
|
||||||
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
|
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
|
||||||
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
res.s0 = std::min<llama_seq_id>(res.s0, seq_virt_idx[seq_id]);
|
res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
|
||||||
res.s1 = std::max<llama_seq_id>(res.s1, seq_virt_idx[seq_id]);
|
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
|
||||||
|
|
||||||
res.seq_id_virt[s] = seq_virt_idx[seq_id];
|
res.strm_id[s] = seq_to_stream[seq_id];
|
||||||
res.idxs[s].resize(n_tokens);
|
res.idxs[s].resize(n_tokens);
|
||||||
|
|
||||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
uint32_t head_cur = v_heads[seq_virt_idx[seq_id]];
|
uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
// if we have enough unused cells before the current head ->
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
@ -891,13 +891,13 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size());
|
assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
||||||
const uint32_t i = s*sinfo.size() + ii;
|
const uint32_t i = s*sinfo.size() + ii;
|
||||||
|
|
||||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
auto & cells = v_cells[sinfo.strm_id[s]];
|
||||||
|
|
||||||
const auto idx = sinfo.idxs.at(s).at(ii);
|
const auto idx = sinfo.idxs.at(s).at(ii);
|
||||||
|
|
||||||
@ -928,9 +928,9 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(s < seq_virt_idx.size());
|
GGML_ASSERT(s < seq_to_stream.size());
|
||||||
|
|
||||||
auto & cells = v_cells[seq_virt_idx[s]];
|
auto & cells = v_cells[seq_to_stream[s]];
|
||||||
|
|
||||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||||
@ -941,8 +941,8 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
auto & head = v_heads[sinfo.strm_id[s]];
|
||||||
|
|
||||||
head = sinfo.idxs[s].back() + 1;
|
head = sinfo.idxs[s].back() + 1;
|
||||||
}
|
}
|
||||||
@ -953,15 +953,19 @@ bool llama_kv_cache_unified::get_can_shift() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::get_size() const {
|
uint32_t llama_kv_cache_unified::get_size() const {
|
||||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
const auto & cells = v_cells[seq_to_stream[0]];
|
||||||
|
|
||||||
return cells.size();
|
return cells.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_kv_cache_unified::get_n_stream() const {
|
||||||
|
return n_stream;
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::get_has_shift() const {
|
bool llama_kv_cache_unified::get_has_shift() const {
|
||||||
bool result = false;
|
bool result = false;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
result |= v_cells[s].get_has_shift();
|
result |= v_cells[s].get_has_shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -971,7 +975,7 @@ bool llama_kv_cache_unified::get_has_shift() const {
|
|||||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||||
uint32_t result = 0;
|
uint32_t result = 0;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
const auto & cells = v_cells[s];
|
const auto & cells = v_cells[s];
|
||||||
|
|
||||||
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
||||||
@ -1053,7 +1057,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
// will be removed when ggml_set_rows() is adopted by all backends
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported");
|
||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||||
n_tokens*n_embd_k_gqa,
|
n_tokens*n_embd_k_gqa,
|
||||||
@ -1097,7 +1101,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
// will be removed when ggml_set_rows() is adopted by all backends
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported");
|
||||||
|
|
||||||
ggml_tensor * v_view = nullptr;
|
ggml_tensor * v_view = nullptr;
|
||||||
|
|
||||||
@ -1148,13 +1152,13 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
|
|||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
|
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
int64_t * data = (int64_t *) dst->data;
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
const int64_t offs = sinfo.strm_id[s]*get_size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
||||||
@ -1168,14 +1172,14 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
|||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
|
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
int64_t * data = (int64_t *) dst->data;
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
const int64_t offs = sinfo.strm_id[s]*get_size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
||||||
@ -1187,8 +1191,8 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
|||||||
|
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
const int64_t offs = sinfo.seq_id_virt[s]*kv_size*n_embd_v_gqa;
|
const int64_t offs = sinfo.strm_id[s]*kv_size*n_embd_v_gqa;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
@ -1204,7 +1208,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
|||||||
|
|
||||||
int32_t * data = (int32_t *) dst->data;
|
int32_t * data = (int32_t *) dst->data;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
const auto & cells = v_cells[s];
|
const auto & cells = v_cells[s];
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
@ -1219,13 +1223,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
float * data = (float *) dst->data;
|
float * data = (float *) dst->data;
|
||||||
|
|
||||||
const int64_t n_kv = dst->ne[0];
|
const int64_t n_kv = dst->ne[0];
|
||||||
const int64_t n_seq_virt = dst->ne[3]; // num virtual sequences in the current ubatch
|
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens%n_seq_virt == 0);
|
GGML_ASSERT(n_tokens%n_stream == 0);
|
||||||
|
|
||||||
const int64_t n_tokens_per_seq = n_tokens/n_seq_virt;
|
// n_tps == n_tokens_per_stream
|
||||||
const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD);
|
const int64_t n_tps = n_tokens/n_stream;
|
||||||
|
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
|
||||||
|
|
||||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||||
@ -1240,13 +1245,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
// xxxxx-----
|
// xxxxx-----
|
||||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||||
for (uint32_t h = 0; h < 1; ++h) {
|
for (uint32_t h = 0; h < 1; ++h) {
|
||||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) {
|
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
||||||
const uint32_t i = s*n_tokens_per_seq + ii;
|
const uint32_t i = s*n_tps + ii;
|
||||||
|
|
||||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||||
|
|
||||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
const llama_pos p1 = ubatch->pos[i];
|
const llama_pos p1 = ubatch->pos[i];
|
||||||
|
|
||||||
@ -1278,14 +1283,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f;
|
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// mask padded tokens
|
// mask padded tokens
|
||||||
if (data) {
|
if (data) {
|
||||||
for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) {
|
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
|
||||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY;
|
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1297,7 +1302,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
const int64_t n_tokens = ubatch->n_tokens;
|
const int64_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences");
|
GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
@ -1406,7 +1411,7 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt]
|
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||||
|
|
||||||
const llama_kv_cache_unified * kv_self;
|
const llama_kv_cache_unified * kv_self;
|
||||||
};
|
};
|
||||||
@ -1430,7 +1435,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||||
|
|
||||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt);
|
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||||
ggml_set_input(inp->k_shift);
|
ggml_set_input(inp->k_shift);
|
||||||
|
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
@ -1446,7 +1451,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||||||
|
|
||||||
ggml_tensor * k =
|
ggml_tensor * k =
|
||||||
ggml_view_3d(ctx, layer.k,
|
ggml_view_3d(ctx, layer.k,
|
||||||
n_embd_head_k, n_head_kv, get_size()*n_seq_virt,
|
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
||||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||||
0);
|
0);
|
||||||
@ -1468,7 +1473,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||||||
const defrag_info & dinfo) const {
|
const defrag_info & dinfo) const {
|
||||||
auto res = std::make_unique<llm_graph_result>();
|
auto res = std::make_unique<llm_graph_result>();
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||||
|
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
@ -1614,7 +1619,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||||
|
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
@ -1766,7 +1771,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not implemented yet");
|
||||||
|
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
@ -1824,7 +1829,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not implemented yet");
|
||||||
|
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
@ -1854,7 +1859,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not implemented yet");
|
||||||
|
|
||||||
const auto & cells = v_cells[0];
|
const auto & cells = v_cells[0];
|
||||||
|
|
||||||
@ -1945,7 +1950,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not implemented yet");
|
||||||
|
|
||||||
auto & cells = v_cells[0];
|
auto & cells = v_cells[0];
|
||||||
auto & head = v_heads[0];
|
auto & head = v_heads[0];
|
||||||
@ -2041,7 +2046,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not implemented yet");
|
||||||
|
|
||||||
auto & cells = v_cells[0];
|
auto & cells = v_cells[0];
|
||||||
auto & head = v_heads[0];
|
auto & head = v_heads[0];
|
||||||
@ -2182,13 +2187,17 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
|
|
||||||
|
const uint32_t n_stream = kv->get_n_stream();
|
||||||
|
|
||||||
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
||||||
// note: this is slot info for a single-virt-sequence batch. therefore we can use it to compute worst-case graphs
|
|
||||||
// for the respective batch contents that would fit to this setup
|
|
||||||
sinfos.resize(1);
|
sinfos.resize(1);
|
||||||
sinfos[0].seq_id_virt.resize(1, 0);
|
sinfos[0].s0 = 0;
|
||||||
sinfos[0].idxs.resize(1);
|
sinfos[0].s1 = n_stream - 1;
|
||||||
sinfos[0].idxs[0].resize(1, 0);
|
sinfos[0].idxs.resize(n_stream);
|
||||||
|
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||||
|
sinfos[0].strm_id.push_back(s);
|
||||||
|
sinfos[0].idxs[s].resize(1, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
|
@ -41,11 +41,12 @@ public:
|
|||||||
// data for ggml_set_rows
|
// data for ggml_set_rows
|
||||||
using idx_vec_t = std::vector<uint32_t>;
|
using idx_vec_t = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
// number of streams: ns = s1 - s0 + 1
|
||||||
llama_seq_id s0;
|
llama_seq_id s0;
|
||||||
llama_seq_id s1;
|
llama_seq_id s1;
|
||||||
|
|
||||||
std::vector<llama_seq_id> seq_id_virt;
|
std::vector<llama_seq_id> strm_id; // [ns]
|
||||||
std::vector<idx_vec_t> idxs;
|
std::vector<idx_vec_t> idxs; // [ns]
|
||||||
|
|
||||||
uint32_t head() const {
|
uint32_t head() const {
|
||||||
GGML_ASSERT(idxs.size() == 1);
|
GGML_ASSERT(idxs.size() == 1);
|
||||||
@ -54,18 +55,18 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void resize(size_t n) {
|
void resize(size_t n) {
|
||||||
seq_id_virt.resize(n);
|
strm_id.resize(n);
|
||||||
idxs.resize(n);
|
idxs.resize(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size() const {
|
size_t size() const {
|
||||||
GGML_ASSERT(idxs.size() == seq_id_virt.size());
|
GGML_ASSERT(idxs.size() == strm_id.size());
|
||||||
|
|
||||||
return idxs.at(0).size();
|
return idxs.at(0).size();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t n_seq_virt() const {
|
size_t n_stream() const {
|
||||||
return seq_id_virt.size();
|
return strm_id.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
@ -86,9 +87,9 @@ public:
|
|||||||
ggml_type type_v,
|
ggml_type type_v,
|
||||||
bool v_trans,
|
bool v_trans,
|
||||||
bool offload,
|
bool offload,
|
||||||
|
bool unified,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
uint32_t n_seq_virt,
|
|
||||||
uint32_t n_pad,
|
uint32_t n_pad,
|
||||||
uint32_t n_swa,
|
uint32_t n_swa,
|
||||||
llama_swa_type swa_type);
|
llama_swa_type swa_type);
|
||||||
@ -130,7 +131,8 @@ public:
|
|||||||
// llama_kv_cache_unified specific API
|
// llama_kv_cache_unified specific API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_size() const;
|
uint32_t get_size() const;
|
||||||
|
uint32_t get_n_stream() const;
|
||||||
|
|
||||||
bool get_has_shift() const;
|
bool get_has_shift() const;
|
||||||
|
|
||||||
@ -193,14 +195,14 @@ private:
|
|||||||
ggml_tensor * k;
|
ggml_tensor * k;
|
||||||
ggml_tensor * v;
|
ggml_tensor * v;
|
||||||
|
|
||||||
std::vector<ggml_tensor *> k_seq;
|
std::vector<ggml_tensor *> k_stream;
|
||||||
std::vector<ggml_tensor *> v_seq;
|
std::vector<ggml_tensor *> v_stream;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
|
|
||||||
const uint32_t n_seq_max = 1;
|
const uint32_t n_seq_max = 1;
|
||||||
const uint32_t n_seq_virt = 1;
|
const uint32_t n_stream = 1;
|
||||||
|
|
||||||
// required padding
|
// required padding
|
||||||
const uint32_t n_pad = 1;
|
const uint32_t n_pad = 1;
|
||||||
@ -226,8 +228,8 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_kv_cells_unified> v_cells;
|
std::vector<llama_kv_cells_unified> v_cells;
|
||||||
|
|
||||||
// maps from a sequence id to a virtual sequence id
|
// maps from a sequence id to a stream id
|
||||||
std::vector<uint32_t> seq_virt_idx;
|
std::vector<uint32_t> seq_to_stream;
|
||||||
|
|
||||||
std::vector<kv_layer> layers;
|
std::vector<kv_layer> layers;
|
||||||
|
|
||||||
|
@ -14710,7 +14710,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
} else {
|
} else {
|
||||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||||
|
|
||||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||||
|
|
||||||
|
if (!cparams.kv_unified) {
|
||||||
|
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||||
|
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||||
|
|
||||||
|
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
|
||||||
|
} else {
|
||||||
|
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||||
|
|
||||||
|
cparams.n_ctx = n_ctx_per_stream;
|
||||||
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
@ -14724,9 +14735,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
!cparams.flash_attn,
|
!cparams.flash_attn,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
params.swa_full,
|
params.swa_full,
|
||||||
cparams.n_ctx,
|
cparams.kv_unified,
|
||||||
|
n_ctx_per_stream,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
cparams.n_seq_virt,
|
|
||||||
cparams.n_ubatch,
|
cparams.n_ubatch,
|
||||||
padding);
|
padding);
|
||||||
} else {
|
} else {
|
||||||
@ -14739,9 +14750,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
params.type_v,
|
params.type_v,
|
||||||
!cparams.flash_attn,
|
!cparams.flash_attn,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
cparams.n_ctx,
|
cparams.kv_unified,
|
||||||
|
n_ctx_per_stream,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
cparams.n_seq_virt,
|
|
||||||
padding,
|
padding,
|
||||||
hparams.n_swa,
|
hparams.n_swa,
|
||||||
hparams.swa_type);
|
hparams.swa_type);
|
||||||
|
@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
|
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||||
|
|
||||||
// decode in batches of ctx_params.n_batch tokens
|
// decode in batches of ctx_params.n_batch tokens
|
||||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||||
@ -119,22 +119,18 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||||
|
|
||||||
//if (n_ctx_req > n_kv_max) {
|
if (n_ctx_req > n_kv_max) {
|
||||||
// continue;
|
continue;
|
||||||
//}
|
}
|
||||||
|
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||||
for (int i = 0; i < pp; ++i) {
|
for (int i = 0; i < pp; ++i) {
|
||||||
common_batch_add(batch, 0, i, { j }, false);
|
common_batch_add(batch, 0, i, { j }, i == pp - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens > 0) {
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
llama_memory_clear(mem, false);
|
llama_memory_clear(mem, false);
|
||||||
|
Reference in New Issue
Block a user