mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-03 15:57:38 -04:00
gguf-py : use ThreadPoolExecutor when writing tensors
- gguf-py : handle (limited) retries for remote tensors
This commit is contained in:
@@ -10,10 +10,10 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from math import prod
|
from math import prod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
|
||||||
from io import BufferedWriter
|
from io import BufferedWriter
|
||||||
from typing import IO, Any, Sequence, Mapping
|
from typing import IO, Any, Sequence, Mapping
|
||||||
from string import ascii_letters, digits
|
from string import ascii_letters, digits
|
||||||
|
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -62,20 +62,49 @@ class WriterState(Enum):
|
|||||||
WEIGHTS = auto()
|
WEIGHTS = auto()
|
||||||
|
|
||||||
|
|
||||||
|
# To close files which were opened in thread-local context
|
||||||
|
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
|
||||||
|
# ref: https://github.com/python/cpython/issues/89502
|
||||||
|
class _ThreadedOpenFiles:
|
||||||
|
files: dict[Path, BufferedWriter]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.files = {}
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
for file in self.files.values():
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
def __getitem__(self, key: Path, /) -> BufferedWriter:
|
||||||
|
if key not in self.files:
|
||||||
|
self.files[key] = open(key, "r+b")
|
||||||
|
return self.files[key]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_thread_local(cls, local_data):
|
||||||
|
local_data.open_files = _ThreadedOpenFiles()
|
||||||
|
|
||||||
|
|
||||||
|
# Exit quickly instead of waiting
|
||||||
|
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
|
||||||
|
del exc_type, exc_val, exc_tb
|
||||||
|
self.shutdown(wait=False, cancel_futures=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ThreadedTensorWriteInfo:
|
class _ThreadedTensorWriteInfo:
|
||||||
filename: Path
|
filename: Path
|
||||||
offset: int
|
offset: int
|
||||||
post_pad: int
|
post_pad: int
|
||||||
tensor: np.ndarray
|
tensor: np.ndarray
|
||||||
bar: Any | None # optional tqdm progress bar
|
bar: Any | None # optional tqdm progress bar
|
||||||
|
|
||||||
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
|
def write_chunk(self, open_files: _ThreadedOpenFiles):
|
||||||
# This is called from a thread pool,
|
# This is called from a thread pool,
|
||||||
# and each thread should have its own file handle per output file
|
# and each thread should have its own file handle per output file
|
||||||
# so that they can have different seek locations.
|
# so that they can have different seek locations.
|
||||||
if self.filename not in open_files:
|
|
||||||
open_files[self.filename] = open(self.filename, "r+b")
|
|
||||||
f = open_files[self.filename]
|
f = open_files[self.filename]
|
||||||
|
|
||||||
f.seek(self.offset)
|
f.seek(self.offset)
|
||||||
@@ -462,9 +491,6 @@ class GGUFWriter:
|
|||||||
|
|
||||||
if self.temp_file is None:
|
if self.temp_file is None:
|
||||||
bar = None
|
bar = None
|
||||||
# Distribute writing the tensors between multiple threads
|
|
||||||
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
|
|
||||||
|
|
||||||
# Initial file offsets before writing the tensor data
|
# Initial file offsets before writing the tensor data
|
||||||
offsets: list[int] = [fout.tell() for fout in self.fout]
|
offsets: list[int] = [fout.tell() for fout in self.fout]
|
||||||
|
|
||||||
@@ -476,6 +502,21 @@ class GGUFWriter:
|
|||||||
|
|
||||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||||
|
|
||||||
|
# Allow opening the files only once per worker
|
||||||
|
local_data = threading.local()
|
||||||
|
|
||||||
|
# Unit of work
|
||||||
|
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
|
||||||
|
tensor.write_chunk(local_data.open_files)
|
||||||
|
|
||||||
|
with _InterruptibleThreadPoolExecutor(
|
||||||
|
max_workers=self.thread_count,
|
||||||
|
initializer=_ThreadedOpenFiles.init_thread_local,
|
||||||
|
initargs=(local_data,),
|
||||||
|
) as executor:
|
||||||
|
|
||||||
|
futures: list[Future] = []
|
||||||
|
|
||||||
# Fill the tensor queue with all the pending tensor writes
|
# Fill the tensor queue with all the pending tensor writes
|
||||||
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
|
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
|
||||||
offset = offsets[i]
|
offset = offsets[i]
|
||||||
@@ -488,48 +529,31 @@ class GGUFWriter:
|
|||||||
nbytes = ti.tensor.nbytes
|
nbytes = ti.tensor.nbytes
|
||||||
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
|
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
|
||||||
padding = offset - (start_offset + nbytes)
|
padding = offset - (start_offset + nbytes)
|
||||||
tensor_queue.put(
|
futures.append(
|
||||||
ThreadedTensorWriteInfo(
|
executor.submit(
|
||||||
|
thread_write_tensor,
|
||||||
|
_ThreadedTensorWriteInfo(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
offset=start_offset,
|
offset=start_offset,
|
||||||
post_pad=padding,
|
post_pad=padding,
|
||||||
tensor=ti.tensor,
|
tensor=ti.tensor,
|
||||||
bar=bar,
|
bar=bar,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ti.tensor = None # avoid keeping a reference to written tensors
|
ti.tensor = None # avoid keeping a reference to written tensors
|
||||||
|
|
||||||
# Write tensors in parallel
|
# FIXME: there's still some weird behavior with KeyboardInterrupt
|
||||||
# TODO: total tensor size limit for the running threads
|
# not being able to interrupt a future mid-execution
|
||||||
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
|
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
||||||
# Opening the files only once per thread
|
exc = None
|
||||||
open_files: dict[Path, BufferedWriter] = {}
|
if any(f for f in done
|
||||||
try:
|
if not f.cancelled() and (exc := f.exception()) is not None):
|
||||||
while tensor := queue.get_nowait():
|
raise RuntimeError("Error writing tensors") from exc
|
||||||
tensor.write_chunk(open_files)
|
elif len(not_done) != 0:
|
||||||
del tensor
|
raise RuntimeError("Not all tensors were written")
|
||||||
queue.task_done()
|
|
||||||
except Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for f in open_files.values():
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
threads = [
|
|
||||||
threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
|
|
||||||
for _ in range(self.thread_count)
|
|
||||||
]
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
# NOTE: thread joining has weird interactions with KeyboardInterrupt,
|
|
||||||
# so waiting for the queue to be "done" first.
|
|
||||||
tensor_queue.join()
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
|
del local_data
|
||||||
else:
|
else:
|
||||||
self.temp_file.seek(0)
|
self.temp_file.seek(0)
|
||||||
|
|
||||||
|
@@ -5,6 +5,14 @@ from typing import Literal
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||||
@@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoteTensor:
|
class RemoteTensor:
|
||||||
|
name: str
|
||||||
dtype: str
|
dtype: str
|
||||||
shape: tuple[int, ...]
|
shape: tuple[int, ...]
|
||||||
offset_start: int
|
offset_start: int
|
||||||
@@ -82,9 +91,30 @@ class RemoteTensor:
|
|||||||
url: str
|
url: str
|
||||||
|
|
||||||
def data(self) -> bytearray:
|
def data(self) -> bytearray:
|
||||||
# TODO: handle request errors (maybe with limited retries?)
|
data = None
|
||||||
|
MAX_RETRIES = 8
|
||||||
|
for i in range(MAX_RETRIES):
|
||||||
|
try:
|
||||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
data = bytearray(
|
||||||
|
SafetensorRemote.get_data_by_range(
|
||||||
|
url=self.url, start=self.offset_start, size=self.size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ChunkedEncodingError,
|
||||||
|
requests.exceptions.ContentDecodingError,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
) as e:
|
||||||
|
if i == MAX_RETRIES - 1:
|
||||||
|
raise RuntimeError(f"Failed to download tensor {self.name}") from e
|
||||||
|
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
|
||||||
|
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
|
||||||
|
continue
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
raise RuntimeError(f"Failed to download tensor {self.name}")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -169,7 +199,14 @@ class SafetensorRemote:
|
|||||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||||
size = offset_end_relative - offset_start_relative
|
size = offset_end_relative - offset_start_relative
|
||||||
offset_start = data_start_offset + offset_start_relative
|
offset_start = data_start_offset + offset_start_relative
|
||||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
res[name] = RemoteTensor(
|
||||||
|
name=name,
|
||||||
|
dtype=dtype,
|
||||||
|
shape=tuple(shape),
|
||||||
|
offset_start=offset_start,
|
||||||
|
size=size,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||||
|
|
||||||
@@ -217,8 +254,6 @@ class SafetensorRemote:
|
|||||||
Get raw byte data from a remote file by range.
|
Get raw byte data from a remote file by range.
|
||||||
If size is not specified, it will read the entire file.
|
If size is not specified, it will read the entire file.
|
||||||
"""
|
"""
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if not parsed_url.scheme or not parsed_url.netloc:
|
if not parsed_url.scheme or not parsed_url.netloc:
|
||||||
@@ -239,9 +274,6 @@ class SafetensorRemote:
|
|||||||
Check if a file exists at the given URL.
|
Check if a file exists at the given URL.
|
||||||
Returns True if the file exists, False otherwise.
|
Returns True if the file exists, False otherwise.
|
||||||
"""
|
"""
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if not parsed_url.scheme or not parsed_url.netloc:
|
if not parsed_url.scheme or not parsed_url.netloc:
|
||||||
raise ValueError(f"Invalid URL: {url}")
|
raise ValueError(f"Invalid URL: {url}")
|
||||||
|
Reference in New Issue
Block a user