ggml : add repeat impl for i64

This commit is contained in:
Georgi Gerganov
2025-06-21 09:07:25 +03:00
parent f2cd962fe2
commit 695b6b7025
2 changed files with 52 additions and 0 deletions

View File

@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
} else if (type == GGML_TYPE_F32) { } else if (type == GGML_TYPE_F32) {
v = *(float *) &data[i]; v = *(float *) &data[i];
} else if (type == GGML_TYPE_I64) {
v = (float) *(int64_t *) &data[i];
} else if (type == GGML_TYPE_I32) { } else if (type == GGML_TYPE_I32) {
v = (float) *(int32_t *) &data[i]; v = (float) *(int32_t *) &data[i];
} else if (type == GGML_TYPE_I16) { } else if (type == GGML_TYPE_I16) {

View File

@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
} }
} }
static void ggml_compute_forward_repeat_i64(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
GGML_ASSERT(ggml_can_repeat(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS
// guaranteed to be an integer due to the check in ggml_can_repeat
const int nr0 = (int)(ne0/ne00);
const int nr1 = (int)(ne1/ne01);
const int nr2 = (int)(ne2/ne02);
const int nr3 = (int)(ne3/ne03);
// TODO: support for transposed / permuted tensors
GGML_ASSERT(nb0 == sizeof(int64_t));
GGML_ASSERT(nb00 == sizeof(int64_t));
// TODO: maybe this is not optimal?
for (int i3 = 0; i3 < nr3; i3++) {
for (int k3 = 0; k3 < ne03; k3++) {
for (int i2 = 0; i2 < nr2; i2++) {
for (int k2 = 0; k2 < ne02; k2++) {
for (int i1 = 0; i1 < nr1; i1++) {
for (int k1 = 0; k1 < ne01; k1++) {
for (int i0 = 0; i0 < nr0; i0++) {
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
for (int i = 0; i < ne00; ++i) {
y[i] = x[i];
}
}
}
}
}
}
}
}
}
void ggml_compute_forward_repeat( void ggml_compute_forward_repeat(
const ggml_compute_params * params, const ggml_compute_params * params,
ggml_tensor * dst) { ggml_tensor * dst) {
@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
{ {
ggml_compute_forward_repeat_f32(params, dst); ggml_compute_forward_repeat_f32(params, dst);
} break; } break;
case GGML_TYPE_I64:
{
ggml_compute_forward_repeat_i64(params, dst);
} break;
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");