mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
llama : support multiple classifier outputs and labels (#13940)
This commit is contained in:
@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
uint32_t n_vocab = 0;
|
||||
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
|
||||
|
||||
// for classifier models
|
||||
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
|
||||
if (!classifier_labels.empty()) {
|
||||
hparams.n_cls_out = classifier_labels.size();
|
||||
}
|
||||
|
||||
// arch-specific KVs
|
||||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
@ -4362,6 +4367,15 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
||||
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
||||
|
||||
if (!classifier_labels.empty()) {
|
||||
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
|
||||
|
||||
size_t i = 0;
|
||||
for (auto label : classifier_labels) {
|
||||
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
|
||||
@ -13602,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
|
||||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
|
||||
return model->hparams.n_cls_out;
|
||||
}
|
||||
|
||||
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
|
||||
if (i < model->classifier_labels.size()) {
|
||||
return model->classifier_labels[i].c_str();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||
return llama_model_n_ctx_train(model);
|
||||
|
Reference in New Issue
Block a user