mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
llama : validate seq id batch input (#13809)
* llama : validate seq id batch input ggml-ci * cont : fix the fix ggml-ci
This commit is contained in:
@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
|
// TODO: move the validation to the llama_batch_allocr
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||||
|
throw -1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
|
// TODO: move the validation to the llama_batch_allocr
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||||
throw std::runtime_error("invalid token");
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user