mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-18 05:56:00 -04:00
llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@@ -4837,7 +4837,6 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||
const int64_t n_s = sx->ne[2];
|
||||
|
||||
// TODO: maybe support other strides than 1?
|
||||
// FIXME: this is always true?
|
||||
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||
GGML_ASSERT(n_t >= 0);
|
||||
@@ -4860,36 +4859,49 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
struct ggml_tensor * dt,
|
||||
struct ggml_tensor * A,
|
||||
struct ggml_tensor * B,
|
||||
struct ggml_tensor * C) {
|
||||
struct ggml_tensor * C,
|
||||
struct ggml_tensor * ids) {
|
||||
GGML_ASSERT(ggml_is_contiguous(s));
|
||||
GGML_ASSERT(ggml_is_contiguous(x));
|
||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||
GGML_ASSERT(ggml_is_contiguous(A));
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
GGML_ASSERT(ggml_is_3d(B));
|
||||
GGML_ASSERT(ggml_is_3d(s));
|
||||
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
|
||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
|
||||
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
||||
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
||||
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
{
|
||||
const int64_t d_state = s->ne[0];
|
||||
const int64_t d_inner = s->ne[1];
|
||||
const int64_t n_seq_tokens = x->ne[1];
|
||||
const int64_t n_seqs = x->ne[2];
|
||||
const int64_t head_dim = x->ne[0];
|
||||
const int64_t n_head = x->ne[1];
|
||||
const int64_t n_seq_tokens = x->ne[2];
|
||||
const int64_t n_seqs = x->ne[3];
|
||||
|
||||
GGML_ASSERT(s->ne[2] == n_seqs);
|
||||
GGML_ASSERT(x->ne[0] == d_inner);
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
GGML_ASSERT(A->ne[1] == d_inner);
|
||||
GGML_ASSERT(dt->ne[0] == n_head);
|
||||
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
|
||||
GGML_ASSERT(dt->ne[2] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_3d(dt));
|
||||
GGML_ASSERT(s->ne[1] == head_dim);
|
||||
GGML_ASSERT(s->ne[2] == n_head);
|
||||
GGML_ASSERT(B->ne[0] == d_state);
|
||||
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[2] == n_seqs);
|
||||
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
||||
GGML_ASSERT(B->ne[3] == n_seqs);
|
||||
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
GGML_ASSERT(A->ne[1] == n_head);
|
||||
GGML_ASSERT(ggml_is_matrix(A));
|
||||
|
||||
if (A->ne[0] != 1) {
|
||||
// Mamba-1 has more granular decay factors
|
||||
GGML_ASSERT(A->ne[0] == d_state);
|
||||
}
|
||||
}
|
||||
|
||||
// concatenated y + ssm_states
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
||||
|
||||
result->op = GGML_OP_SSM_SCAN;
|
||||
result->src[0] = s;
|
||||
@@ -4898,6 +4910,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||
result->src[3] = A;
|
||||
result->src[4] = B;
|
||||
result->src[5] = C;
|
||||
result->src[6] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
Reference in New Issue
Block a user