feat: Add llama_model_is_hybrid API call

Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-05-09 15:21:29 -06:00
parent c46503014d
commit ec8fe17b1a
4 changed files with 33 additions and 8 deletions

View File

@ -572,6 +572,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,

View File

@ -1816,3 +1816,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}

View File

@ -439,3 +439,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch& arch);
bool llm_arch_is_hybrid(const llm_arch& arch);

View File

@ -14377,14 +14377,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
}
bool llama_model_is_recurrent(const llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV6: return true;
case LLM_ARCH_RWKV6QWEN2: return true;
case LLM_ARCH_RWKV7: return true;
case LLM_ARCH_ARWKV7: return true;
default: return false;
}
return llm_arch_is_recurrent(model->arch);
}
bool llama_model_is_hybrid(const llama_model * model) {
return llm_arch_is_hybrid(model->arch);
}
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {