diff --git a/common/common.cpp b/common/common.cpp index c2c94e7ae..e4e71ad13 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1290,6 +1290,9 @@ std::vector common_tokenize( int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens == std::numeric_limits::min()) { + throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit"); + } if (n_tokens < 0) { result.resize(-n_tokens); int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); diff --git a/include/llama.h b/include/llama.h index 3475d5965..b04720bee 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1088,6 +1088,7 @@ extern "C" { /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @return Returns the number of tokens on success, no more than n_tokens_max /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit) /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so. /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// as plaintext. Does not insert a leading space. diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 4ab120d9b..4aaf4c825 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -3074,6 +3074,11 @@ int32_t llama_vocab::tokenize( bool add_special, bool parse_special) const { auto res = tokenize(std::string(text, text_len), add_special, parse_special); + if (res.size() >= static_cast(std::numeric_limits::max())) { + LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size()); + return std::numeric_limits::min(); + } + if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size());