vulkan: support L2_NORM with contiguous rows (#19604)

This commit is contained in:
Jeff Bolz
2026-02-13 21:42:04 -08:00
committed by GitHub
parent 53aef25a88
commit dbb023336b
3 changed files with 31 additions and 16 deletions
+12 -4
View File
@@ -5821,20 +5821,27 @@ struct test_l2_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
bool v;
std::string vars() override {
return VARS_TO_STR2(type, ne);
return VARS_TO_STR4(type, ne, eps, v);
}
test_l2_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 64, 320, 1},
float eps = 1e-12f)
: type(type), ne(ne), eps(eps) {}
float eps = 1e-12f,
bool v = false)
: type(type), ne(ne), eps(eps), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
ggml_set_name(out, "out");
@@ -7596,7 +7603,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
}
}