llama : allow custom list of swa_layers (#13726)

This commit is contained in:
Xuan-Son Nguyen
2025-05-23 17:07:04 +02:00
committed by GitHub
parent 9ecf3e66a3
commit 8a2afb7520
3 changed files with 54 additions and 23 deletions

View File

@ -2,6 +2,26 @@
#include "ggml.h" #include "ggml.h"
llama_hparams::llama_hparams() {
swa_layers.fill(false);
}
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}
bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) {
if (swa_layers[il]) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_head(uint32_t il) const { uint32_t llama_hparams::n_head(uint32_t il) const {
if (il < n_layer) { if (il < n_layer) {
return n_head_arr[il]; return n_head_arr[il];
@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
bool llama_hparams::is_swa(uint32_t il) const { bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) { if (il < n_layer) {
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1)); return swa_layers[il];
} }
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View File

@ -102,20 +102,12 @@ struct llama_hparams {
// Sliding Window Attention (SWA) // Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0;
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA) // if swa_layers[il] == true, then layer il is SWA
// by default n == 1, all layers are dense // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// note that if n_swa_pattern == 0, all layers are SWA // by default, all layers are dense
// example: n_swa_pattern = 3 std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
// for State Space Models // for State Space Models
uint32_t ssm_d_conv = 0; uint32_t ssm_d_conv = 0;
@ -153,6 +145,25 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
llama_hparams();
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense
// example: n_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern);
// return true if one of the layers is SWA
bool is_swa_any() const;
uint32_t n_head(uint32_t il = 0) const; uint32_t n_head(uint32_t il = 0) const;
uint32_t n_head_kv(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const;

View File

@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
switch (hparams.n_expert) { switch (hparams.n_expert) {
case 16: type = LLM_TYPE_17B_16E; break; case 16: type = LLM_TYPE_17B_16E; break;
@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_swa = 0; hparams.n_swa = 0;
hparams.n_swa_pattern = 1; hparams.set_swa_pattern(1);
} }
} break; } break;
case LLM_ARCH_PHIMOE: case LLM_ARCH_PHIMOE:
@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096; // default value of gemma 2 hparams.n_swa = 4096; // default value of gemma 2
hparams.n_swa_pattern = 2; hparams.set_swa_pattern(2);
hparams.attn_soft_cap = true; hparams.attn_soft_cap = true;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 6; hparams.set_swa_pattern(6);
hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f; hparams.rope_freq_scale_train_swa = 1.0f;
@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_COHERE2: case LLM_ARCH_COHERE2:
{ {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa_pattern = 4; hparams.set_swa_pattern(4);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.n_swa_pattern != 1); GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa( res = new llama_kv_cache_unified_iswa(
*this, *this,
@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_batch, cparams.n_batch,
padding); padding);
} else { } else {
GGML_ASSERT(hparams.n_swa_pattern == 1); GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified( res = new llama_kv_cache_unified(
*this, *this,