ggml : use WARP_SIZE/2 for argmax reduction offset (#18092)

This commit is contained in:
Aadeshveer Singh
2025-12-17 09:17:01 +05:30
committed by GitHub
parent 2973a65ecb
commit 58062860af
+2 -2
View File
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {