Files
llama.cpp/tests/test-tokenizers-remote.cpp
Sigbjørn Skjæret f9a27178e5 download in batches
2025-05-31 22:35:26 +02:00

165 lines
6.6 KiB
C++

#include "arg.h"
#include "common.h"
#include <string>
#include <fstream>
#include <vector>
#include <nlohmann/json.hpp>
using json = nlohmann::json;
#undef NDEBUG
#include <cassert>
std::string endpoint = "https://huggingface.co/";
std::string repo = "ggml-org/vocabs";
static void write_file(const std::string & fname, const std::string & content) {
std::ofstream file(fname);
if (file) {
file << content;
file.close();
}
}
static json get_hf_repo_dir(const std::string & hf_repo_with_branch, bool recursive, const std::string & repo_path, const std::string & bearer_token) {
auto parts = string_split<std::string>(hf_repo_with_branch, ':');
std::string branch = parts.size() > 1 ? parts.back() : "main";
std::string hf_repo = parts[0];
std::string url = endpoint + "api/models/" + hf_repo + "/tree/" + branch;
std::string path = repo_path;
if (!path.empty()) {
// FIXME: path should be properly url-encoded!
string_replace_all(path, "/", "%2F");
url += "/" + path;
}
if (recursive) {
url += "?recursive=true";
}
// headers
std::vector<std::string> headers;
headers.push_back("Accept: application/json");
if (!bearer_token.empty()) {
headers.push_back("Authorization: Bearer " + bearer_token);
}
// we use "=" to avoid clashing with other component, while still being allowed on windows
std::string cached_response_fname = "test_vocab=" + hf_repo + "/" + repo_path + "=" + branch + ".json";
string_replace_all(cached_response_fname, "/", "_");
std::string cached_response_path = fs_get_cache_file(cached_response_fname);
// make the request
common_remote_params params;
params.headers = headers;
json res_data;
try {
// TODO: For pagination links we need response headers, which is not provided by common_remote_get_content()
auto res = common_remote_get_content(url, params);
long res_code = res.first;
std::string res_str = std::string(res.second.data(), res.second.size());
if (res_code == 200) {
write_file(cached_response_path, res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
} catch (const std::exception & e) {
fprintf(stderr, "error: failed to get repo tree: %s\n", e.what());
fprintf(stderr, "try reading from cache\n");
}
// try to read from cache
try {
std::ifstream f(cached_response_path);
res_data = json::parse(f);
} catch (const std::exception & e) {
fprintf(stderr, "error: failed to get repo tree (check your internet connection)\n");
}
return res_data;
}
int main(void) {
if (common_has_curl()) {
json tree = get_hf_repo_dir(repo, true, {}, {});
if (!tree.empty()) {
std::vector<std::pair<std::string, std::string>> files;
for (const auto & item : tree) {
if (item.at("type") == "file") {
std::string path = item.at("path");
if (string_ends_with(path, ".gguf") || string_ends_with(path, ".gguf.inp") || string_ends_with(path, ".gguf.out")) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filepath = repo + "_" + path;
// to make sure we don't have any slashes in the filename
string_replace_all(filepath, "/", "_");
// to make sure we don't have any quotes in the filename
string_replace_all(filepath, "'", "_");
filepath = fs_get_cache_file(filepath);
files.push_back({endpoint + repo + "/resolve/main/" + path, filepath});
}
}
}
if (!files.empty()) {
bool downloaded = false;
const size_t batch_size = 6;
size_t batches = (files.size() + batch_size - 1) / batch_size;
for (size_t i = 0; i < batches; i++) {
size_t batch_pos = (i * batch_size);
size_t batch_step = batch_pos + batch_size;
auto batch_begin = files.begin() + batch_pos;
auto batch_end = batch_step >= files.size() ? files.end() : files.begin() + batch_step;
std::vector<std::pair<std::string, std::string>> batch(batch_begin, batch_end);
if (!(downloaded = common_download_file_multiple(batch, {}, false))) {
break;
}
}
if (downloaded) {
std::string dir_sep(1, DIRECTORY_SEPARATOR);
for (auto const & item : files) {
std::string filepath = item.second;
if (string_ends_with(filepath, ".gguf")) {
std::string vocab_inp = filepath + ".inp";
std::string vocab_out = filepath + ".out";
auto matching_inp = std::find_if(files.begin(), files.end(), [&vocab_inp](const auto & p) {
return p.second == vocab_inp;
});
auto matching_out = std::find_if(files.begin(), files.end(), [&vocab_out](const auto & p) {
return p.second == vocab_out;
});
if (matching_inp != files.end() && matching_out != files.end()) {
std::string test_command = "." + dir_sep + "test-tokenizer-0 '" + filepath + "'";
assert(std::system(test_command.c_str()) == 0);
} else {
printf("test-tokenizers-remote: %s found without .inp/out vocab files, skipping...\n", filepath.c_str());
}
}
}
} else {
printf("test-tokenizers-remote: failed to download files, unable to perform tests...\n");
}
}
} else {
printf("test-tokenizers-remote: failed to retrieve repository info, unable to perform tests...\n");
}
} else {
printf("test-tokenizers-remote: no curl, unable to perform tests...\n");
}
}