mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
llama : add support for BertForSequenceClassification reranker (#13858)
* convert: add support for BertForSequenceClassification * add support for reranking using BertForSequenceClassification * merge checks of eos and sep * fix lint --------- Co-authored-by: dinhhuy <huy.dinh@brains-tech.co.jp>
This commit is contained in:
@ -903,13 +903,16 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
if (!has_eos && !has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
} else if (!has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
|
@ -3682,7 +3682,7 @@ class InternLM3Model(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
|
||||
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification")
|
||||
class BertModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
@ -3745,6 +3745,13 @@ class BertModel(TextModel):
|
||||
if name.startswith("cls.seq_relationship"):
|
||||
return []
|
||||
|
||||
# For BertForSequenceClassification (direct projection layer)
|
||||
if name == "classifier.weight":
|
||||
name = "classifier.out_proj.weight"
|
||||
|
||||
if name == "classifier.bias":
|
||||
name = "classifier.out_proj.bias"
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def _xlmroberta_tokenizer_init(self) -> None:
|
||||
|
@ -1562,20 +1562,25 @@ void llm_graph_context::build_pooling(
|
||||
ggml_tensor * inp_cls = build_inp_cls();
|
||||
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
GGML_ASSERT(cls != nullptr);
|
||||
GGML_ASSERT(cls_b != nullptr);
|
||||
if (cls != nullptr && cls_b != nullptr) {
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
|
||||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
||||
if (cls_out) {
|
||||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
||||
if (cls_out) {
|
||||
GGML_ASSERT(cls_out_b != nullptr);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
||||
}
|
||||
} else if (cls_out) {
|
||||
// Single layer classification head (direct projection)
|
||||
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
||||
GGML_ASSERT(cls_out_b != nullptr);
|
||||
|
||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
|
||||
} else {
|
||||
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
|
@ -264,13 +264,19 @@ static size_t validate_utf8(const std::string& text) {
|
||||
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
|
||||
llama_tokens result;
|
||||
|
||||
// Get EOS token - use SEP token as fallback if EOS is not available
|
||||
llama_token eos_token = llama_vocab_eos(vocab);
|
||||
if (eos_token == LLAMA_TOKEN_NULL) {
|
||||
eos_token = llama_vocab_sep(vocab);
|
||||
}
|
||||
|
||||
result.reserve(doc.size() + query.size() + 4);
|
||||
result.push_back(llama_vocab_bos(vocab));
|
||||
result.insert(result.end(), query.begin(), query.end());
|
||||
result.push_back(llama_vocab_eos(vocab));
|
||||
result.push_back(eos_token);
|
||||
result.push_back(llama_vocab_sep(vocab));
|
||||
result.insert(result.end(), doc.begin(), doc.end());
|
||||
result.push_back(llama_vocab_eos(vocab));
|
||||
result.push_back(eos_token);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user