mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 04:45:17 +00:00
feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||||||
return n_embd_head_v * n_head_kv;
|
return n_embd_head_v * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_embd_k_s() const {
|
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
|
||||||
|
if (!recurrent_layer(il)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
if (wkv_head_size != 0) {
|
if (wkv_head_size != 0) {
|
||||||
// for RWKV models
|
// for RWKV models
|
||||||
return token_shift_count * n_embd;
|
return token_shift_count * n_embd;
|
||||||
@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_embd_v_s() const {
|
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
|
||||||
|
if (!recurrent_layer(il)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
if (wkv_head_size != 0) {
|
if (wkv_head_size != 0) {
|
||||||
// corresponds to RWKV's wkv_states size
|
// corresponds to RWKV's wkv_states size
|
||||||
return n_embd * wkv_head_size;
|
return n_embd * wkv_head_size;
|
||||||
@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::recurrent_layer(uint32_t il) const {
|
||||||
|
return recurrent_layer_arr[il];
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return swa_layers[il];
|
return swa_layers[il];
|
||||||
|
@ -115,6 +115,9 @@ struct llama_hparams {
|
|||||||
uint32_t ssm_d_state = 0;
|
uint32_t ssm_d_state = 0;
|
||||||
uint32_t ssm_dt_rank = 0;
|
uint32_t ssm_dt_rank = 0;
|
||||||
|
|
||||||
|
// for hybrid state space models
|
||||||
|
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||||
|
|
||||||
bool ssm_dt_b_c_rms = false;
|
bool ssm_dt_b_c_rms = false;
|
||||||
|
|
||||||
float f_clamp_kqv = 0.0f;
|
float f_clamp_kqv = 0.0f;
|
||||||
@ -181,10 +184,13 @@ struct llama_hparams {
|
|||||||
|
|
||||||
// dimension of the rolling state embeddings
|
// dimension of the rolling state embeddings
|
||||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||||
uint32_t n_embd_k_s() const;
|
uint32_t n_embd_k_s(uint32_t il = 0) const;
|
||||||
|
|
||||||
// dimension of the recurrent state embeddings
|
// dimension of the recurrent state embeddings
|
||||||
uint32_t n_embd_v_s() const;
|
uint32_t n_embd_v_s(uint32_t il = 0) const;
|
||||||
|
|
||||||
|
// whether or not the given layer is recurrent (for hybrid models)
|
||||||
|
bool recurrent_layer(uint32_t il) const;
|
||||||
|
|
||||||
bool is_swa(uint32_t il) const;
|
bool is_swa(uint32_t il) const;
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user