multithreaded download

This commit is contained in:
Xuan Son Nguyen
2025-04-08 18:05:48 +02:00
parent 4c0170e206
commit 42fc895ace

View File

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal, Any
import os import os
import json import json
import requests
import threading
from urllib.parse import urlparse
def fill_templated_filename(filename: str, output_type: str | None) -> str: def fill_templated_filename(filename: str, output_type: str | None) -> str:
@ -110,6 +113,10 @@ class SafetensorRemote:
BASE_DOMAIN = "https://huggingface.co" BASE_DOMAIN = "https://huggingface.co"
ALIGNMENT = 8 # bytes ALIGNMENT = 8 # bytes
# start using multithread download for files larger than 100MB
MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024
MULTITHREAD_COUNT = 8 # number of threads
@classmethod @classmethod
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
""" """
@ -211,29 +218,139 @@ class SafetensorRemote:
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
@classmethod
def _get_request_headers(cls) -> dict[str, str]:
"""Prepare common headers for requests."""
headers = {"User-Agent": "convert_hf_to_gguf"}
if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers
@classmethod @classmethod
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
""" """
Get raw byte data from a remote file by range. Get raw byte data from a remote file by range using single or multi-threaded download.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse
If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only).
If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads.
Otherwise, it uses a single request.
"""
parsed_url = urlparse(url) parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc: if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}") raise ValueError(f"Invalid URL: {url}")
headers = {} common_headers = cls._get_request_headers()
if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
if size > -1:
headers["Range"] = f"bytes={start}-{start + size}"
response = requests.get(url, allow_redirects=True, headers=headers)
response.raise_for_status()
# Get raw byte data # --- Multithreading Path ---
return response.content[:size] if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1:
# print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB")
num_threads = cls.MULTITHREAD_COUNT
results: list[Any] = [None] * num_threads # Store results or exceptions
threads: list[threading.Thread] = []
def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict):
"""Worker function for thread."""
thread_headers = headers.copy()
# Range header is inclusive end byte
range_end = chunk_start + chunk_size - 1
thread_headers["Range"] = f"bytes={chunk_start}-{range_end}"
try:
# Using stream=False should make requests wait for content download
response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout
response.raise_for_status() # Check for HTTP errors
content = response.content
if len(content) != chunk_size:
# This is a critical check
raise IOError(
f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. "
f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}"
)
result_list[index] = content
except Exception as e:
# Store exception to be raised by the main thread
# print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print
result_list[index] = e
# Calculate chunk sizes and create/start threads
base_chunk_size = size // num_threads
remainder = size % num_threads
current_offset = start
for i in range(num_threads):
chunk_size = base_chunk_size + (1 if i < remainder else 0)
if chunk_size == 0: # Should not happen if size >= threshold but handle defensively
results[i] = b"" # Store empty bytes for this "chunk"
continue
thread = threading.Thread(
target=download_chunk,
args=(url, current_offset, chunk_size, i, results, common_headers),
daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this)
)
threads.append(thread)
thread.start()
current_offset += chunk_size # Move offset for the next chunk
# Wait for all threads to complete
for i, thread in enumerate(threads):
thread.join() # Wait indefinitely for each thread
# Check results for errors and concatenate chunks
final_data_parts = []
for i in range(num_threads):
result = results[i]
if isinstance(result, Exception):
# Raise the first exception encountered
raise result
elif result is None:
# This indicates a thread finished without setting its result or exception (unexpected)
# Check if it was supposed to download anything
expected_chunk_size = base_chunk_size + (1 if i < remainder else 0)
if expected_chunk_size > 0:
raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.")
else:
final_data_parts.append(b"") # Append empty bytes for zero-size chunk
else:
final_data_parts.append(result)
# Combine the byte chunks
final_data = b"".join(final_data_parts)
# Final validation: Does the combined size match the requested size?
if len(final_data) != size:
raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}")
return final_data
# --- Single-threaded Path ---
else:
# print(f"Using single thread for size {size}") # Optional debug print
headers = common_headers.copy()
if size > -1:
# Range header uses inclusive end byte
range_end = start + size - 1
headers["Range"] = f"bytes={start}-{range_end}"
elif start > 0:
# Request from start offset to the end of the file
headers["Range"] = f"bytes={start}-"
# If start=0 and size=-1, no Range header is needed (get full file)
response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout
response.raise_for_status()
content = response.content
# Validate downloaded size if a specific size was requested
if size > -1 and len(content) != size:
# Check status code - 206 Partial Content is expected for successful range requests
status_code = response.status_code
content_range = response.headers.get('Content-Range')
raise IOError(
f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), "
f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}"
)
return content
@classmethod @classmethod
def check_file_exist(cls, url: str) -> bool: def check_file_exist(cls, url: str) -> bool:
@ -241,17 +358,13 @@ class SafetensorRemote:
Check if a file exists at the given URL. Check if a file exists at the given URL.
Returns True if the file exists, False otherwise. Returns True if the file exists, False otherwise.
""" """
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url) parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc: if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}") raise ValueError(f"Invalid URL: {url}")
try: try:
headers = {"Range": "bytes=0-0"} headers = cls._get_request_headers()
if os.environ.get("HF_TOKEN"): headers["Range"] = "bytes=0-0" # Request a small range to check existence
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
response = requests.head(url, allow_redirects=True, headers=headers) response = requests.head(url, allow_redirects=True, headers=headers)
# Success (2xx) or redirect (3xx) # Success (2xx) or redirect (3xx)
return 200 <= response.status_code < 400 return 200 <= response.status_code < 400