diff --git a/common/common.cpp b/common/common.cpp index b75195956..ec95f32d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -952,7 +952,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { - LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__); + LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index dbc9231ac..6b2a11ad6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -341,6 +341,10 @@ uint32_t llama_context::n_ubatch() const { return cparams.n_ubatch; } +uint32_t llama_context::n_seq_max() const { + return cparams.n_seq_max; +} + uint32_t llama_context::n_threads() const { return cparams.n_threads; } @@ -353,6 +357,20 @@ int32_t llama_context::max_nodes() const { return std::max(8192, 5*model.n_tensors()); } +llama_kv_cache * llama_context::get_kv_self() { + LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); + return nullptr; +} + +const llama_kv_cache * llama_context::get_kv_self() const { + LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); + return nullptr; +} + +void llama_context::kv_self_update() { + LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); +} + enum llama_pooling_type llama_context::pooling_type() const { return cparams.pooling_type; } @@ -566,6 +584,9 @@ ggml_cgraph * llama_context::graph_init() { inp_mean = nullptr; inp_cls = nullptr; + inp_kq_mask = nullptr; + inp_kq_mask_cnv = nullptr; + struct ggml_init_params params = { /*.mem_size =*/ buf_compute_meta.size(), /*.mem_buffer =*/ buf_compute_meta.data(), @@ -612,179 +633,11 @@ enum ggml_status llama_context::graph_compute( return status; } -void llama_context::input_set(const llama_ubatch & ubatch) { - const llama_hparams & hparams = model.hparams; - - if (ubatch.token) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens)); - } - - if (ubatch.embd) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd)); - } - - if (ubatch.pos && inp_pos) { - const int64_t n_tokens = ubatch.n_tokens; - - ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos)); - } - - if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - //GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs"); - - if (!inp_out_ids) { - LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__); - } else { - const int64_t n_tokens = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer)); - int32_t * data = (int32_t *) inp_out_ids->data; - - if (n_outputs == n_tokens) { - for (int i = 0; i < n_tokens; ++i) { - data[i] = i; - } - } else if (ubatch.output) { - int32_t n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - if (ubatch.output[i]) { - data[n_outputs++] = i; - } - } - // the graph needs to have been passed the correct number of outputs - GGML_ASSERT(n_outputs == n_outputs); - } else if (n_outputs == 1) { - // only keep last output - data[0] = n_tokens - 1; - } else { - GGML_ASSERT(n_outputs == 0); - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_mean); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer)); - - float * data = (float *) inp_mean->data; - memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean)); - - std::vector sum(n_tokens, 0); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); - - sum[seq_id] += ubatch.n_seq_tokens; - } - - std::vector div(n_tokens, 0.0f); - for (int i = 0; i < n_tokens; ++i) { - const uint64_t s = sum[i]; - if (s > 0) { - div[i] = 1.0f/float(s); - } - } - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - for (int i = 0; i < n_seq_tokens; ++i) { - data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; - } - } - } - - if (cparams.embeddings && ( - cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || - cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); - - uint32_t * data = (uint32_t *) inp_cls->data; - memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos == 0) { - data[seq_id] = s*n_seq_tokens + i; - } - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - const int64_t n_tokens = ubatch.n_tokens; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - - GGML_ASSERT(inp_cls); - GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); - - uint32_t * data = (uint32_t *) inp_cls->data; - memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); - - std::vector last_pos(n_tokens, -1); - std::vector last_row(n_tokens, -1); - - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - - // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true - GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); - - for (int i = 0; i < n_seq_tokens; ++i) { - const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; - - if (pos >= last_pos[seq_id]) { - last_pos[seq_id] = pos; - last_row[seq_id] = s*n_seq_tokens + i; - } - } - } - - for (int i = 0; i < n_tokens; ++i) { - if (last_row[i] >= 0) { - data[i] = last_row[i]; - } - } - } - - GGML_ASSERT( - // (!a || b) is a logical implication (a -> b) - // !hparams.causal_attn -> !cparams.causal_attn - (hparams.causal_attn || !cparams.causal_attn) && - "causal attention is not supported by this model" - ); -} - int32_t llama_context::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; - const int64_t n_outputs_max = std::max(n_outputs, cparams.n_seq_max); + const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); @@ -887,6 +740,348 @@ void llama_context::output_reorder() { } } +int llama_context::encode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(inp_batch, 0); + + const llama_batch & batch = batch_allocr.batch; + + const int32_t n_tokens = batch.n_tokens; + + const auto & hparams = model.hparams; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int32_t i = 0; i < n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + n_queued_tokens += n_tokens; + + const int64_t n_embd = hparams.n_embd; + + sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (int32_t i = 0; i < n_tokens; ++i) { + output_ids[i] = i; + } + + n_outputs = n_tokens; + + GGML_ASSERT(need_reserve == false); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set(ubatch); + + const auto compute_status = graph_compute(gf, n_tokens > 1); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + + auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + + // extract embeddings + if (t_embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + GGML_ASSERT(embd != nullptr); + + // extract token embeddings + float * embd_out = embd; + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); + + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + for (int32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + +int llama_context::decode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(inp_batch, 0); + + const llama_batch & batch = batch_allocr.batch; + + const auto & vocab = model.vocab; + const auto & hparams = model.hparams; + + const int32_t n_vocab = vocab.n_tokens(); + + const int64_t n_tokens = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int64_t i = 0; i < n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); + throw std::runtime_error("invalid token"); + } + } + } + + // micro-batching is not possible without KV cache + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "llama_context requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = 0; + + // count outputs + if (batch.logits && !embd_pooled) { + for (uint32_t i = 0; i < n_tokens; ++i) { + n_outputs_all += batch.logits[i] != 0; + } + } else if (logits_all || embd_pooled) { + n_outputs_all = n_tokens; + } else { + // keep last output only + n_outputs_all = 1; + } + + const bool logits_all = n_outputs_all == n_tokens; + + sbatch.from_batch(batch, n_embd, + /* simple_split */ true, + /* logits_all */ logits_all); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + return -2; + }; + + n_outputs = n_outputs_all; + + GGML_ASSERT(need_reserve == false); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set(ubatch); + + const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); + if (compute_status != GGML_STATUS_SUCCESS) { + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + } + + auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; + auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + + if (t_embd && res.t_embd_pooled) { + t_embd = res.t_embd_pooled; + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits; + + if (n_outputs) { + GGML_ASSERT(n_outputs <= n_outputs_all); + GGML_ASSERT(n_outputs*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd; + + if (n_outputs) { + GGML_ASSERT(n_outputs <= n_outputs_all); + GGML_ASSERT(n_outputs*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // set output mappings + { + bool sorted_output = true; + + GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + + for (int64_t i = 0; i < n_outputs_all; ++i) { + int64_t out_id = sbatch.out_ids[i]; + output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + if (sorted_output) { + sbatch.out_ids.clear(); + } + } + + // wait for the computation to finish (automatically done when obtaining the model output) + //synchronize(); + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + void llama_context::build_cb( ggml_tensor * cur, const char * name, @@ -922,19 +1117,6 @@ void llama_context::build_cb( } } -llama_perf_context_data llama_context::perf_get_data() const { - llama_perf_context_data data = {}; - - data.t_start_ms = 1e-3 * t_start_us; - data.t_load_ms = 1e-3 * t_load_us; - data.t_p_eval_ms = 1e-3 * t_p_eval_us; - data.t_eval_ms = 1e-3 * t_eval_us; - data.n_p_eval = std::max(1, n_p_eval); - data.n_eval = std::max(1, n_eval); - - return data; -} - ggml_tensor * llama_context::build_cvec( ggml_context * ctx0, ggml_tensor * cur, @@ -1002,7 +1184,7 @@ ggml_tensor * llama_context::build_rope_factors(int il) { const auto & hparams = model.hparams; // choose long/short freq factors based on the context size - const auto n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + const auto n_ctx_per_seq = n_ctx() / n_seq_max(); if (model.layers[il].rope_freqs != nullptr) { return model.layers[il].rope_freqs; @@ -1153,6 +1335,166 @@ ggml_tensor * llama_context::build_inp_cls( return inp_cls; } +ggml_tensor * llama_context::build_attn( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case) { + const auto & hparams = model.hparams; + + const auto & n_ctx = cparams.n_ctx; + + //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + const auto & kq_mask = inp_kq_mask_cnv; + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + //const auto & n_embd_head_k = hparams.n_embd_head_k; + const auto & n_embd_head_v = hparams.n_embd_head_v; + + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch + GGML_UNUSED(worst_case); + const auto n_kv = n_tokens; + + struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, k_cur, 0, 2, 1, 3)); + //cb(k, "k", il); + + struct ggml_tensor * cur; + + //if (cparams.flash_attn) { + if (false) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3)); + v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv); + + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + //cb(kq, "kq", il); + + // note: this op tends to require high floating point range + // while for some models F16 is enough, for others it is not, so we default to F32 here + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx0, kq, 30); + } + + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh(ctx0, kq); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + } + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + //cb(kq, "kq_soft_max_ext", il); + + // split cached v into n_head heads + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens))); + + v = ggml_reshape_3d(ctx0, v, n_kv, n_embd_head_v, n_head_kv); + //cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + //cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + //cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + //cb(cur, "kqv_merged_cont", il); + + if (!cparams.offload_kqv) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + } + } + + ggml_build_forward_expand(gf, cur); + + if (wo) { + cur = build_lora_mm(ctx0, wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +void llama_context::build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa, + bool worst_case) { + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch + GGML_UNUSED(causal); + GGML_UNUSED(swa); + GGML_UNUSED(worst_case); + + inp_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp_kq_mask, "KQ_mask", -1); + ggml_set_input(inp_kq_mask); + + inp_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_kq_mask, GGML_TYPE_F16) : inp_kq_mask; +} + +// +// perf +// + +llama_perf_context_data llama_context::perf_get_data() const { + llama_perf_context_data data = {}; + + data.t_start_ms = 1e-3 * t_start_us; + data.t_load_ms = 1e-3 * t_load_us; + data.t_p_eval_ms = 1e-3 * t_p_eval_us; + data.t_eval_ms = 1e-3 * t_eval_us; + data.n_p_eval = std::max(1, n_p_eval); + data.n_eval = std::max(1, n_eval); + + return data; +} + +void llama_context::perf_reset() { + t_start_us = ggml_time_us(); + t_eval_us = n_eval = 0; + t_p_eval_us = n_p_eval = 0; +} + // // state // @@ -1620,10 +1962,277 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_ return io.n_bytes(); } -void llama_context::perf_reset() { - t_start_us = ggml_time_us(); - t_eval_us = n_eval = 0; - t_p_eval_us = n_p_eval = 0; +// +// input +// + +void llama_context::input_set(const llama_ubatch & ubatch) { + const llama_hparams & hparams = model.hparams; + + if (ubatch.token) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp_tokens)); + } + + if (ubatch.embd) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp_embd)); + } + + if (ubatch.pos && inp_pos) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos)); + } + + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(inp_out_ids && "every model that can must skip unused outputs"); + + if (!inp_out_ids) { + LLAMA_LOG_WARN("%s: 'inp_out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_out_ids->buffer)); + int32_t * data = (int32_t *) inp_out_ids->data; + + if (n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch.output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch.output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(n_outputs == n_outputs); + } else if (n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(n_outputs == 0); + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_mean); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_mean->buffer)); + + float * data = (float *) inp_mean->data; + memset(inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(inp_mean)); + + std::vector sum(n_tokens, 0); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch.n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } + + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); + + uint32_t * data = (uint32_t *) inp_cls->data; + memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_cls->buffer)); + + uint32_t * data = (uint32_t *) inp_cls->data; + memset(inp_cls->data, 0, n_tokens * ggml_element_size(inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + + if (inp_kq_mask) { + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn) { + // TODO: need to use the batch directly to construct the masks + GGML_ABORT("TODO"); + + //const int64_t n_kv = ubatch.n_tokens; + //const int64_t n_tokens = ubatch.n_tokens; + //const int64_t n_seq_tokens = ubatch.n_seq_tokens; + //const int64_t n_seqs = ubatch.n_seqs; + + //float * data = nullptr; + + //if (inp_kq_mask) { + // GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); + // data = (float *) inp_kq_mask->data; + //} + + //// For causal attention, use only the previous KV cells + //// of the correct sequence for each token of the ubatch. + //// It's assumed that if a token in the batch has multiple sequences, they are equivalent. + //for (int h = 0; h < 1; ++h) { + // for (int s = 0; s < n_seqs; ++s) { + // const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // for (int j = 0; j < n_seq_tokens; ++j) { + // const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; + + // for (int i = 0; i < n_kv; ++i) { + // float f; + // if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + // f = -INFINITY; + // } else { + // if (hparams.use_alibi) { + // f = -std::abs(kv_self.cells[i].pos - pos); + // } else { + // f = 0.0f; + // } + // } + + // if (data) { + // data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + // } + // } + // } + // } + + // if (data) { + // for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + // for (int j = 0; j < n_kv; ++j) { + // data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + // } + // } + // } + //} + } else { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_stride = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); + + float * data = (float *) inp_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } + + GGML_ASSERT( + // (!a || b) is a logical implication (a -> b) + // !hparams.causal_attn -> !cparams.causal_attn + (hparams.causal_attn || !cparams.causal_attn) && + "causal attention is not supported by this model" + ); } // @@ -1684,11 +2293,6 @@ llama_context_kv_self::llama_context_kv_self( llama_context_kv_self::~llama_context_kv_self() = default; -uint32_t llama_context_kv_self::n_seq_max() const { - // TODO: add notion of n_seq_max to llama_kv_cache and use it here - return kv_self.size; -} - llama_kv_cache * llama_context_kv_self::get_kv_self() { return &kv_self; } @@ -1698,14 +2302,15 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const { } ggml_cgraph * llama_context_kv_self::graph_init() { - inp_KQ_mask = nullptr; - inp_KQ_mask_cnv = nullptr; - inp_KQ_mask_swa = nullptr; - inp_KQ_mask_swa_cnv = nullptr; - inp_KQ_mask_cross = nullptr; - inp_k_shift = nullptr; - inp_embd_enc = nullptr; - inp_pos_bucket = nullptr; + inp_embd_enc = nullptr; + inp_pos_bucket = nullptr; + inp_kq_mask_cross = nullptr; + + inp_self_kq_mask = nullptr; + inp_self_kq_mask_cnv = nullptr; + inp_self_kq_mask_swa = nullptr; + inp_self_kq_mask_swa_cnv = nullptr; + inp_self_k_shift = nullptr; return llama_context::graph_init(); } @@ -1979,8 +2584,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { const auto & n_ubatch = cparams.n_ubatch; - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - if (kv_self.recurrent) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) @@ -2033,7 +2636,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self.get_padding(cparams); + const uint32_t pad = get_ctx_padding(cparams); kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } @@ -2246,10 +2849,10 @@ uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) c void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; - if (inp_k_shift) { - assert(ggml_backend_buffer_is_host(inp_k_shift->buffer)); + if (inp_self_k_shift) { + assert(ggml_backend_buffer_is_host(inp_self_k_shift->buffer)); - int32_t * data = (int32_t *) inp_k_shift->data; + int32_t * data = (int32_t *) inp_self_k_shift->data; for (uint32_t i = 0; i < kv_self.size; ++i) { data[i] = kv_self.cells[i].delta; @@ -2262,7 +2865,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { // call base functionality llama_context::input_set(ubatch); - if (inp_KQ_mask || inp_KQ_mask_swa) { + if (inp_self_kq_mask || inp_self_kq_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn && !is_encoding) { const int64_t n_kv = kv_self.n; @@ -2273,14 +2876,14 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { float * data = nullptr; float * data_swa = nullptr; - if (inp_KQ_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp_KQ_mask->buffer)); - data = (float *) inp_KQ_mask->data; + if (inp_self_kq_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask->buffer)); + data = (float *) inp_self_kq_mask->data; } - if (inp_KQ_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(inp_KQ_mask_swa->buffer)); - data_swa = (float *) inp_KQ_mask_swa->data; + if (inp_self_kq_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask_swa->buffer)); + data_swa = (float *) inp_self_kq_mask_swa->data; } // For causal attention, use only the previous KV cells @@ -2341,11 +2944,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seqs = ubatch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = hparams.causal_attn && !is_encoding ? kv_self.n : n_tokens; + const int64_t n_stride = hparams.causal_attn && !is_encoding ? kv_self.n : n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(inp_KQ_mask->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_self_kq_mask->buffer)); - float * data = (float *) inp_KQ_mask->data; + float * data = (float *) inp_self_kq_mask->data; for (int h = 0; h < 1; ++h) { for (int s1 = 0; s1 < n_seqs; ++s1) { @@ -2442,14 +3045,14 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { ggml_backend_tensor_set(inp_embd_enc, embd_enc.data(), 0, ggml_nbytes(inp_embd_enc)); } - if (!is_encoding && inp_KQ_mask_cross) { + if (!is_encoding && inp_kq_mask_cross) { const int64_t n_output_enc = embd_enc.size() / hparams.n_embd; const int64_t n_tokens = ubatch.n_tokens; - GGML_ASSERT(ggml_backend_buffer_is_host(inp_KQ_mask_cross->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask_cross->buffer)); GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing - float * data = (float *) inp_KQ_mask_cross->data; + float * data = (float *) inp_kq_mask_cross->data; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { @@ -2529,11 +3132,11 @@ void llama_context_kv_self::kv_self_update() { } } -ggml_tensor * llama_context_kv_self::build_inp_k_shift(ggml_context * ctx0) { - inp_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx()); - ggml_set_input(inp_k_shift); +ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) { + inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx()); + ggml_set_input(inp_self_k_shift); - return inp_k_shift; + return inp_self_k_shift; } void llama_context_kv_self::build_attn_inp( @@ -2542,28 +3145,28 @@ void llama_context_kv_self::build_attn_inp( bool causal, bool swa, bool worst_case) { - const auto & hparams = model.hparams; - const auto n_kv = worst_case ? kv_self.size : kv_self.n; - inp_KQ_mask = causal + inp_self_kq_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_KQ_mask, "KQ_mask", -1); - ggml_set_input(inp_KQ_mask); + //cb(inp_self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp_self_kq_mask); - inp_KQ_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_KQ_mask, GGML_TYPE_F16) : inp_KQ_mask; + inp_self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_self_kq_mask, GGML_TYPE_F16) : inp_self_kq_mask; if (swa) { + const auto & hparams = model.hparams; + GGML_ASSERT(hparams.n_swa > 0); - inp_KQ_mask_swa = causal + inp_self_kq_mask_swa = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_KQ_mask_swa, "KQ_mask_swa", -1); - ggml_set_input(inp_KQ_mask_swa); + //cb(inp_self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp_self_kq_mask_swa); - inp_KQ_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_KQ_mask_swa, GGML_TYPE_F16) : inp_KQ_mask_swa; + inp_self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_self_kq_mask_swa, GGML_TYPE_F16) : inp_self_kq_mask_swa; } } @@ -2598,7 +3201,7 @@ ggml_tensor * llama_context_kv_self::build_attn( // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens); struct ggml_tensor * v_cache_view = nullptr; @@ -2641,7 +3244,7 @@ ggml_tensor * llama_context_kv_self::build_attn( } }; - const auto & kq_mask = is_sliding ? inp_KQ_mask_swa_cnv : inp_KQ_mask_cnv; + const auto & kq_mask = is_sliding ? inp_self_kq_mask_swa_cnv : inp_self_kq_mask_cnv; const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -2754,15 +3357,6 @@ ggml_tensor * llama_context_kv_self::build_attn( return cur; } -ggml_tensor * llama_context_kv_self::build_attn_soft_max( - ggml_context * ctx0, - ggml_tensor * kq, - float kq_scale) { - const auto & hparams = model.hparams; - - return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias); -} - void llama_context_kv_self::build_kv_self_shift( ggml_context * ctx0, ggml_cgraph * gf) { @@ -2775,7 +3369,7 @@ void llama_context_kv_self::build_kv_self_shift( //GGML_ASSERT(kv_self.size == n_ctx); - ggml_tensor * inp_k_shift = build_inp_k_shift(ctx0); + ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0); for (uint32_t il = 0; il < n_layer; ++il) { const int64_t n_head_kv = hparams.n_head_kv(il); @@ -2790,7 +3384,7 @@ void llama_context_kv_self::build_kv_self_shift( ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(ctx0, k, inp_k_shift, rope_factors, kv_self.k_l[il]->buffer); + ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self.k_l[il]->buffer); ggml_build_forward_expand(gf, cur); } @@ -3082,7 +3676,7 @@ ggml_tensor * llama_context_kv_self::build_inp_embd_enc( return inp_embd_enc; } -ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross( +ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( ggml_context * ctx0, int32_t n_tokens, bool worst_case) { @@ -3092,10 +3686,10 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross( // TODO: not sure if this is correct const int32_t n_outputs_enc = worst_case ? n_tokens : embd_enc.size() / n_embd; - inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - ggml_set_input(inp_KQ_mask_cross); + inp_kq_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp_kq_mask_cross); - return inp_KQ_mask_cross; + return inp_kq_mask_cross; } // @@ -3765,11 +4359,23 @@ int32_t llama_apply_adapter_cvec( // struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - return llama_kv_cache_view_init(*ctx->get_kv_self(), n_seq_max); + const auto * kv = ctx->get_kv_self(); + if (kv == nullptr) { + LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); + return {}; + } + + return llama_kv_cache_view_init(*kv, n_seq_max); } void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - llama_kv_cache_view_update(view, *ctx->get_kv_self()); + const auto * kv = ctx->get_kv_self(); + if (kv == nullptr) { + LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); + return; + } + + llama_kv_cache_view_update(view, *kv); } // @@ -3903,7 +4509,7 @@ void llama_kv_cache_defrag(llama_context * ctx) { } void llama_kv_self_defrag(llama_context * ctx) { - return llama_kv_cache_defrag(ctx->get_kv_self()); + llama_kv_cache_defrag(ctx->get_kv_self()); } // deprecated diff --git a/src/llama-context.h b/src/llama-context.h index 2b3d5f122..c605cec6f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,7 @@ class llama_io_write_i; using llama_loras = std::unordered_map; +// basic transformer without KV cache struct llama_context : public llama_graph_i { llama_context( const llama_model & model, @@ -38,17 +39,19 @@ struct llama_context : public llama_graph_i { virtual uint32_t n_ctx_per_seq() const; virtual uint32_t n_batch() const; virtual uint32_t n_ubatch() const; - virtual uint32_t n_seq_max() const = 0; + virtual uint32_t n_seq_max() const; virtual uint32_t n_threads() const; virtual uint32_t n_threads_batch() const; virtual int32_t max_nodes() const; - virtual llama_kv_cache * get_kv_self() = 0; - virtual const llama_kv_cache * get_kv_self() const = 0; + // returns nullptr + virtual llama_kv_cache * get_kv_self(); + virtual const llama_kv_cache * get_kv_self() const; - virtual void kv_self_update() = 0; + // noop + virtual void kv_self_update(); virtual enum llama_pooling_type pooling_type() const; @@ -109,8 +112,6 @@ struct llama_context : public llama_graph_i { ggml_cgraph * gf, bool batched); - virtual void input_set(const llama_ubatch & ubatch); - // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. virtual int32_t output_reserve(int32_t n_outputs); @@ -128,7 +129,7 @@ struct llama_context : public llama_graph_i { // return positive int on warning // return negative int on error // - virtual int encode(llama_batch & inp_batch) = 0; + virtual int encode(llama_batch & inp_batch); // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), @@ -142,7 +143,7 @@ struct llama_context : public llama_graph_i { // return positive int on warning // return negative int on error // - virtual int decode(llama_batch & inp_batch) = 0; + virtual int decode(llama_batch & inp_batch); // // graph build API (generic) @@ -204,6 +205,31 @@ struct llama_context : public llama_graph_i { ggml_context * ctx0, int32_t n_tokens); + virtual void build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa, + bool worst_case); + + virtual ggml_tensor * build_attn( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case); + + // perf + + virtual llama_perf_context_data perf_get_data() const; + virtual void perf_reset(); + // state save/load virtual size_t state_get_size(); @@ -238,13 +264,7 @@ struct llama_context : public llama_graph_i { const llama_token * tokens, size_t n_token_count); - // perf - - virtual llama_perf_context_data perf_get_data() const; - virtual void perf_reset(); - protected: - // state save/load virtual size_t state_get_data(llama_io_write_i & io); @@ -253,14 +273,21 @@ protected: virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); - // input tensors + // input - struct ggml_tensor * inp_tokens; // I32 [n_batch] - struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_out_ids; // I32 [n_outputs] - struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - struct ggml_tensor * inp_cls; // I32 [n_batch] + virtual void input_set(const llama_ubatch & ubatch); + + // base input tensors + ggml_tensor * inp_tokens; // I32 [n_batch] + ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + ggml_tensor * inp_pos; // I32 [n_batch] + ggml_tensor * inp_out_ids; // I32 [n_outputs] + ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + ggml_tensor * inp_cls; // I32 [n_batch] + + // KQ mask input tensors + ggml_tensor * inp_kq_mask; // F32 [n_tokens, n_batch] + ggml_tensor * inp_kq_mask_cnv; // [n_tokens, n_batch] // members @@ -337,8 +364,6 @@ public: virtual ~llama_context_kv_self(); - virtual uint32_t n_seq_max() const override; - virtual llama_kv_cache * get_kv_self() override; virtual const llama_kv_cache * get_kv_self() const override; @@ -346,8 +371,6 @@ public: virtual ggml_cgraph * graph_init() override; - virtual void input_set(const llama_ubatch & ubatch) override; - virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; @@ -357,17 +380,7 @@ public: // certain implementations could require a padding for the context size uint32_t get_ctx_padding(const llama_cparams & cparams) const; - // === KV cache === - - llama_kv_cache kv_self; - - ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - ggml_tensor * inp_KQ_mask_cnv; // [kv_size, n_batch] - ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] - ggml_tensor * inp_KQ_mask_swa_cnv; // [kv_size, n_batch] - ggml_tensor * inp_k_shift; // I32 [kv_size] - - virtual ggml_tensor * build_inp_k_shift(ggml_context * ctx0) override; + virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override; virtual void build_attn_inp( ggml_context * ctx0, @@ -389,11 +402,6 @@ public: int il, bool worst_case) override; - virtual ggml_tensor * build_attn_soft_max( - ggml_context * ctx0, - ggml_tensor * kq, - float kq_scale) override; - virtual void build_kv_self_shift( ggml_context * ctx0, ggml_cgraph * gf) override; @@ -414,14 +422,14 @@ public: struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] - struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch] virtual ggml_tensor * build_inp_embd_enc( ggml_context * ctx0, int32_t n_tokens, bool worst_case) override; - virtual ggml_tensor * build_inp_KQ_mask_cross( + virtual ggml_tensor * build_inp_kq_mask_cross( ggml_context * ctx0, int32_t n_tokens, bool worst_case) override; @@ -432,6 +440,16 @@ protected: virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; + + virtual void input_set(const llama_ubatch & ubatch) override; + + llama_kv_cache kv_self; + + ggml_tensor * inp_self_kq_mask; // F32 [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_cnv; // [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_swa; // F32 [kv_size, n_batch] + ggml_tensor * inp_self_kq_mask_swa_cnv; // [kv_size, n_batch] + ggml_tensor * inp_self_k_shift; // I32 [kv_size] }; // a recurrent transformer (ie.e RWKV, Mamba) @@ -447,8 +465,6 @@ public: virtual ggml_cgraph * graph_init() override; - virtual void input_set(const llama_ubatch & ubatch) override; - virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, bool worst_case) override; @@ -506,6 +522,8 @@ public: bool worst_case) override; protected: + virtual void input_set(const llama_ubatch & ubatch) override; + struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 17605e74c..d9d4e00e9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2,6 +2,84 @@ #include "llama-impl.h" +ggml_tensor * llama_graph_i::build_attn( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(wo); + GGML_UNUSED(wo_b); + GGML_UNUSED(q_cur); + GGML_UNUSED(k_cur); + GGML_UNUSED(v_cur); + GGML_UNUSED(n_tokens); + GGML_UNUSED(kq_scale); + GGML_UNUSED(il); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +void llama_graph_i::build_kv_self_shift( + ggml_context * ctx0, + ggml_cgraph * gf) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); +} + +void llama_graph_i::build_kv_self_defrag( + ggml_context * ctx0, + ggml_cgraph * gf) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); +} + +ggml_tensor * llama_graph_i::build_inp_self_k_shift( + ggml_context * ctx0) { + GGML_UNUSED(ctx0); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +ggml_tensor * llama_graph_i::build_inp_embd_enc( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(n_tokens); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +ggml_tensor * llama_graph_i::build_inp_kq_mask_cross( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) { + GGML_UNUSED(ctx0); + GGML_UNUSED(n_tokens); + GGML_UNUSED(worst_case); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + ggml_tensor * llama_graph_i::build_inp_s_copy ( ggml_context * ctx0, bool worst_case) { diff --git a/src/llama-graph.h b/src/llama-graph.h index b64e0f5f4..8d237431e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -99,34 +99,29 @@ public: int32_t n_tokens, float kq_scale, int il, - bool worst_case) = 0; - - virtual ggml_tensor * build_attn_soft_max( - ggml_context * ctx0, - ggml_tensor * kq, - float kq_scale) = 0; + bool worst_case); virtual void build_kv_self_shift( ggml_context * ctx0, - ggml_cgraph * gf) = 0; + ggml_cgraph * gf); // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache virtual void build_kv_self_defrag( ggml_context * ctx0, - ggml_cgraph * gf) = 0; + ggml_cgraph * gf); - virtual ggml_tensor * build_inp_k_shift( - ggml_context * ctx0) = 0; + virtual ggml_tensor * build_inp_self_k_shift( + ggml_context * ctx0); virtual ggml_tensor * build_inp_embd_enc( ggml_context * ctx0, int32_t n_tokens, - bool worst_case) = 0; + bool worst_case); - virtual ggml_tensor * build_inp_KQ_mask_cross( + virtual ggml_tensor * build_inp_kq_mask_cross( ggml_context * ctx0, int32_t n_tokens, - bool worst_case) = 0; + bool worst_case); virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 8a87f9129..3aec6495f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1079,14 +1079,26 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t cell_count) // int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) { + if (!kv) { + return 0; + } + return kv->n_tokens(); } int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) { + if (!kv) { + return 0; + } + return kv->used; } void llama_kv_cache_clear(llama_kv_cache * kv) { + if (!kv) { + return; + } + kv->clear(); } @@ -1095,6 +1107,10 @@ bool llama_kv_cache_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (!kv) { + return true; + } + return kv->seq_rm(seq_id, p0, p1); } @@ -1104,10 +1120,18 @@ void llama_kv_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (!kv) { + return; + } + kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) { + if (!kv) { + return; + } + kv->seq_keep(seq_id); } @@ -1117,6 +1141,10 @@ void llama_kv_cache_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { + if (!kv) { + return; + } + kv->seq_add(seq_id, p0, p1, delta); } @@ -1126,18 +1154,34 @@ void llama_kv_cache_seq_div( llama_pos p0, llama_pos p1, int d) { + if (!kv) { + return; + } + kv->seq_div(seq_id, p0, p1, d); } llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) { + if (!kv) { + return 0; + } + return kv->seq_pos_max(seq_id); } void llama_kv_cache_defrag(llama_kv_cache * kv) { + if (!kv) { + return; + } + kv->defrag(); } bool llama_kv_cache_can_shift(const llama_kv_cache * kv) { + if (!kv) { + return false; + } + return kv->can_shift; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index debbacbb6..a0a7816da 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3956,8 +3956,8 @@ struct llm_build_context { } // TODO: tmp - struct ggml_tensor * build_inp_KQ_mask_cross() { - ggml_tensor * cur = lgf->build_inp_KQ_mask_cross(ctx0, n_tokens, worst_case); + struct ggml_tensor * build_inp_kq_mask_cross() { + ggml_tensor * cur = lgf->build_inp_kq_mask_cross(ctx0, n_tokens, worst_case); cb(cur, "KQ_mask_cross", -1); return cur; @@ -5568,7 +5568,6 @@ struct llm_build_context { // self-attention if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - cb(Qcur, "Qcur", il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, @@ -5578,7 +5577,6 @@ struct llm_build_context { } Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - cb(Kcur, "Kcur", il); if (model.layers[il].attn_k_norm) { Kcur = build_norm(Kcur, @@ -5586,11 +5584,12 @@ struct llm_build_context { model.layers[il].attn_k_norm_b, LLM_NORM, il); } + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { // compute Q and K and RoPE them cur = build_lora_mm(model.layers[il].wqkv, cur); @@ -5600,10 +5599,6 @@ struct llm_build_context { Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -5617,40 +5612,17 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - cb(kq, "kq", il); - - //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); - kq = lgf->build_attn_soft_max(ctx0, kq, 1.0f/sqrtf(float(n_embd_head))); - cb(kq, "kq_soft_max_ext", il); - - struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - cb(cur, "kqv_merged_cont", il); - - ggml_build_forward_expand(gf, cur); - - cur = build_lora_mm(model.layers[il].wo, cur); - if (model.layers[il].bo) { - cb(cur, "kqv_wo", il); - } - - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); - } + cur = build_attn(gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, n_tokens, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -9652,7 +9624,7 @@ struct llm_build_context { // struct ggml_tensor * pos_bucket_enc = build_pos_bucket(false); // // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - // struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false); + // struct ggml_tensor * KQ_mask_enc = build_inp_kq_mask(false); // for (int il = 0; il < n_layer; ++il) { // struct ggml_tensor * inpSA = inpL; @@ -9781,8 +9753,8 @@ struct llm_build_context { // struct ggml_tensor * embd_enc = build_inp_embd_enc(); // struct ggml_tensor * pos_bucket_dec = build_pos_bucket(true); - // struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask(); - // struct ggml_tensor * KQ_mask_cross = build_inp_KQ_mask_cross(); + // struct ggml_tensor * KQ_mask_dec = build_inp_kq_mask(); + // struct ggml_tensor * KQ_mask_cross = build_inp_kq_mask_cross(); // for (int il = 0; il < n_layer; ++il) { // struct ggml_tensor * inpSA = inpL; diff --git a/src/llama.cpp b/src/llama.cpp index 3db164477..9bacc9e9b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -328,6 +328,11 @@ struct llama_context * llama_init_from_model( try { // TODO: make static method of llama_context switch (model->arch) { + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + ctx = new llama_context(*model, params); + break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_MAMBA: