mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
||||
bool llama_hparams::recurrent_layer(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return swa_layers[il];
|
||||
|
@ -115,6 +115,9 @@ struct llama_hparams {
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
@ -181,10 +184,13 @@ struct llama_hparams {
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_k_s() const;
|
||||
uint32_t n_embd_k_s(uint32_t il = 0) const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s() const;
|
||||
uint32_t n_embd_v_s(uint32_t il = 0) const;
|
||||
|
||||
// whether or not the given layer is recurrent (for hybrid models)
|
||||
bool recurrent_layer(uint32_t il) const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
};
|
||||
|
Reference in New Issue
Block a user