mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
This commit is contained in:
@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
std::vector<uint64_t> sum(n_tokens, 0);
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
uint32_t * data = (uint32_t *) cls->data;
|
||||
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
std::vector<int> last_pos(n_tokens, -1);
|
||||
std::vector<int> last_row(n_tokens, -1);
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
@ -283,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
||||
if (hparams.use_alibi) {
|
||||
@ -322,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
@ -377,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_enc; ++i) {
|
||||
float f = -INFINITY;
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
||||
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
||||
|
Reference in New Issue
Block a user