mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
convert : fix squeeze for ssm_conv tensors (#12573)
* convert : fix squeeze for ssm_conv tensors * convert : match ssm_conv tensors by type --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net>
This commit is contained in:
@ -3803,8 +3803,6 @@ class MambaModel(Model):
|
|||||||
_tok_embd = None
|
_tok_embd = None
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
|
||||||
|
|
||||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||||
|
|
||||||
@ -3814,6 +3812,10 @@ class MambaModel(Model):
|
|||||||
logger.debug("A_log --> A ==> " + new_name)
|
logger.debug("A_log --> A ==> " + new_name)
|
||||||
data_torch = -torch.exp(data_torch)
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
|
# [4 1 8192 1] -> [4 8192 1 1]
|
||||||
|
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
|
||||||
|
data_torch = data_torch.squeeze()
|
||||||
|
|
||||||
# assuming token_embd.weight is seen before output.weight
|
# assuming token_embd.weight is seen before output.weight
|
||||||
if self._tok_embd is not None and new_name == output_name:
|
if self._tok_embd is not None and new_name == output_name:
|
||||||
if torch.equal(self._tok_embd, data_torch):
|
if torch.equal(self._tok_embd, data_torch):
|
||||||
|
Reference in New Issue
Block a user