quantize : improve tensor-type pattern matching (#13033)

This commit is contained in:
Ed Addario
2025-05-13 18:12:31 +01:00
committed by GitHub
parent 71bdbdb587
commit e5c834f718
2 changed files with 26 additions and 87 deletions

View File

@ -57,6 +57,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
};
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
@ -244,56 +250,10 @@ static ggml_type parse_ggml_type(const char * arg) {
return type;
}
}
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
return GGML_TYPE_COUNT;
}
// Allowed tensors for arbitrary quantization with --tensor-type option
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
"attn_k",
"attn_kv_a_mqa",
"attn_kv_b",
"attn_o",
"attn_output",
"attn_q",
"attn_q_a",
"attn_q_b",
"attn_qkv",
"attn_v",
"channel_mix_key",
"channel_mix_receptance",
"channel_mix_value",
"cls",
"cls.output",
"cross_attn_k",
"cross_attn_o",
"cross_attn_q",
"cross_attn_v",
"ffn_act",
"ffn_down",
"ffn_down_exps",
"ffn_down_shexp",
"ffn_gate",
"ffn_gate_exps",
"ffn_gate_shexp",
"ffn_up",
"ffn_up_exps",
"ffn_up_shexp",
"ssm_in",
"ssm_out",
"time_mix_gate",
"time_mix_key",
"time_mix_output",
"time_mix_receptance",
"time_mix_value",
};
// changes to this struct must be replicated in llama-quant.cpp
struct tensor_quantization {
std::string name;
ggml_type quant = GGML_TYPE_COUNT;
};
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
const char * sep = strchr(data, '=');
if (sep == nullptr) {
@ -306,7 +266,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
printf("\n%s: missing tensor name\n\n", __func__);
return false;
}
if (const size_t qt_len = strlen(sep); qt_len == 1) {
printf("\n%s: missing quantization type\n\n", __func__);
return false;
@ -315,37 +274,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
std::string tn(data, tn_len);
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
sep++;
const std::string qt(sep);
bool found = false;
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
std::string tensor;
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
// handle special case of cls.output
std::string cls_output = "cls.output";
if (tn.find(cls_output) != std::string::npos) {
tensor = "cls.output";
}
// check if an allowed tensor exists and it's at the end of the kv string
if (tensor == allowed) {
found = true;
break;
}
}
if (!found) {
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
return false;
}
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
return false;
}
tensor_quantization tqz;
tqz.name = tn;
tqz.quant = parse_ggml_type(qt.c_str());
tqz.quant = parse_ggml_type(sep);
tensor_type.emplace_back(std::move(tqz));
if (tqz.quant == GGML_TYPE_COUNT) {
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
return false;
}
return true;
}