diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 797773193..8fcff0de7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3690,14 +3690,20 @@ class BertModel(TextModel): super().__init__(*args, **kwargs) self.vocab_size = None + if cls_out_labels := self.hparams.get("id2label"): + if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0": + # Remove dummy labels added by AutoConfig + cls_out_labels = None + self.cls_out_labels = cls_out_labels + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_causal_attention(False) self._try_set_pooling_type() - if cls_out_labels := self.hparams.get("id2label"): + if self.cls_out_labels: key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch]) - self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())]) + self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())]) def set_vocab(self): tokens, toktypes, tokpre = self.get_vocab_base() @@ -3749,7 +3755,7 @@ class BertModel(TextModel): if name.startswith("cls.seq_relationship"): return [] - if self.hparams.get("id2label"): + if self.cls_out_labels: # For BertForSequenceClassification (direct projection layer) if name == "classifier.weight": name = "classifier.out_proj.weight"