mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 20:58:45 +00:00
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@ -572,6 +572,9 @@ extern "C" {
|
||||
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||
|
||||
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
|
||||
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
|
@ -1816,3 +1816,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
switch (arch) {
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -439,3 +439,6 @@ const char * llm_arch_name(llm_arch arch);
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch& arch);
|
||||
bool llm_arch_is_hybrid(const llm_arch& arch);
|
||||
|
@ -14377,14 +14377,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
|
||||
}
|
||||
|
||||
bool llama_model_is_recurrent(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_MAMBA: return true;
|
||||
case LLM_ARCH_RWKV6: return true;
|
||||
case LLM_ARCH_RWKV6QWEN2: return true;
|
||||
case LLM_ARCH_RWKV7: return true;
|
||||
case LLM_ARCH_ARWKV7: return true;
|
||||
default: return false;
|
||||
}
|
||||
return llm_arch_is_recurrent(model->arch);
|
||||
}
|
||||
|
||||
bool llama_model_is_hybrid(const llama_model * model) {
|
||||
return llm_arch_is_hybrid(model->arch);
|
||||
}
|
||||
|
||||
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
||||
|
Reference in New Issue
Block a user