feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-05-09 15:04:36 -06:00
parent 5e2f2c3876
commit 05f1958080
2 changed files with 20 additions and 4 deletions

View File

@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}
uint32_t llama_hparams::n_embd_k_s() const {
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// for RWKV models
return token_shift_count * n_embd;
@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}
uint32_t llama_hparams::n_embd_v_s() const {
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
if (!recurrent_layer(il)) {
return 0;
}
if (wkv_head_size != 0) {
// corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::recurrent_layer(uint32_t il) const {
return recurrent_layer_arr[il];
}
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return swa_layers[il];

View File

@ -115,6 +115,9 @@ struct llama_hparams {
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
@ -181,10 +184,13 @@ struct llama_hparams {
// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_k_s() const;
uint32_t n_embd_k_s(uint32_t il = 0) const;
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
uint32_t n_embd_v_s(uint32_t il = 0) const;
// whether or not the given layer is recurrent (for hybrid models)
bool recurrent_layer(uint32_t il) const;
bool is_swa(uint32_t il) const;
};