diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b2a11ad6..648a669b1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -48,6 +48,7 @@ llama_context::llama_context( // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self if (cparams.n_batch < GGML_KQ_MASK_PAD) { LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; @@ -2127,60 +2128,44 @@ void llama_context::input_set(const llama_ubatch & ubatch) { } if (inp_kq_mask) { - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn) { - // TODO: need to use the batch directly to construct the masks - GGML_ABORT("TODO"); + const int64_t n_kv = ubatch.n_tokens; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; - //const int64_t n_kv = ubatch.n_tokens; - //const int64_t n_tokens = ubatch.n_tokens; - //const int64_t n_seq_tokens = ubatch.n_seq_tokens; - //const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); + float * data = (float *) inp_kq_mask->data; - //float * data = nullptr; + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; - //if (inp_kq_mask) { - // GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer)); - // data = (float *) inp_kq_mask->data; - //} + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; - //// For causal attention, use only the previous KV cells - //// of the correct sequence for each token of the ubatch. - //// It's assumed that if a token in the batch has multiple sequences, they are equivalent. - //for (int h = 0; h < 1; ++h) { - // for (int s = 0; s < n_seqs; ++s) { - // const llama_seq_id seq_id = ubatch.seq_id[s][0]; + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; - // for (int j = 0; j < n_seq_tokens; ++j) { - // const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } - // for (int i = 0; i < n_kv; ++i) { - // float f; - // if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - // f = -INFINITY; - // } else { - // if (hparams.use_alibi) { - // f = -std::abs(kv_self.cells[i].pos - pos); - // } else { - // f = 0.0f; - // } - // } - - // if (data) { - // data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - // } - // } - // } - // } - - // if (data) { - // for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - // for (int j = 0; j < n_kv; ++j) { - // data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - // } - // } - // } - //} + data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f; + } + } + } + } + } } else { const int64_t n_tokens = ubatch.n_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens;