CUDA: Factor out and re-use block_reduce function (#18785)

* CUDA: Refactor and expose two_stage_warp_reduce_* function

* Use `two_stage_warp_reduce` also in softmax kernel, move smem out of it

Moving smem out of `__device__` function to `__global__` function
allows for explicit smem reuse, as either compiler or cuda rt seem to not
free it afterwards (`cudaFuncSetAttribute` fails when not accounting for
it once for each call to two_stage_warp_reduce)

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* Use two_stage_warp_reduce in group_norm_f32

* Use two_stage_warp_reduce in rms_norm_f32

* Fix smem calculation which expects bytes

* Make `two_stage_warp_reduce` accept all values warp_reduce accepts

Also integrate it into norm_f32 function

* Use two_stage_warp_reduce in l2_norm_f32

* Use type traits for block reduction for better legibility

Also adresss other requests by @am17an such as variable renaming

* Make norm tests cover all cuda paths

* Mark columns % WARP_SIZE !=0 as supported for RMS_NORM_BACK

Unit-tests passed locally, let's see if they pass in the CI as well

* Use `enum class` for `block_reduce_method`

This is more type-safe than plain enum

* Rename variables as suggested in code review by @am17an

* Rename two_stage_warp_reduce -> block_reduce

* Fix trailing whitespace in common.cuh

* Make condition of static_assert type-dependent

This delays evaluation until the template is actually instantiated.
Otherwise, some compilers may evaluate the assert when parsing the
template, resulting in build errors as observed here:

https://github.com/ggml-org/llama.cpp/actions/runs/20960323123/job/60235530068?pr=18785

* Inline definitions

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Oliver Simons
2026-01-15 03:44:54 +01:00
committed by GitHub
parent d98b548120
commit 36f0132464
6 changed files with 125 additions and 191 deletions
+17 -16
View File
@@ -7482,25 +7482,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
for (uint32_t n : { 64, 1025 }) {
for (bool v : { false, true }) {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
// in-place tests
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
for (uint32_t n : { 64, 1025 }) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
}
}
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
for (bool multi_add : {false, true}) {
@@ -7524,9 +7528,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));