mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-18 14:18:50 -04:00
Add ggml_roll
(ggml/1274)
* ggml : add ggml_roll * use set/get_op_params & std::min
This commit is contained in:
@@ -955,6 +955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"UPSCALE",
|
||||
"PAD",
|
||||
"PAD_REFLECT_1D",
|
||||
"ROLL",
|
||||
"ARANGE",
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
@@ -985,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"upscale(x)",
|
||||
"pad(x)",
|
||||
"pad_reflect_1d(x)",
|
||||
"roll(x)",
|
||||
"arange(start, stop, step)",
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
@@ -1080,7 +1082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -4341,6 +4343,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_roll
|
||||
|
||||
struct ggml_tensor * ggml_roll(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int shift0,
|
||||
int shift1,
|
||||
int shift2,
|
||||
int shift3) {
|
||||
GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
|
||||
GGML_ASSERT(abs(shift0) < a->ne[0]);
|
||||
GGML_ASSERT(abs(shift1) < a->ne[1]);
|
||||
GGML_ASSERT(abs(shift2) < a->ne[2]);
|
||||
GGML_ASSERT(abs(shift3) < a->ne[3]);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, shift0);
|
||||
ggml_set_op_params_i32(result, 1, shift1);
|
||||
ggml_set_op_params_i32(result, 2, shift2);
|
||||
ggml_set_op_params_i32(result, 3, shift3);
|
||||
|
||||
result->op = GGML_OP_ROLL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_arange
|
||||
|
||||
struct ggml_tensor * ggml_arange(
|
||||
|
Reference in New Issue
Block a user