llama : support multiple classifier outputs and labels (#13940)

This commit is contained in:
Sigbjørn Skjæret
2025-06-06 09:03:25 +02:00
committed by GitHub
parent 1caae7fc6c
commit d17a809ef0
6 changed files with 101 additions and 24 deletions

View File

@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
uint32_t n_vocab = 0;
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
// for classifier models
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
if (!classifier_labels.empty()) {
hparams.n_cls_out = classifier_labels.size();
}
// arch-specific KVs
switch (arch) {
case LLM_ARCH_LLAMA:
@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
switch (hparams.n_layer) {
case 3:
@ -4362,6 +4367,15 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) {
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
size_t i = 0;
for (auto label : classifier_labels) {
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
}
}
}
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@ -13602,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa;
}
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
if (i < model->classifier_labels.size()) {
return model->classifier_labels[i].c_str();
}
return nullptr;
}
// deprecated
int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model);