mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
llama : allow custom list of swa_layers (#13726)
This commit is contained in:
@ -2,6 +2,26 @@
|
|||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
|
llama_hparams::llama_hparams() {
|
||||||
|
swa_layers.fill(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::is_swa_any() const {
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
if (swa_layers[il]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return n_head_arr[il];
|
return n_head_arr[il];
|
||||||
@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|||||||
|
|
||||||
bool llama_hparams::is_swa(uint32_t il) const {
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
if (il < n_layer) {
|
if (il < n_layer) {
|
||||||
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
|
return swa_layers[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
@ -102,20 +102,12 @@ struct llama_hparams {
|
|||||||
|
|
||||||
// Sliding Window Attention (SWA)
|
// Sliding Window Attention (SWA)
|
||||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
// the size of the sliding window (0 - no SWA)
|
||||||
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
|
uint32_t n_swa = 0;
|
||||||
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
|
// if swa_layers[il] == true, then layer il is SWA
|
||||||
// by default n == 1, all layers are dense
|
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||||
// note that if n_swa_pattern == 0, all layers are SWA
|
// by default, all layers are dense
|
||||||
// example: n_swa_pattern = 3
|
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||||
// il == 0: swa
|
|
||||||
// il == 1: swa
|
|
||||||
// il == 2: dense
|
|
||||||
// il == 3: swa
|
|
||||||
// il == 4: swa
|
|
||||||
// il == 5: dense
|
|
||||||
// il == 6: swa
|
|
||||||
// etc ...
|
|
||||||
|
|
||||||
// for State Space Models
|
// for State Space Models
|
||||||
uint32_t ssm_d_conv = 0;
|
uint32_t ssm_d_conv = 0;
|
||||||
@ -153,6 +145,25 @@ struct llama_hparams {
|
|||||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||||
|
|
||||||
|
llama_hparams();
|
||||||
|
|
||||||
|
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||||
|
// note that if n_pattern == 0, all layers are SWA
|
||||||
|
// if n_pattern == 1, all layers are dense
|
||||||
|
// example: n_pattern = 3
|
||||||
|
// il == 0: swa
|
||||||
|
// il == 1: swa
|
||||||
|
// il == 2: dense
|
||||||
|
// il == 3: swa
|
||||||
|
// il == 4: swa
|
||||||
|
// il == 5: dense
|
||||||
|
// il == 6: swa
|
||||||
|
// etc ...
|
||||||
|
void set_swa_pattern(uint32_t n_pattern);
|
||||||
|
|
||||||
|
// return true if one of the layers is SWA
|
||||||
|
bool is_swa_any() const;
|
||||||
|
|
||||||
uint32_t n_head(uint32_t il = 0) const;
|
uint32_t n_head(uint32_t il = 0) const;
|
||||||
|
|
||||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||||
|
@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
|
|
||||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||||
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||||
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
|
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||||
|
|
||||||
switch (hparams.n_expert) {
|
switch (hparams.n_expert) {
|
||||||
case 16: type = LLM_TYPE_17B_16E; break;
|
case 16: type = LLM_TYPE_17B_16E; break;
|
||||||
@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
|
||||||
hparams.n_swa = 0;
|
hparams.n_swa = 0;
|
||||||
hparams.n_swa_pattern = 1;
|
hparams.set_swa_pattern(1);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_PHIMOE:
|
case LLM_ARCH_PHIMOE:
|
||||||
@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
{
|
{
|
||||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||||
hparams.n_swa = 4096; // default value of gemma 2
|
hparams.n_swa = 4096; // default value of gemma 2
|
||||||
hparams.n_swa_pattern = 2;
|
hparams.set_swa_pattern(2);
|
||||||
hparams.attn_soft_cap = true;
|
hparams.attn_soft_cap = true;
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||||
@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
case LLM_ARCH_GEMMA3:
|
case LLM_ARCH_GEMMA3:
|
||||||
{
|
{
|
||||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||||
hparams.n_swa_pattern = 6;
|
hparams.set_swa_pattern(6);
|
||||||
|
|
||||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||||
@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||||||
case LLM_ARCH_COHERE2:
|
case LLM_ARCH_COHERE2:
|
||||||
{
|
{
|
||||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||||
hparams.n_swa_pattern = 4;
|
hparams.set_swa_pattern(4);
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||||
@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
|
|||||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||||
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
|
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||||
@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
GGML_ASSERT(hparams.n_swa_pattern != 1);
|
GGML_ASSERT(hparams.is_swa_any());
|
||||||
|
|
||||||
res = new llama_kv_cache_unified_iswa(
|
res = new llama_kv_cache_unified_iswa(
|
||||||
*this,
|
*this,
|
||||||
@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
cparams.n_batch,
|
cparams.n_batch,
|
||||||
padding);
|
padding);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(hparams.n_swa_pattern == 1);
|
GGML_ASSERT(!hparams.is_swa_any());
|
||||||
|
|
||||||
res = new llama_kv_cache_unified(
|
res = new llama_kv_cache_unified(
|
||||||
*this,
|
*this,
|
||||||
|
Reference in New Issue
Block a user