convert : allow using lazy remote tensors

It's a bit slow for now since everything is blocking and single-threaded.
This commit is contained in:
Francis Couture-Harpin
2025-04-08 10:26:24 -04:00
parent 08ecbbe398
commit 3a3682de0b
2 changed files with 51 additions and 18 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
import json
@@ -71,6 +72,20 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
@dataclass
class RemoteTensor:
dtype: str
shape: tuple[int, ...]
offset_start: int
size: int
url: str
def data(self) -> bytes:
# TODO: handle request errors (maybe with limited retries?)
data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)
return data
class SafetensorRemote:
"""
Uility class to handle remote safetensor files.
@@ -94,7 +109,7 @@ class SafetensorRemote:
ALIGNMENT = 8 # bytes
@classmethod
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]:
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a Hugging Face model repository.
@@ -105,10 +120,7 @@ class SafetensorRemote:
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
if is_single_file:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
tensors: dict[str, tuple[str, list[int], int, int, str]] = {}
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
return tensors
return cls.get_list_tensors(url)
# case 2: model has multiple files
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
@@ -124,17 +136,17 @@ class SafetensorRemote:
all_files = list(set(weight_map.values()))
all_files.sort() # make sure we load shard files in order
# get the list of tensors
tensors = {}
tensors: dict[str, RemoteTensor] = {}
for file in all_files:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
for key, val in cls.get_list_tensors(url).items():
tensors[key] = (*val, url) # populate the url
tensors[key] = val
return tensors
raise ValueError(f"Model {model_id} does not have any safetensor files")
@classmethod
def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]:
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a remote safetensor file.
@@ -142,7 +154,7 @@ class SafetensorRemote:
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
"""
metadata, data_start_offset = cls.get_metadata(url)
res: dict[str, tuple[str, list[int], int, int]] = {}
res: dict[str, RemoteTensor] = {}
for name, meta in metadata.items():
if name == "__metadata__":
@@ -155,7 +167,7 @@ class SafetensorRemote:
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = (dtype, shape, offset_start, size)
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")