llama : add option to override model tensor buffers (#11397)

* llama : add option to override tensor buffers

* ggml : fix possible underflow in ggml_nbytes
This commit is contained in:
Diego Devesa
2025-04-02 14:52:01 +02:00
committed by GitHub
parent a10b36c91a
commit e0e912f49b
12 changed files with 108 additions and 9 deletions

View File

@ -17,6 +17,7 @@
#include <cmath>
#include <functional>
#include <map>
#include <regex>
#include <sstream>
#include <stdexcept>
@ -378,9 +379,12 @@ struct llama_model::impl {
layer_dev dev_input = {};
layer_dev dev_output = {};
std::vector<layer_dev> dev_layer;
bool has_tensor_overrides;
};
llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
}
llama_model::~llama_model() {}
@ -1571,9 +1575,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
}
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
ggml_backend_buffer_type_t buft = nullptr;
// check overrides
if (ml.tensor_buft_overrides) {
std::string tensor_name = tn.str();
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
std::regex pattern(overrides->pattern);
if (std::regex_search(tensor_name, pattern)) {
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
buft = overrides->buft;
break;
}
}
}
if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
buft = select_weight_buft(hparams, t_meta, op, *buft_list);
if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
}
}
// avoid using a host buffer when using mmap
@ -4151,6 +4172,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
});
}
bool llama_model::has_tensor_overrides() const {
return pimpl->has_tensor_overrides;
}
const ggml_tensor * llama_model::get_tensor(const char * name) const {
auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
[name](const std::pair<std::string, ggml_tensor *> & it) {
@ -12319,6 +12344,7 @@ llm_graph_result_ptr llama_model::build_graph(
llama_model_params llama_model_default_params() {
llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,