diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 998d8723d..c8b548940 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4185,15 +4185,22 @@ class XLMRobertaModel(BertModel): if self._position_offset is not None: data_torch = data_torch[self._position_offset:,:] - if name.endswith(".lora_A"): - # TODO: convert loras - return [] + if name.endswith(".weight.0.lora_A") or name.endswith(".weight.0.lora_B"): + if name.startswith("pooler.dense"): + return - if name.endswith(".lora_B"): - # TODO: convert loras - return [] + lora_name = self.hparams["lora_adaptations"] + num_loras = data_torch.size(0) + assert num_loras == len(lora_name) - return super().modify_tensors(data_torch, name, bid) + # Split out each LoRA in their own named tensors + # Remove "weight" from the name to not confuse quantize + for i in range(num_loras): + data_lora = data_torch[i, :, :] + yield (self.map_tensor_name(name[:-16]) + name[-16:].lower().replace("weight.0.", f"<{lora_name[i]}>"), data_lora) + return + + yield from super().modify_tensors(data_torch, name, bid) def set_gguf_parameters(self): super().set_gguf_parameters() @@ -4201,6 +4208,13 @@ class XLMRobertaModel(BertModel): # jina-embeddings-v3 if rotary_emb_base := self.hparams.get("rotary_emb_base"): self.gguf_writer.add_rope_freq_base(rotary_emb_base) + if lora_alpha := self.hparams.get("lora_alpha"): + self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha) + if lora_names := self.hparams.get("lora_adaptations"): + self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_NAMES, lora_names) + if lora_prompt_prefixes := self.hparams.get("task_instructions"): + assert lora_names and all(lora_name in lora_prompt_prefixes for lora_name in lora_names) + self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_PROMPT_PREFIXES, [lora_prompt_prefixes[lora_name] for lora_name in lora_names]) @ModelBase.register("GemmaForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6d3d596d6..5b62dc3ca 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -227,8 +227,10 @@ class Keys: MIDDLE_ID = "tokenizer.ggml.middle_token_id" class Adapter: - TYPE = "adapter.type" - LORA_ALPHA = "adapter.lora.alpha" + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + LORA_NAMES = "adapter.lora.names" + LORA_PROMPT_PREFIXES = "adapter.lora.prompt_prefixes" class Clip: PROJECTOR_TYPE = "clip.projector_type" diff --git a/src/llama-adapter.h b/src/llama-adapter.h index 65824e972..1439ebdeb 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -66,6 +66,7 @@ struct llama_adapter_lora { std::vector bufs; float alpha; + std::string prompt_prefix; llama_adapter_lora() = default; ~llama_adapter_lora() = default; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index cbec31954..223aea961 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -217,8 +217,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_LORA_NAMES, "adapter.lora.names" }, + { LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, "adapter.lora.prompt_prefixes" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 89608ca63..657a3365a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -215,6 +215,8 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + LLM_KV_ADAPTER_LORA_NAMES, + LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, LLM_KV_POSNET_EMBEDDING_LENGTH, LLM_KV_POSNET_BLOCK_COUNT, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e15a6cb83..6949f697a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1720,6 +1720,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_type_t first_moved_from_buft = nullptr; ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + auto add_lora_tensors = [&](const std::string & lora_name, const std::string & tensor_name) -> void { + std::string base_name = tensor_name.substr(0, tensor_name.size() - 6); + + ggml_tensor * lora_a = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_a").c_str()); + ggml_tensor * lora_b = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_b").c_str()); + loras[lora_name]->ab_map[tensor_name] = llama_adapter_lora_weight(lora_a, lora_b); + + ml.n_created += 2; + }; + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); @@ -2246,6 +2256,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_JINA_BERT_V3: { + std::vector lora_names; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -2262,6 +2274,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + if (arch == LLM_ARCH_JINA_BERT_V3) { + float lora_alpha = 1.0f; + std::vector lora_prompt_prefixes; + + ml.get_key(LLM_KV_ADAPTER_LORA_ALPHA, lora_alpha, false); + ml.get_arr(LLM_KV_ADAPTER_LORA_NAMES, lora_names, false); + ml.get_arr(LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, lora_prompt_prefixes, false); + GGML_ASSERT(lora_names.size() == lora_prompt_prefixes.size()); + + for (size_t i = 0; i < lora_names.size(); ++i) { + llama_adapter_lora * adapter = new llama_adapter_lora(); + std::string lora_name = lora_names[i]; + + adapter->alpha = lora_alpha; + adapter->prompt_prefix = lora_prompt_prefixes[i]; + loras[lora_name] = adapter; + + add_lora_tensors(lora_name, tok_embd->name); + + if (type_embd) { + add_lora_tensors(lora_name, type_embd->name); + } + } + } + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2300,6 +2337,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } + if (arch == LLM_ARCH_JINA_BERT_V3) { + GGML_ASSERT(layer.wqkv != nullptr); + + for (const auto & lora_name : lora_names) { + add_lora_tensors(lora_name, layer.wqkv->name); + add_lora_tensors(lora_name, layer.wo->name); + add_lora_tensors(lora_name, layer.ffn_up->name); + add_lora_tensors(lora_name, layer.ffn_down->name); + } + } + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } diff --git a/src/llama-model.h b/src/llama-model.h index 9a7d8727b..21f217aaf 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -7,6 +7,7 @@ #include "llama-memory.h" #include "llama-vocab.h" +#include #include #include #include @@ -383,6 +384,9 @@ struct llama_model { llama_model_params params; + // built-in LoRAs + std::map loras; + // gguf metadata std::unordered_map gguf_kv;