CUDA: Factor out and re-use block_reduce function (#18785)
* CUDA: Refactor and expose two_stage_warp_reduce_* function * Use `two_stage_warp_reduce` also in softmax kernel, move smem out of it Moving smem out of `__device__` function to `__global__` function allows for explicit smem reuse, as either compiler or cuda rt seem to not free it afterwards (`cudaFuncSetAttribute` fails when not accounting for it once for each call to two_stage_warp_reduce) * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Aman Gupta <amangupta052@gmail.com> * Use two_stage_warp_reduce in group_norm_f32 * Use two_stage_warp_reduce in rms_norm_f32 * Fix smem calculation which expects bytes * Make `two_stage_warp_reduce` accept all values warp_reduce accepts Also integrate it into norm_f32 function * Use two_stage_warp_reduce in l2_norm_f32 * Use type traits for block reduction for better legibility Also adresss other requests by @am17an such as variable renaming * Make norm tests cover all cuda paths * Mark columns % WARP_SIZE !=0 as supported for RMS_NORM_BACK Unit-tests passed locally, let's see if they pass in the CI as well * Use `enum class` for `block_reduce_method` This is more type-safe than plain enum * Rename variables as suggested in code review by @am17an * Rename two_stage_warp_reduce -> block_reduce * Fix trailing whitespace in common.cuh * Make condition of static_assert type-dependent This delays evaluation until the template is actually instantiated. Otherwise, some compilers may evaluate the assert when parsing the template, resulting in build errors as observed here: https://github.com/ggml-org/llama.cpp/actions/runs/20960323123/job/60235530068?pr=18785 * Inline definitions --------- Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
+17
-16
@@ -7482,25 +7482,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
|
||||
test_cases.emplace_back(new test_silu_back());
|
||||
|
||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
|
||||
for (uint32_t n : { 64, 1025 }) {
|
||||
for (bool v : { false, true }) {
|
||||
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||
}
|
||||
|
||||
// in-place tests
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
|
||||
|
||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
|
||||
for (uint32_t n : { 64, 1025 }) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
}
|
||||
}
|
||||
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
|
||||
for (bool multi_add : {false, true}) {
|
||||
@@ -7524,9 +7528,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
for (int64_t d_conv : {3, 4, 9}) {
|
||||
for (int64_t d_inner: {1024, 1536, 2048}) {
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
|
||||
Reference in New Issue
Block a user