mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-07 17:24:18 -04:00
feat: Support hybrid recurrent in llama-graph
NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -7,6 +7,7 @@
|
|||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
#include "llama-kv-cache-unified-iswa.h"
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-kv-cache-recurrent.h"
|
||||||
|
#include "llama-kv-cache-hybrid-recurrent.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -403,6 +404,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
||||||
|
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// llm_graph_context
|
// llm_graph_context
|
||||||
//
|
//
|
||||||
@@ -961,8 +969,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
if (kv_state == nullptr) {
|
||||||
|
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||||
|
}
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||||
|
|
||||||
@@ -1291,6 +1301,44 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
|
||||||
|
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
|
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla,
|
||||||
|
float kq_scale,
|
||||||
|
int il) const {
|
||||||
|
return build_attn(
|
||||||
|
static_cast<llm_graph_input_attn_kv_unified *>(inp),
|
||||||
|
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||||
|
|
||||||
|
@@ -22,6 +22,7 @@ struct llama_memory_state_i;
|
|||||||
class llama_kv_cache_unified_state;
|
class llama_kv_cache_unified_state;
|
||||||
class llama_kv_cache_unified_iswa_state;
|
class llama_kv_cache_unified_iswa_state;
|
||||||
class llama_kv_cache_recurrent_state;
|
class llama_kv_cache_recurrent_state;
|
||||||
|
class llama_kv_cache_hybrid_recurrent_state;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
@@ -242,7 +243,7 @@ public:
|
|||||||
cparams(cparams),
|
cparams(cparams),
|
||||||
kv_state(kv_state) {
|
kv_state(kv_state) {
|
||||||
}
|
}
|
||||||
~llm_graph_input_attn_kv_unified() = default;
|
virtual ~llm_graph_input_attn_kv_unified() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
@@ -285,6 +286,16 @@ public:
|
|||||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
const llama_kv_cache_unified_iswa_state * kv_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_hybrid_recurrent_state * kv_state);
|
||||||
|
|
||||||
|
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
|
||||||
|
};
|
||||||
|
|
||||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
||||||
@@ -508,7 +519,7 @@ struct llm_graph_context {
|
|||||||
ggml_tensor * build_inp_out_ids() const;
|
ggml_tensor * build_inp_out_ids() const;
|
||||||
ggml_tensor * build_inp_mean() const;
|
ggml_tensor * build_inp_mean() const;
|
||||||
ggml_tensor * build_inp_cls() const;
|
ggml_tensor * build_inp_cls() const;
|
||||||
ggml_tensor * build_inp_s_copy() const;
|
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
|
||||||
|
|
||||||
ggml_tensor * build_inp_cross_embd() const;
|
ggml_tensor * build_inp_cross_embd() const;
|
||||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||||
@@ -574,6 +585,21 @@ struct llm_graph_context {
|
|||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||||
|
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
|
Reference in New Issue
Block a user