gguf-py : add more clarifying comments for multi-thread writes

This commit is contained in:
Francis Couture-Harpin
2025-04-08 21:55:15 -04:00
parent 06e1d3119a
commit d8bab9efa1

View File

@@ -63,14 +63,17 @@ class WriterState(Enum):
@dataclass
class TensorWriteInfo:
class ThreadedTensorWriteInfo:
filename: Path
offset: int
post_pad: int
tensor: np.ndarray
bar: Any | None
bar: Any | None # optional tqdm progress bar
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
# This is called from a thread pool,
# and each thread should have its own file handle per output file
# so that they can have different seek locations.
if self.filename not in open_files:
open_files[self.filename] = open(self.filename, "r+b")
f = open_files[self.filename]
@@ -460,8 +463,9 @@ class GGUFWriter:
if self.temp_file is None:
bar = None
# Distribute writing the tensors between multiple threads
tensor_queue: Queue[TensorWriteInfo] = Queue()
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
# Initial file offsets before writing the tensor data
offsets: list[int] = [fout.tell() for fout in self.fout]
if progress:
@@ -472,6 +476,7 @@ class GGUFWriter:
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
# Fill the tensor queue with all the pending tensor writes
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
offset = offsets[i]
@@ -484,7 +489,7 @@ class GGUFWriter:
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
padding = offset - (start_offset + nbytes)
tensor_queue.put(
TensorWriteInfo(
ThreadedTensorWriteInfo(
filename=filename,
offset=start_offset,
post_pad=padding,
@@ -496,12 +501,13 @@ class GGUFWriter:
# Write tensors in parallel
# TODO: total tensor size limit for the running threads
def write_tensors_from_thread(queue: Queue[TensorWriteInfo]):
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
# Opening the files only once per thread
open_files: dict[Path, BufferedWriter] = {}
try:
while t := queue.get_nowait():
t.write_chunk(open_files)
del t
while tensor := queue.get_nowait():
tensor.write_chunk(open_files)
del tensor
queue.task_done()
except Empty:
pass