diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bde665953..e234e3683 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -10,30 +10,6 @@ #include #include -static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { - // TODO move to hparams if a T5 variant appears that uses a different value - const int64_t max_distance = 128; - - if (bidirectional) { - n_buckets >>= 1; - } - - const int64_t max_exact = n_buckets >> 1; - - int32_t relative_position = x - y; - int32_t relative_bucket = 0; - if (bidirectional) { - relative_bucket += (relative_position > 0) * n_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); - relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); - return relative_bucket; -} - // // llama_context // @@ -346,6 +322,7 @@ public: return size_written; } +private: size_t size_written = 0; }; @@ -378,6 +355,7 @@ public: return size_written; } +private: uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; @@ -406,6 +384,7 @@ public: return size_read; } +private: const uint8_t * ptr; size_t buf_size = 0; size_t size_read = 0; @@ -430,6 +409,7 @@ public: return size_written; } +private: llama_file * file; size_t size_written = 0; std::vector temp_buffer; @@ -454,6 +434,7 @@ public: return size_read; } +private: llama_file * file; size_t size_read = 0; std::vector temp_buffer; @@ -2132,6 +2113,30 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer)); GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; + }; + int32_t * data = (int32_t *) inp_pos_bucket->data; if (!is_encoding) { @@ -2139,7 +2144,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); + data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); } } } @@ -2147,7 +2152,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding); } } }