mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
model : add Kimi-K2 support (#14654)
* Kimi-K2 conversion * add Kimi_K2 pre type * Kimi-K2 * Kimi-K2 unicode * Kimi-K2 * LLAMA_MAX_EXPERTS 384 * fix vocab iteration * regex space fix * add kimi-k2 to pre_computed_hashes * Updated with kimi-k2 get_vocab_base_pre hash * fix whitespaces * fix flake errors * remove more unicode.cpp whitespaces * change set_vocab() flow * add moonshotai-Kimi-K2.jinja to /models/templates/ * update moonshotai-Kimi-K2.jinja * add kimi-k2 chat template * add kimi-k2 * update NotImplementedError Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * except Exception Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * LLM_CHAT_TEMPLATE_KIMI_K2 if(add_ass){} --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -65,6 +65,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -188,6 +189,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -680,6 +683,26 @@ int32_t llm_chat_apply_template(
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
|
||||
// moonshotai/Kimi-K2-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|im_system|>system<|im_middle|>";
|
||||
} else if (role == "user") {
|
||||
ss << "<|im_user|>user<|im_middle|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|im_assistant|>assistant<|im_middle|>";
|
||||
} else if (role == "tool") {
|
||||
ss << "<|im_system|>tool<|im_middle|>";
|
||||
}
|
||||
|
||||
ss << message->content << "<|im_end|>";
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<|im_assistant|>assistant<|im_middle|>";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
@@ -45,6 +45,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
|
||||
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
||||
|
||||
enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
|
@@ -405,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_KIMI_K2:
|
||||
regex_exprs = {
|
||||
// K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp
|
||||
// The custom handler implements all K2 patterns with proper Han character exclusion
|
||||
"\\p{Han}+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
|
||||
regex_exprs = {
|
||||
"\\p{N}+",
|
||||
@@ -1954,6 +1961,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "hunyuan") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "kimi-k2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
@@ -45,6 +45,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
207
src/unicode.cpp
207
src/unicode.cpp
@@ -557,6 +557,178 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// K2 system regex patterns (from tokenization_kimi.py):
|
||||
// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||
static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
bpe_offsets.reserve(offsets.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||
};
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||
const uint32_t cpt = _get_cpt(pos);
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// Pattern 1: [\p{Han}]+ (Chinese characters)
|
||||
if (unicode_cpt_is_han(cpt)) {
|
||||
while (unicode_cpt_is_han(_get_cpt(pos))) {
|
||||
pos++;
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Pattern 2 & 3: Letter words excluding Han characters with optional contractions
|
||||
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
|
||||
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
|
||||
// Check if current char is a letter OR if current char could be a leading char and next char is a letter
|
||||
bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
|
||||
(!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
|
||||
_get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
|
||||
|
||||
if (is_letter_pattern) {
|
||||
// Handle optional leading non-letter/non-number character
|
||||
bool has_leading_char = false;
|
||||
if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
|
||||
has_leading_char = true;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Match letter sequence (excluding Han characters)
|
||||
bool has_letters = false;
|
||||
while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
|
||||
has_letters = true;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Only proceed if we found letters (after potentially skipping leading char)
|
||||
if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
|
||||
if (!has_letters) pos++; // consume the first letter if we didn't already
|
||||
|
||||
// Continue consuming letters
|
||||
while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
|
||||
if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
|
||||
uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
|
||||
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
||||
pos += 2;
|
||||
} else if (pos + 2 < offset_end) {
|
||||
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
|
||||
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
||||
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
||||
pos += 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_add_token(pos);
|
||||
continue;
|
||||
} else if (has_leading_char) {
|
||||
// We consumed a leading char but found no letters, backtrack
|
||||
pos--;
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
|
||||
if (flags.is_number) {
|
||||
size_t ini = pos;
|
||||
while (_get_flags(pos).is_number) {
|
||||
if (++pos - ini >= 3) {
|
||||
_add_token(pos);
|
||||
ini = pos;
|
||||
}
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
|
||||
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
|
||||
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
|
||||
pos += (cpt == ' ');
|
||||
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
|
||||
flags2 = _get_flags(++pos);
|
||||
}
|
||||
// Match optional [\r\n]*
|
||||
uint32_t cpt2 = _get_cpt(pos);
|
||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||
cpt2 = _get_cpt(++pos);
|
||||
}
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Count whitespace characters
|
||||
size_t num_whitespaces = 0;
|
||||
size_t last_end_r_or_n = 0;
|
||||
while (_get_flags(pos + num_whitespaces).is_whitespace) {
|
||||
uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
|
||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||
}
|
||||
num_whitespaces++;
|
||||
}
|
||||
|
||||
// Pattern 6: \s*[\r\n]+ (whitespace with newlines)
|
||||
if (last_end_r_or_n > 0) {
|
||||
pos = last_end_r_or_n;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Pattern 7: \s+(?!\S) (trailing whitespace)
|
||||
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
|
||||
pos += num_whitespaces - 1;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Pattern 8: \s+ (general whitespace)
|
||||
if (num_whitespaces > 0) {
|
||||
pos += num_whitespaces;
|
||||
_add_token(pos);
|
||||
continue;
|
||||
}
|
||||
|
||||
// No matches - consume single character
|
||||
_add_token(++pos);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
@@ -567,6 +739,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
||||
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
||||
|
||||
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
||||
} else if (regex_expr == "\\p{Han}+") {
|
||||
// K2's first pattern - handle all K2 patterns together
|
||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
@@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) {
|
||||
return cpt; // Return the original code point if no lowercase mapping is found
|
||||
}
|
||||
|
||||
bool unicode_cpt_is_han(uint32_t cpt) {
|
||||
// Han character ranges (Chinese/CJK characters)
|
||||
// CJK Unified Ideographs (most common)
|
||||
if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true;
|
||||
|
||||
// CJK Extension A
|
||||
if (cpt >= 0x3400 && cpt <= 0x4DBF) return true;
|
||||
|
||||
// CJK Extension B
|
||||
if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true;
|
||||
|
||||
// CJK Extension C
|
||||
if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true;
|
||||
|
||||
// CJK Extension D
|
||||
if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true;
|
||||
|
||||
// CJK Extension E
|
||||
if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true;
|
||||
|
||||
// CJK Extension F
|
||||
if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true;
|
||||
|
||||
// CJK Compatibility Ideographs
|
||||
if (cpt >= 0xF900 && cpt <= 0xFAFF) return true;
|
||||
|
||||
// CJK Compatibility Ideographs Supplement
|
||||
if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
|
@@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||
|
||||
uint32_t unicode_tolower(uint32_t cpt);
|
||||
|
||||
bool unicode_cpt_is_han(uint32_t cpt);
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
||||
|
Reference in New Issue
Block a user