diff --git a/common/arg.cpp b/common/arg.cpp index 0657553e4..de173159f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -162,6 +162,10 @@ struct common_hf_file_res { #ifdef LLAMA_USE_CURL +bool common_has_curl() { + return true; +} + #ifdef __linux__ #include #elif defined(_WIN32) @@ -527,6 +531,50 @@ static bool common_download_model( return true; } +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::vector res_buffer; + + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + auto data_vec = static_cast *>(data); + data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (params.timeout > 0) { + curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); + } + if (params.max_size > 0) { + curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); + } + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { + http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + std::string error_msg = curl_easy_strerror(res); + throw std::runtime_error("error: cannot make GET request: " + error_msg); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + + return { res_code, std::move(res_buffer) }; +} + /** * Allow getting the HF file from the HF repo with tag (like ollama), for example: * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 @@ -546,45 +594,26 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); } - // fetch model info from Hugging Face Hub API - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::string res_str; + std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; - std::string model_endpoint = get_model_endpoint(); - - std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif + // headers + std::vector headers; + headers.push_back("Accept: application/json"); if (!bearer_token.empty()) { - std::string auth_header = "Authorization: Bearer " + bearer_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + headers.push_back("Authorization: Bearer " + bearer_token); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + // User-Agent header is already set in common_remote_get_content, no need to set it here - CURLcode res = curl_easy_perform(curl.get()); + // make the request + common_remote_params params; + params.headers = headers; + auto res = common_remote_get_content(url, params); + long res_code = res.first; + std::string res_str(res.second.data(), res.second.size()); + std::string ggufFile; + std::string mmprojFile; - if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); - } - - long res_code; - std::string ggufFile = ""; - std::string mmprojFile = ""; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); if (res_code == 200) { // extract ggufFile.rfilename in json, using regex { @@ -618,6 +647,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ #else +bool common_has_curl() { + return false; +} + static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; @@ -640,6 +673,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s return {}; } +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { + throw std::runtime_error("error: built without CURL, cannot download model from the internet"); +} + #endif // LLAMA_USE_CURL // diff --git a/common/arg.h b/common/arg.h index 49ab8667b..70bea100f 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,3 +78,12 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); +bool common_has_curl(); + +struct common_remote_params { + std::vector headers; + long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout + long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB +}; +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 537fc63a4..21dbd5404 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -126,6 +126,53 @@ int main(void) { assert(params.cpuparams.n_threads == 1010); #endif // _WIN32 + if (common_has_curl()) { + printf("test-arg-parser: test curl-related functions\n\n"); + const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md"; + const char * BAD_URL = "https://www.google.com/404"; + const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin"; + + { + printf("test-arg-parser: test good URL\n\n"); + auto res = common_remote_get_content(GOOD_URL, {}); + assert(res.first == 200); + assert(res.second.size() > 0); + std::string str(res.second.data(), res.second.size()); + assert(str.find("llama.cpp") != std::string::npos); + } + + { + printf("test-arg-parser: test bad URL\n\n"); + auto res = common_remote_get_content(BAD_URL, {}); + assert(res.first == 404); + } + + { + printf("test-arg-parser: test max size error\n"); + common_remote_params params; + params.max_size = 1; + try { + common_remote_get_content(GOOD_URL, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); + } + } + + { + printf("test-arg-parser: test timeout error\n"); + common_remote_params params; + params.timeout = 1; + try { + common_remote_get_content(BIG_FILE, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); + } + } + } else { + printf("test-arg-parser: no curl, skipping curl-related functions\n"); + } printf("test-arg-parser: all tests OK\n\n"); }