CUDA: better coalesce data-access for contiguous concat (#22330)

Also, distribute all elements across CTAs evenly instead of launching
one CTA per dim
This commit is contained in:
Oliver Simons
2026-04-26 09:21:45 +02:00
committed by GitHub
parent 0c6ee1cade
commit b1a5bd4e0c
+56 -73
View File
@@ -1,96 +1,79 @@
#include "concat.cuh" #include "concat.cuh"
// contiguous kernels // contiguous kernels
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { template <int dim>
int nidx = threadIdx.x + blockIdx.x * blockDim.x; static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x,
if (nidx >= ne0) { const float * y,
return; float * dst,
} int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]");
int offset_dst = const int64_t n = ne0 * ne1 * ne2;
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00) { // src0 for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
int offset_src = if constexpr (dim == 0) {
nidx + const int64_t row = i / ne0;
blockIdx.y * ne00 + const int64_t i0 = i - row * ne0;
blockIdx.z * ne00 * gridDim.y;
dst[offset_dst] = x[offset_src]; if (i0 < ne00) {
dst[i] = x[row * ne00 + i0];
} else { } else {
int offset_src = dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)];
(nidx - ne00) +
blockIdx.y * (ne0 - ne00) +
blockIdx.z * (ne0 - ne00) * gridDim.y;
dst[offset_dst] = y[offset_src];
}
} }
} else if constexpr (dim == 1) {
const int64_t dst_plane = ne0 * ne1;
const int64_t src0_plane = ne0 * ne01;
const int64_t src1_plane = dst_plane - src0_plane;
const int64_t i2 = i / dst_plane;
const int64_t i01 = i - i2 * dst_plane;
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { if (i01 < src0_plane) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x; dst[i] = x[i2 * src0_plane + i01];
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.y < (unsigned)ne01) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * ne01;
dst[offset_dst] = x[offset_src];
} else { } else {
int offset_src = dst[i] = y[i2 * src1_plane + (i01 - src0_plane)];
nidx +
(blockIdx.y - ne01) * ne0 +
blockIdx.z * ne0 * (gridDim.y - ne01);
dst[offset_dst] = y[offset_src];
} }
}
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.z < (unsigned)ne02) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
} else { } else {
int offset_src = const int64_t src0_size = ne0 * ne1 * ne02;
nidx +
blockIdx.y * ne0 + if (i < src0_size) {
(blockIdx.z - ne02) * ne0 * gridDim.y; dst[i] = x[i];
dst[offset_dst] = y[offset_src]; } else {
dst[i] = y[i - src0_size];
}
}
} }
} }
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { static void concat_f32_cuda(const float * x,
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; const float * y,
dim3 gridDim(num_blocks, ne1, ne2); float * dst,
int64_t ne00,
int64_t ne01,
int64_t ne02,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int dim,
cudaStream_t stream) {
const int64_t n = ne0 * ne1 * ne2;
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
if (dim == 0) { if (dim == 0) {
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00); concat_f32_cont<0>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return; return;
} }
if (dim == 1) { if (dim == 1) {
concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01); concat_f32_cont<1>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return; return;
} }
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02); concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
} }
// non-contiguous kernel (slow) // non-contiguous kernel (slow)