diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 820d5128e..159b1307a 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -14,6 +14,12 @@ #include #include +// Quantization types. Changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -48,12 +54,6 @@ struct quantize_state_impl { {} }; -// changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { - std::string name; - ggml_type quant = GGML_TYPE_COUNT; -}; - static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread @@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // unless the user specifies a type if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); + const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); + if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + new_type = qtype; + break; // if two or more types are specified for the tensor, first match wins } - new_type = qtype; - break; } } } } + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 0355311dc..3f54af7c5 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -57,6 +57,12 @@ static const std::vector QUANT_OPTIONS = { { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, }; +// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file"; static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; @@ -244,56 +250,10 @@ static ggml_type parse_ggml_type(const char * arg) { return type; } } - fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg); + fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg); return GGML_TYPE_COUNT; } -// Allowed tensors for arbitrary quantization with --tensor-type option -static const std::vector ALLOWED_TENSOR_TYPE = { - "attn_k", - "attn_kv_a_mqa", - "attn_kv_b", - "attn_o", - "attn_output", - "attn_q", - "attn_q_a", - "attn_q_b", - "attn_qkv", - "attn_v", - "channel_mix_key", - "channel_mix_receptance", - "channel_mix_value", - "cls", - "cls.output", - "cross_attn_k", - "cross_attn_o", - "cross_attn_q", - "cross_attn_v", - "ffn_act", - "ffn_down", - "ffn_down_exps", - "ffn_down_shexp", - "ffn_gate", - "ffn_gate_exps", - "ffn_gate_shexp", - "ffn_up", - "ffn_up_exps", - "ffn_up_shexp", - "ssm_in", - "ssm_out", - "time_mix_gate", - "time_mix_key", - "time_mix_output", - "time_mix_receptance", - "time_mix_value", -}; - -// changes to this struct must be replicated in llama-quant.cpp -struct tensor_quantization { - std::string name; - ggml_type quant = GGML_TYPE_COUNT; -}; - static bool parse_tensor_type(const char * data, std::vector & tensor_type) { const char * sep = strchr(data, '='); if (sep == nullptr) { @@ -306,7 +266,6 @@ static bool parse_tensor_type(const char * data, std::vector