mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
ggml : add repeat impl for i64
This commit is contained in:
@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
|
@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_repeat_i64(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_can_repeat(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
// guaranteed to be an integer due to the check in ggml_can_repeat
|
||||
const int nr0 = (int)(ne0/ne00);
|
||||
const int nr1 = (int)(ne1/ne01);
|
||||
const int nr2 = (int)(ne2/ne02);
|
||||
const int nr3 = (int)(ne3/ne03);
|
||||
|
||||
// TODO: support for transposed / permuted tensors
|
||||
GGML_ASSERT(nb0 == sizeof(int64_t));
|
||||
GGML_ASSERT(nb00 == sizeof(int64_t));
|
||||
|
||||
// TODO: maybe this is not optimal?
|
||||
for (int i3 = 0; i3 < nr3; i3++) {
|
||||
for (int k3 = 0; k3 < ne03; k3++) {
|
||||
for (int i2 = 0; i2 < nr2; i2++) {
|
||||
for (int k2 = 0; k2 < ne02; k2++) {
|
||||
for (int i1 = 0; i1 < nr1; i1++) {
|
||||
for (int k1 = 0; k1 < ne01; k1++) {
|
||||
for (int i0 = 0; i0 < nr0; i0++) {
|
||||
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
|
||||
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_repeat(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
|
||||
{
|
||||
ggml_compute_forward_repeat_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_I64:
|
||||
{
|
||||
ggml_compute_forward_repeat_i64(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
Reference in New Issue
Block a user