mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-15 15:17:44 +00:00
context : minor cleanup
ggml-ci
This commit is contained in:
@ -10,30 +10,6 @@
|
||||
#include <stdexcept>
|
||||
#include <cinttypes>
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
if (bidirectional) {
|
||||
n_buckets >>= 1;
|
||||
}
|
||||
|
||||
const int64_t max_exact = n_buckets >> 1;
|
||||
|
||||
int32_t relative_position = x - y;
|
||||
int32_t relative_bucket = 0;
|
||||
if (bidirectional) {
|
||||
relative_bucket += (relative_position > 0) * n_buckets;
|
||||
relative_position = abs(relative_position);
|
||||
} else {
|
||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||
}
|
||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
//
|
||||
// llama_context
|
||||
//
|
||||
@ -346,6 +322,7 @@ public:
|
||||
return size_written;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
@ -378,6 +355,7 @@ public:
|
||||
return size_written;
|
||||
}
|
||||
|
||||
private:
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
@ -406,6 +384,7 @@ public:
|
||||
return size_read;
|
||||
}
|
||||
|
||||
private:
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
@ -430,6 +409,7 @@ public:
|
||||
return size_written;
|
||||
}
|
||||
|
||||
private:
|
||||
llama_file * file;
|
||||
size_t size_written = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
@ -454,6 +434,7 @@ public:
|
||||
return size_read;
|
||||
}
|
||||
|
||||
private:
|
||||
llama_file * file;
|
||||
size_t size_read = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
@ -2132,6 +2113,30 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer));
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
||||
|
||||
static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
||||
if (bidirectional) {
|
||||
n_buckets >>= 1;
|
||||
}
|
||||
|
||||
const int64_t max_exact = n_buckets >> 1;
|
||||
|
||||
int32_t relative_position = x - y;
|
||||
int32_t relative_bucket = 0;
|
||||
if (bidirectional) {
|
||||
relative_bucket += (relative_position > 0) * n_buckets;
|
||||
relative_position = abs(relative_position);
|
||||
} else {
|
||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||
}
|
||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||
return relative_bucket;
|
||||
};
|
||||
|
||||
int32_t * data = (int32_t *) inp_pos_bucket->data;
|
||||
|
||||
if (!is_encoding) {
|
||||
@ -2139,7 +2144,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2147,7 +2152,7 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
|
||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user