diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 889d9fdfe..db8ad4f05 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -63,14 +63,17 @@ class WriterState(Enum): @dataclass -class TensorWriteInfo: +class ThreadedTensorWriteInfo: filename: Path offset: int post_pad: int tensor: np.ndarray - bar: Any | None + bar: Any | None # optional tqdm progress bar def write_chunk(self, open_files: dict[Path, BufferedWriter]): + # This is called from a thread pool, + # and each thread should have its own file handle per output file + # so that they can have different seek locations. if self.filename not in open_files: open_files[self.filename] = open(self.filename, "r+b") f = open_files[self.filename] @@ -460,8 +463,9 @@ class GGUFWriter: if self.temp_file is None: bar = None # Distribute writing the tensors between multiple threads - tensor_queue: Queue[TensorWriteInfo] = Queue() + tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue() + # Initial file offsets before writing the tensor data offsets: list[int] = [fout.tell() for fout in self.fout] if progress: @@ -472,6 +476,7 @@ class GGUFWriter: bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + # Fill the tensor queue with all the pending tensor writes for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): offset = offsets[i] @@ -484,7 +489,7 @@ class GGUFWriter: offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) padding = offset - (start_offset + nbytes) tensor_queue.put( - TensorWriteInfo( + ThreadedTensorWriteInfo( filename=filename, offset=start_offset, post_pad=padding, @@ -496,12 +501,13 @@ class GGUFWriter: # Write tensors in parallel # TODO: total tensor size limit for the running threads - def write_tensors_from_thread(queue: Queue[TensorWriteInfo]): + def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]): + # Opening the files only once per thread open_files: dict[Path, BufferedWriter] = {} try: - while t := queue.get_nowait(): - t.write_chunk(open_files) - del t + while tensor := queue.get_nowait(): + tensor.write_chunk(open_files) + del tensor queue.task_done() except Empty: pass