mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
convert : workaround for AutoConfig dummy labels (#13881)
This commit is contained in:
@ -3690,14 +3690,20 @@ class BertModel(TextModel):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.vocab_size = None
|
self.vocab_size = None
|
||||||
|
|
||||||
|
if cls_out_labels := self.hparams.get("id2label"):
|
||||||
|
if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0":
|
||||||
|
# Remove dummy labels added by AutoConfig
|
||||||
|
cls_out_labels = None
|
||||||
|
self.cls_out_labels = cls_out_labels
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_causal_attention(False)
|
self.gguf_writer.add_causal_attention(False)
|
||||||
self._try_set_pooling_type()
|
self._try_set_pooling_type()
|
||||||
|
|
||||||
if cls_out_labels := self.hparams.get("id2label"):
|
if self.cls_out_labels:
|
||||||
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
||||||
self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
|
self.gguf_writer.add_array(key_name, [v for k, v in sorted(self.cls_out_labels.items())])
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
@ -3749,7 +3755,7 @@ class BertModel(TextModel):
|
|||||||
if name.startswith("cls.seq_relationship"):
|
if name.startswith("cls.seq_relationship"):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if self.hparams.get("id2label"):
|
if self.cls_out_labels:
|
||||||
# For BertForSequenceClassification (direct projection layer)
|
# For BertForSequenceClassification (direct projection layer)
|
||||||
if name == "classifier.weight":
|
if name == "classifier.weight":
|
||||||
name = "classifier.out_proj.weight"
|
name = "classifier.out_proj.weight"
|
||||||
|
Reference in New Issue
Block a user