common : add common_remote_get_content (#13123)

* common : add common_remote_get_content

* support max size and timeout

* add tests
This commit is contained in:
Xuan-Son Nguyen
2025-04-26 22:58:12 +02:00
committed by GitHub
parent 4753791e70
commit 2d451c8059
3 changed files with 126 additions and 33 deletions

View File

@ -162,6 +162,10 @@ struct common_hf_file_res {
#ifdef LLAMA_USE_CURL
bool common_has_curl() {
return true;
}
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
@ -527,6 +531,50 @@ static bool common_download_model(
return true;
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::vector<char> res_buffer;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
auto data_vec = static_cast<std::vector<char> *>(data);
data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (params.timeout > 0) {
curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
}
if (params.max_size > 0) {
curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
}
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
for (const auto & header : params.headers) {
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
std::string error_msg = curl_easy_strerror(res);
throw std::runtime_error("error: cannot make GET request: " + error_msg);
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
return { res_code, std::move(res_buffer) };
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
@ -546,45 +594,26 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
std::string model_endpoint = get_model_endpoint();
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// headers
std::vector<std::string> headers;
headers.push_back("Accept: application/json");
if (!bearer_token.empty()) {
std::string auth_header = "Authorization: Bearer " + bearer_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
headers.push_back("Authorization: Bearer " + bearer_token);
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
// User-Agent header is already set in common_remote_get_content, no need to set it here
CURLcode res = curl_easy_perform(curl.get());
// make the request
common_remote_params params;
params.headers = headers;
auto res = common_remote_get_content(url, params);
long res_code = res.first;
std::string res_str(res.second.data(), res.second.size());
std::string ggufFile;
std::string mmprojFile;
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
std::string ggufFile = "";
std::string mmprojFile = "";
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
// extract ggufFile.rfilename in json, using regex
{
@ -618,6 +647,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
#else
bool common_has_curl() {
return false;
}
static bool common_download_file_single(const std::string &, const std::string &, const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from internet\n");
return false;
@ -640,6 +673,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s
return {};
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
throw std::runtime_error("error: built without CURL, cannot download model from the internet");
}
#endif // LLAMA_USE_CURL
//

View File

@ -78,3 +78,12 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
// function to be used by test-arg-parser
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
bool common_has_curl();
struct common_remote_params {
std::vector<std::string> headers;
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);

View File

@ -126,6 +126,53 @@ int main(void) {
assert(params.cpuparams.n_threads == 1010);
#endif // _WIN32
if (common_has_curl()) {
printf("test-arg-parser: test curl-related functions\n\n");
const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md";
const char * BAD_URL = "https://www.google.com/404";
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
{
printf("test-arg-parser: test good URL\n\n");
auto res = common_remote_get_content(GOOD_URL, {});
assert(res.first == 200);
assert(res.second.size() > 0);
std::string str(res.second.data(), res.second.size());
assert(str.find("llama.cpp") != std::string::npos);
}
{
printf("test-arg-parser: test bad URL\n\n");
auto res = common_remote_get_content(BAD_URL, {});
assert(res.first == 404);
}
{
printf("test-arg-parser: test max size error\n");
common_remote_params params;
params.max_size = 1;
try {
common_remote_get_content(GOOD_URL, params);
assert(false && "it should throw an error");
} catch (std::exception & e) {
printf(" expected error: %s\n\n", e.what());
}
}
{
printf("test-arg-parser: test timeout error\n");
common_remote_params params;
params.timeout = 1;
try {
common_remote_get_content(BIG_FILE, params);
assert(false && "it should throw an error");
} catch (std::exception & e) {
printf(" expected error: %s\n\n", e.what());
}
}
} else {
printf("test-arg-parser: no curl, skipping curl-related functions\n");
}
printf("test-arg-parser: all tests OK\n\n");
}