diff --git a/include/llama.h b/include/llama.h index 635508b10..168059cdc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -572,6 +572,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index de8d289cf..d20f2bf26 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1816,3 +1816,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 3e8a61da3..0c248f72d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -439,3 +439,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5eb122f9..bac1a07c4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14377,14 +14377,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) {