mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
multithreaded download
This commit is contained in:
@ -1,10 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal, Any
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import requests
|
||||||
|
import threading
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||||
@ -110,6 +113,10 @@ class SafetensorRemote:
|
|||||||
BASE_DOMAIN = "https://huggingface.co"
|
BASE_DOMAIN = "https://huggingface.co"
|
||||||
ALIGNMENT = 8 # bytes
|
ALIGNMENT = 8 # bytes
|
||||||
|
|
||||||
|
# start using multithread download for files larger than 100MB
|
||||||
|
MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024
|
||||||
|
MULTITHREAD_COUNT = 8 # number of threads
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
||||||
"""
|
"""
|
||||||
@ -211,29 +218,139 @@ class SafetensorRemote:
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_request_headers(cls) -> dict[str, str]:
|
||||||
|
"""Prepare common headers for requests."""
|
||||||
|
headers = {"User-Agent": "convert_hf_to_gguf"}
|
||||||
|
if os.environ.get("HF_TOKEN"):
|
||||||
|
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||||
|
return headers
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
||||||
"""
|
"""
|
||||||
Get raw byte data from a remote file by range.
|
Get raw byte data from a remote file by range using single or multi-threaded download.
|
||||||
If size is not specified, it will read the entire file.
|
|
||||||
"""
|
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
|
If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only).
|
||||||
|
If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads.
|
||||||
|
Otherwise, it uses a single request.
|
||||||
|
"""
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if not parsed_url.scheme or not parsed_url.netloc:
|
if not parsed_url.scheme or not parsed_url.netloc:
|
||||||
raise ValueError(f"Invalid URL: {url}")
|
raise ValueError(f"Invalid URL: {url}")
|
||||||
|
|
||||||
headers = {}
|
common_headers = cls._get_request_headers()
|
||||||
if os.environ.get("HF_TOKEN"):
|
|
||||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
|
||||||
if size > -1:
|
|
||||||
headers["Range"] = f"bytes={start}-{start + size}"
|
|
||||||
response = requests.get(url, allow_redirects=True, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# Get raw byte data
|
# --- Multithreading Path ---
|
||||||
return response.content[:size]
|
if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1:
|
||||||
|
# print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB")
|
||||||
|
num_threads = cls.MULTITHREAD_COUNT
|
||||||
|
results: list[Any] = [None] * num_threads # Store results or exceptions
|
||||||
|
threads: list[threading.Thread] = []
|
||||||
|
|
||||||
|
def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict):
|
||||||
|
"""Worker function for thread."""
|
||||||
|
thread_headers = headers.copy()
|
||||||
|
# Range header is inclusive end byte
|
||||||
|
range_end = chunk_start + chunk_size - 1
|
||||||
|
thread_headers["Range"] = f"bytes={chunk_start}-{range_end}"
|
||||||
|
try:
|
||||||
|
# Using stream=False should make requests wait for content download
|
||||||
|
response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout
|
||||||
|
response.raise_for_status() # Check for HTTP errors
|
||||||
|
|
||||||
|
content = response.content
|
||||||
|
if len(content) != chunk_size:
|
||||||
|
# This is a critical check
|
||||||
|
raise IOError(
|
||||||
|
f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. "
|
||||||
|
f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}"
|
||||||
|
)
|
||||||
|
result_list[index] = content
|
||||||
|
except Exception as e:
|
||||||
|
# Store exception to be raised by the main thread
|
||||||
|
# print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print
|
||||||
|
result_list[index] = e
|
||||||
|
|
||||||
|
# Calculate chunk sizes and create/start threads
|
||||||
|
base_chunk_size = size // num_threads
|
||||||
|
remainder = size % num_threads
|
||||||
|
current_offset = start
|
||||||
|
|
||||||
|
for i in range(num_threads):
|
||||||
|
chunk_size = base_chunk_size + (1 if i < remainder else 0)
|
||||||
|
if chunk_size == 0: # Should not happen if size >= threshold but handle defensively
|
||||||
|
results[i] = b"" # Store empty bytes for this "chunk"
|
||||||
|
continue
|
||||||
|
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=download_chunk,
|
||||||
|
args=(url, current_offset, chunk_size, i, results, common_headers),
|
||||||
|
daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this)
|
||||||
|
)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
current_offset += chunk_size # Move offset for the next chunk
|
||||||
|
|
||||||
|
# Wait for all threads to complete
|
||||||
|
for i, thread in enumerate(threads):
|
||||||
|
thread.join() # Wait indefinitely for each thread
|
||||||
|
|
||||||
|
# Check results for errors and concatenate chunks
|
||||||
|
final_data_parts = []
|
||||||
|
for i in range(num_threads):
|
||||||
|
result = results[i]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
# Raise the first exception encountered
|
||||||
|
raise result
|
||||||
|
elif result is None:
|
||||||
|
# This indicates a thread finished without setting its result or exception (unexpected)
|
||||||
|
# Check if it was supposed to download anything
|
||||||
|
expected_chunk_size = base_chunk_size + (1 if i < remainder else 0)
|
||||||
|
if expected_chunk_size > 0:
|
||||||
|
raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.")
|
||||||
|
else:
|
||||||
|
final_data_parts.append(b"") # Append empty bytes for zero-size chunk
|
||||||
|
else:
|
||||||
|
final_data_parts.append(result)
|
||||||
|
|
||||||
|
# Combine the byte chunks
|
||||||
|
final_data = b"".join(final_data_parts)
|
||||||
|
|
||||||
|
# Final validation: Does the combined size match the requested size?
|
||||||
|
if len(final_data) != size:
|
||||||
|
raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}")
|
||||||
|
|
||||||
|
return final_data
|
||||||
|
|
||||||
|
# --- Single-threaded Path ---
|
||||||
|
else:
|
||||||
|
# print(f"Using single thread for size {size}") # Optional debug print
|
||||||
|
headers = common_headers.copy()
|
||||||
|
if size > -1:
|
||||||
|
# Range header uses inclusive end byte
|
||||||
|
range_end = start + size - 1
|
||||||
|
headers["Range"] = f"bytes={start}-{range_end}"
|
||||||
|
elif start > 0:
|
||||||
|
# Request from start offset to the end of the file
|
||||||
|
headers["Range"] = f"bytes={start}-"
|
||||||
|
# If start=0 and size=-1, no Range header is needed (get full file)
|
||||||
|
|
||||||
|
response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout
|
||||||
|
response.raise_for_status()
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
# Validate downloaded size if a specific size was requested
|
||||||
|
if size > -1 and len(content) != size:
|
||||||
|
# Check status code - 206 Partial Content is expected for successful range requests
|
||||||
|
status_code = response.status_code
|
||||||
|
content_range = response.headers.get('Content-Range')
|
||||||
|
raise IOError(
|
||||||
|
f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), "
|
||||||
|
f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_file_exist(cls, url: str) -> bool:
|
def check_file_exist(cls, url: str) -> bool:
|
||||||
@ -241,17 +358,13 @@ class SafetensorRemote:
|
|||||||
Check if a file exists at the given URL.
|
Check if a file exists at the given URL.
|
||||||
Returns True if the file exists, False otherwise.
|
Returns True if the file exists, False otherwise.
|
||||||
"""
|
"""
|
||||||
import requests
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if not parsed_url.scheme or not parsed_url.netloc:
|
if not parsed_url.scheme or not parsed_url.netloc:
|
||||||
raise ValueError(f"Invalid URL: {url}")
|
raise ValueError(f"Invalid URL: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers = {"Range": "bytes=0-0"}
|
headers = cls._get_request_headers()
|
||||||
if os.environ.get("HF_TOKEN"):
|
headers["Range"] = "bytes=0-0" # Request a small range to check existence
|
||||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
|
||||||
response = requests.head(url, allow_redirects=True, headers=headers)
|
response = requests.head(url, allow_redirects=True, headers=headers)
|
||||||
# Success (2xx) or redirect (3xx)
|
# Success (2xx) or redirect (3xx)
|
||||||
return 200 <= response.status_code < 400
|
return 200 <= response.status_code < 400
|
||||||
|
Reference in New Issue
Block a user