diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b508a4f8d..0e0af806d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1655,6 +1655,168 @@ ggml_context_ptr llama_context_kv_self::graph_init() { return llama_context::graph_init(); } +int llama_context_kv_self::encode(llama_batch & inp_batch) { + is_encoding = true; + + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1); + + const llama_batch & batch = batch_allocr.batch; + const int32_t n_tokens = batch.n_tokens; + + const auto & hparams = model.hparams; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int32_t i = 0; i < n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + n_queued_tokens += n_tokens; + + const int64_t n_embd = hparams.n_embd; + + sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (int32_t i = 0; i < n_tokens; ++i) { + output_ids[i] = i; + } + + inp_embd_enc = NULL; + n_outputs = n_tokens; + + //batch_manager->prepare(ubatch); + + // TODO: do reserve + GGML_ASSERT(need_reserve == false); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto ctx = graph_init(); + auto res = graph_build(ctx, ubatch, false); + + auto * gf = res.gf; + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set(ubatch); + + const auto compute_status = graph_compute(gf, n_tokens > 1); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + + auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + + // extract embeddings + if (t_embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + if (llama_model_has_decoder(&model)) { + embd_enc.resize(n_tokens*n_embd); + float * embd_out = embd_enc.data(); + + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + // remember the sequence ids used during the encoding - needed for cross attention later + seq_ids_enc.resize(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + for (int s = 0; s < ubatch.n_seq_id[i]; s++) { + llama_seq_id seq_id = ubatch.seq_id[i][s]; + seq_ids_enc[i].insert(seq_id); + } + } + } else { + GGML_ASSERT(embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd; + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); + + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + for (int32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; +} + int llama_context_kv_self::decode(llama_batch & inp_batch) { is_encoding = false; @@ -2020,168 +2182,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { return 0; } -int llama_context_kv_self::encode(llama_batch & inp_batch) { - is_encoding = true; - - if (inp_batch.n_tokens == 0) { - LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); - return -1; - } - - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; - const int32_t n_tokens = batch.n_tokens; - - const auto & hparams = model.hparams; - - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - - if (batch.token) { - for (int32_t i = 0; i < n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return -1; - } - } - } - - // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot - GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); - - if (t_compute_start_us == 0) { - t_compute_start_us = ggml_time_us(); - } - - n_queued_tokens += n_tokens; - - const int64_t n_embd = hparams.n_embd; - - sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); - - const llama_ubatch ubatch = sbatch.split_simple(n_tokens); - - // reserve output buffer - if (output_reserve(n_tokens) < n_tokens) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); - return -2; - }; - - for (int32_t i = 0; i < n_tokens; ++i) { - output_ids[i] = i; - } - - inp_embd_enc = NULL; - n_outputs = n_tokens; - - //batch_manager->prepare(ubatch); - - // TODO: do reserve - GGML_ASSERT(need_reserve == false); - - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - - auto ctx = graph_init(); - auto res = graph_build(ctx, ubatch, false); - - auto * gf = res.gf; - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - input_set(ubatch); - - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; - } - - auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; - - // extract embeddings - if (t_embd) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - if (llama_model_has_decoder(&model)) { - embd_enc.resize(n_tokens*n_embd); - float * embd_out = embd_enc.data(); - - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - - // remember the sequence ids used during the encoding - needed for cross attention later - seq_ids_enc.resize(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < ubatch.n_seq_id[i]; s++) { - llama_seq_id seq_id = ubatch.seq_id[i][s]; - seq_ids_enc[i].insert(seq_id); - } - } - } else { - GGML_ASSERT(embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd; - - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings - auto & embd_seq_out = embd_seq; - embd_seq_out.clear(); - - GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - - for (int32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = ubatch.seq_id[i][0]; - if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { - continue; - } - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // TODO: this likely should be the same logic as in llama_decoder_internal, but better to - // wait for an encoder model that requires this pooling type in order to test it - // https://github.com/ggerganov/llama.cpp/pull/9510 - GGML_ABORT("RANK pooling not implemented yet"); - } - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } - } - } - } - - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - - return 0; -} - llama_pos llama_context_kv_self::pos_max() const { return kv_self.pos_max(); } diff --git a/src/llama-context.h b/src/llama-context.h index e3ab12e59..9f6abfc82 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -116,6 +116,17 @@ struct llama_context : public llama_graph_i { // TODO: maybe remove this virtual void output_reorder(); + // encode a batch of tokens by evaluating the encoder part of the transformer + // + // - lctx: llama context + // - batch: batch to evaluate + // + // return 0 on success + // return positive int on warning + // return negative int on error + // + virtual int encode(llama_batch & inp_batch) = 0; + // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state @@ -130,17 +141,6 @@ struct llama_context : public llama_graph_i { // virtual int decode(llama_batch & inp_batch) = 0; - // encode a batch of tokens by evaluating the encoder part of the transformer - // - // - lctx: llama context - // - batch: batch to evaluate - // - // return 0 on success - // return positive int on warning - // return negative int on error - // - virtual int encode(llama_batch & inp_batch) = 0; - // // graph build API (generic) // @@ -336,8 +336,8 @@ public: virtual void input_set(const llama_ubatch & ubatch) override; - virtual int decode(llama_batch & inp_batch) override; virtual int encode(llama_batch & inp_batch) override; + virtual int decode(llama_batch & inp_batch) override; // max token position across all sequences in the current context llama_pos pos_max() const;