diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 89a9d494d..c42b9cb15 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -1,10 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import Literal, Any import os import json +import requests +import threading +from urllib.parse import urlparse def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -110,6 +113,10 @@ class SafetensorRemote: BASE_DOMAIN = "https://huggingface.co" ALIGNMENT = 8 # bytes + # start using multithread download for files larger than 100MB + MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 + MULTITHREAD_COUNT = 8 # number of threads + @classmethod def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: """ @@ -211,29 +218,139 @@ class SafetensorRemote: except json.JSONDecodeError as e: raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + @classmethod + def _get_request_headers(cls) -> dict[str, str]: + """Prepare common headers for requests.""" + headers = {"User-Agent": "convert_hf_to_gguf"} + if os.environ.get("HF_TOKEN"): + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + return headers + @classmethod def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: """ - Get raw byte data from a remote file by range. - If size is not specified, it will read the entire file. - """ - import requests - from urllib.parse import urlparse + Get raw byte data from a remote file by range using single or multi-threaded download. + If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). + If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. + Otherwise, it uses a single request. + """ parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") - headers = {} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" - if size > -1: - headers["Range"] = f"bytes={start}-{start + size}" - response = requests.get(url, allow_redirects=True, headers=headers) - response.raise_for_status() + common_headers = cls._get_request_headers() - # Get raw byte data - return response.content[:size] + # --- Multithreading Path --- + if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: + # print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB") + num_threads = cls.MULTITHREAD_COUNT + results: list[Any] = [None] * num_threads # Store results or exceptions + threads: list[threading.Thread] = [] + + def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): + """Worker function for thread.""" + thread_headers = headers.copy() + # Range header is inclusive end byte + range_end = chunk_start + chunk_size - 1 + thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" + try: + # Using stream=False should make requests wait for content download + response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout + response.raise_for_status() # Check for HTTP errors + + content = response.content + if len(content) != chunk_size: + # This is a critical check + raise IOError( + f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " + f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" + ) + result_list[index] = content + except Exception as e: + # Store exception to be raised by the main thread + # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print + result_list[index] = e + + # Calculate chunk sizes and create/start threads + base_chunk_size = size // num_threads + remainder = size % num_threads + current_offset = start + + for i in range(num_threads): + chunk_size = base_chunk_size + (1 if i < remainder else 0) + if chunk_size == 0: # Should not happen if size >= threshold but handle defensively + results[i] = b"" # Store empty bytes for this "chunk" + continue + + thread = threading.Thread( + target=download_chunk, + args=(url, current_offset, chunk_size, i, results, common_headers), + daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) + ) + threads.append(thread) + thread.start() + current_offset += chunk_size # Move offset for the next chunk + + # Wait for all threads to complete + for i, thread in enumerate(threads): + thread.join() # Wait indefinitely for each thread + + # Check results for errors and concatenate chunks + final_data_parts = [] + for i in range(num_threads): + result = results[i] + if isinstance(result, Exception): + # Raise the first exception encountered + raise result + elif result is None: + # This indicates a thread finished without setting its result or exception (unexpected) + # Check if it was supposed to download anything + expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) + if expected_chunk_size > 0: + raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") + else: + final_data_parts.append(b"") # Append empty bytes for zero-size chunk + else: + final_data_parts.append(result) + + # Combine the byte chunks + final_data = b"".join(final_data_parts) + + # Final validation: Does the combined size match the requested size? + if len(final_data) != size: + raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") + + return final_data + + # --- Single-threaded Path --- + else: + # print(f"Using single thread for size {size}") # Optional debug print + headers = common_headers.copy() + if size > -1: + # Range header uses inclusive end byte + range_end = start + size - 1 + headers["Range"] = f"bytes={start}-{range_end}" + elif start > 0: + # Request from start offset to the end of the file + headers["Range"] = f"bytes={start}-" + # If start=0 and size=-1, no Range header is needed (get full file) + + response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout + response.raise_for_status() + content = response.content + + # Validate downloaded size if a specific size was requested + if size > -1 and len(content) != size: + # Check status code - 206 Partial Content is expected for successful range requests + status_code = response.status_code + content_range = response.headers.get('Content-Range') + raise IOError( + f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " + f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" + ) + + return content @classmethod def check_file_exist(cls, url: str) -> bool: @@ -241,17 +358,13 @@ class SafetensorRemote: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ - import requests - from urllib.parse import urlparse - parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}") try: - headers = {"Range": "bytes=0-0"} - if os.environ.get("HF_TOKEN"): - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" + headers = cls._get_request_headers() + headers["Range"] = "bytes=0-0" # Request a small range to check existence response = requests.head(url, allow_redirects=True, headers=headers) # Success (2xx) or redirect (3xx) return 200 <= response.status_code < 400