context : minor cleanup

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:50:53 +02:00
parent f7c7757bab
commit e08f38df69

View File

@ -10,30 +10,6 @@
#include <stdexcept>
#include <cinttypes>
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
if (bidirectional) {
n_buckets >>= 1;
}
const int64_t max_exact = n_buckets >> 1;
int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
}
//
// llama_context
//
@ -346,6 +322,7 @@ public:
return size_written;
}
private:
size_t size_written = 0;
};
@ -378,6 +355,7 @@ public:
return size_written;
}
private:
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
@ -406,6 +384,7 @@ public:
return size_read;
}
private:
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
@ -430,6 +409,7 @@ public:
return size_written;
}
private:
llama_file * file;
size_t size_written = 0;
std::vector<uint8_t> temp_buffer;
@ -454,6 +434,7 @@ public:
return size_read;
}
private:
llama_file * file;
size_t size_read = 0;
std::vector<uint8_t> temp_buffer;
@ -2132,6 +2113,30 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer));
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
if (bidirectional) {
n_buckets >>= 1;
}
const int64_t max_exact = n_buckets >> 1;
int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
};
int32_t * data = (int32_t *) inp_pos_bucket->data;
if (!is_encoding) {
@ -2139,7 +2144,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
}
}
}
@ -2147,7 +2152,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_tokens; ++i) {
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
}
}
}