diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a..70a7114f3 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01..361459646 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -181,10 +184,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; };