vulkan/cuda: Fix im2col when KW!=KH (#14789)

The tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.
This commit is contained in:
Jeff Bolz
2025-07-21 06:35:40 -05:00
committed by GitHub
parent c82d48ec23
commit c2e058f1b4
3 changed files with 4 additions and 5 deletions

View File

@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
return;
}
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t ksize = OW * KH;
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;