diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 0efe87716..52624a46d 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -97,6 +97,36 @@ jobs: vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: investigate slight precision issues in some operations for test-backend-ops on the WebGPU backend. + #ggml-ci-nvidia-webgpu: + # runs-on: [self-hosted, Linux, NVIDIA] + + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 + + # - name: Dawn Dependency + # id: dawn-depends + # run: | + # DAWN_VERSION="v20260317.182325" + # DAWN_OWNER="google" + # DAWN_REPO="dawn" + # DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release" + # echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # curl -L -o artifact.tar.gz \ + # "https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz" + # mkdir dawn + # tar -xvf artifact.tar.gz -C dawn --strip-components=1 + + # - name: Test + # id: ggml-ci + # run: | + # GG_BUILD_WEBGPU=1 \ + # GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + # GG_BUILD_WEBGPU_DAWN_DIR="$GITHUB_WORKSPACE/dawn/lib64/cmake/Dawn" \ + # bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + # TODO: provision AMX-compatible machine #ggml-ci-cpu-amx: # runs-on: [self-hosted, Linux, CPU, AMX] diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3de6258c7..7d9a4403f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; -}; - -struct ggml_webgpu_flash_attn_shader_decisions { +struct ggml_webgpu_flash_attn_decisions { uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; }; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { - return 1u; - } +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; - // Head-dim specializations used by the tuned vec f16 path. - switch (key.head_dim_qk) { - case 64: - return 2u; - case 96: - return 4u; - case 128: - return 1u; - case 192: - return 2u; - case 576: - return 2u; - default: - return 1u; - } +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } -struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { - ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_reduce"; - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - variant += std::string("_wg") + std::to_string(context.max_wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - struct ggml_webgpu_flash_attn_blk_pipeline_key { - uint32_t q_tile; uint32_t kv_tile; - bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { - return q_tile == other.q_tile && kv_tile == other.kv_tile; - } + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; -struct ggml_webgpu_flash_attn_blk_shader_lib_context { - ggml_webgpu_flash_attn_blk_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_blk"; - - defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); - variant += std::string("_qt") + std::to_string(context.key.q_tile); - - defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); - variant += std::string("_kvt") + std::to_string(context.key.kv_tile); - - uint32_t wg_size = 1; - while ((wg_size << 1) <= context.max_wg_size) { - wg_size <<= 1; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - variant += std::string("_wg") + std::to_string(wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + return kv_tile; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_pipelines; std::unordered_map @@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_row_norm_pipeline_key key = { - .op = context.dst->op, - .inplace = context.inplace, - }; + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = context.inplace; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = context.inplace; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_solve_tri_pipeline_key key = { - .type = context.dst->type, - .n = (int) context.src0->ne[0], - .k = (int) context.src1->ne[0], - }; + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { @@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_ssm_conv_pipeline_key key = { - .type = context.dst->type, - .vectorized = context.src1->ne[0] == 4, - }; + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { @@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_gated_delta_net_pipeline_key key = { - .type = context.dst->type, - .s_v = (int) context.src2->ne[0], - .kda = context.src3->ne[0] == context.src2->ne[0], - }; + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { @@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); @@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_id_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - }; + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), - }; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = context.inplace; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { - auto it = flash_attn_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; std::string variant = "flash_attn"; - switch (context.key.kv_type) { + switch (key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.key.kv_type); + variant += std::string("_") + ggml_type_name(key.kv_type); - if (context.key.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.key.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.key.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.key.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + auto decisions = std::make_shared(); + decisions->q_tile = context.sg_mat_m; + + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + decisions->kv_tile = kv_tile; + decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - uint32_t wg_size = 0; - if (context.key.use_vec) { - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; - flash_attn_pipelines[context.key] = pipeline; - return flash_attn_pipelines[context.key]; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - auto it = flash_attn_blk_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn_vec"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + defines.push_back("BLK"); + variant += "_mask_blk"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back("Q_TILE=1"); + + auto decisions = std::make_shared(); + decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + uint32_t vec_ne = 1u; + + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_blk_pipelines[context.key] = pipeline; - return flash_attn_blk_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_reduce_pipeline( - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - auto it = flash_attn_vec_reduce_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_vec_reduce_pipelines[context.key] = pipeline; - return flash_attn_vec_reduce_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_cpy_pipeline_key key = { - .src_type = context.src0->type, - .dst_type = context.dst->type, - }; + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { @@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_glu_pipeline_key key = { - .glu_op = ggml_get_glu_op(context.dst), - .type = context.dst->type, - .split = (context.src1 != nullptr), - }; + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_rope_pipeline_key key = { - .type = context.dst->type, - .inplace = context.inplace, - .has_ff = (context.src2 != nullptr), - }; + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; + key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { @@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_soft_max_pipeline_key key = { - .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, - .has_mask = (context.src1 != nullptr), - .has_sink = (context.src2 != nullptr), - .inplace = context.inplace, - }; + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = context.inplace; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 01637e2dd..e7bda817a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -41,6 +41,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -369,6 +375,96 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + + return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst + bool src_overlap; +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + + return flags; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -480,10 +576,8 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = ctx->param_arena.buffer, - .offset = param_offset, - .size = ctx->param_arena.slot_size }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); @@ -502,13 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, - .beginningOfPassWriteIndex = query_begin, - .endOfPassWriteIndex = query_end }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -544,17 +642,19 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - entries.push_back( - { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); @@ -632,65 +732,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} - -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; - -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); - - return flags; -} - static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); @@ -712,14 +758,8 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -732,13 +772,12 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); @@ -772,29 +811,21 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, std::vector entries; uint32_t binding_index = 0; if (!inplace) { - entries.push_back({ .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); binding_index++; } - entries.push_back({ .binding = binding_index, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - entries.push_back({ .binding = binding_index + 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -832,14 +863,8 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -850,13 +875,12 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); @@ -888,18 +912,9 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); @@ -911,12 +926,11 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -944,18 +958,9 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); @@ -971,15 +976,14 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src4, ggml_tensor * src5, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .src3 = src3, - .src4 = src4, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); @@ -1015,34 +1019,10 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(src3), - .offset = ggml_webgpu_tensor_align_offset(ctx, src3), - .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(src4), - .offset = ggml_webgpu_tensor_align_offset(ctx, src4), - .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(src5), - .offset = ggml_webgpu_tensor_align_offset(ctx, src5), - .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, - { .binding = 6, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); @@ -1058,12 +1038,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1086,25 +1065,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1131,12 +1099,11 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1160,20 +1127,9 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); @@ -1225,17 +1181,16 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1270,18 +1225,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1333,13 +1279,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; @@ -1380,22 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id_gather.wgsl std::vector gather_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1427,30 +1364,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id.wgsl std::vector main_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; // Calculate workgroup dimensions @@ -1486,11 +1411,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1522,86 +1445,53 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; uint32_t binding_index = 3; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : + ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + if (!use_vec) { + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = use_vec, - }; - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); - - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1609,197 +1499,162 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec) { - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - const bool use_vec_reduce = nwg > 1u; - GGML_ASSERT(nrows <= UINT32_MAX); + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; - const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); - if (use_vec_reduce) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - tmp_stats_base = tmp_data_elems; - tmp_size_bytes = - ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = scratch_offset; - tmp_bind_size = tmp_size_bytes; - scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); - } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); - } - - webgpu_pipeline blk_pipeline; - std::vector blk_params; - std::vector blk_entries; - if (use_blk) { - GGML_ASSERT(has_mask); - - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = - { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); - - blk_params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) K->ne[1], // seq_len_kv - stride_mask3, // stride_mask3 - blk_nblk0, // nblk0 - blk_nblk1, // nblk1 - }; - blk_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, - }; - scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); - } - - std::vector split_params = params; - if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 - } - split_params.push_back(0u); // tmp_data_base - split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg - - std::vector split_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, - }; - uint32_t split_binding_index = 3; - if (has_mask) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = blk_entries[1].offset, - .size = blk_size_bytes }); - } - split_entries.push_back( - { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; - std::vector reduce_entries; - if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, - std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = - { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }, - .max_wg_size = reduce_wg_size, - }; - reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); - - reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; - - reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - } - - const uint64_t split_wg_total = (uint64_t) wg_x * nwg; - GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector dispatches; - - if (use_blk) { - dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } - dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } - }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } - - return ggml_backend_webgpu_build_multi(ctx, dispatches); + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V)), + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } + + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector dispatches; + + if (has_mask) { + dispatches.push_back({ + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); + } + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #endif // __EMSCRIPTEN__ @@ -1807,13 +1662,12 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); @@ -1844,10 +1698,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1856,25 +1710,19 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -1887,15 +1735,14 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = flags.inplace; + shader_lib_ctx.overlap = flags.overlap; + shader_lib_ctx.src_overlap = flags.src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1944,38 +1791,18 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2012,26 +1839,16 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2059,21 +1876,14 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2097,28 +1907,19 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); @@ -2129,14 +1930,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); @@ -2187,41 +1987,27 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2232,12 +2018,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); @@ -2265,29 +2050,20 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2296,13 +2072,12 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2321,23 +2096,15 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2349,25 +2116,23 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2389,39 +2154,29 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[2], has_mask ? (uint32_t) src1->ne[2] : 0, has_mask ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; if (has_mask) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); @@ -2432,20 +2187,13 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); @@ -2455,13 +2203,12 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -2527,11 +2274,8 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; dispatches.push_back({ @@ -2580,12 +2324,9 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; @@ -2607,23 +2348,14 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); @@ -2641,20 +2373,13 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); @@ -3133,40 +2858,24 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * mask = tensor->src[3]; const ggml_tensor * sinks = tensor->src[4]; if (Q && K && V) { - GGML_UNUSED(sinks); - const bool kv_direct = (K->type == GGML_TYPE_F16) && - (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - kv_vec_type_supported && (V->type == K->type); - if (use_vec) { - const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - const size_t limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!kv_direct) { - bytes_per_kv += std::max(Q->ne[0], V->ne[0]); - } - if (mask != nullptr) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; - if (kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= sg_mat_n; - } - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { + const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3271,8 +2980,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -3931,20 +3641,23 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - static ggml_backend_reg reg = { + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, + /* .context = */ ctx, }; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 0; + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend // registry. Recreating it on repeated registry lookups can invalidate // adapter/device references that are still held by the backend/device layer. - if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { return ® } @@ -3961,17 +3674,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); // Probe for adapter support wgpu::Adapter adapter; - if (ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - ctx.webgpu_global_ctx->instance.WaitAny( - ctx.webgpu_global_ctx->instance.RequestAdapter( + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { @@ -3984,7 +3698,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } if (adapter != nullptr) { - ctx.device_count = 1; + ctx->device_count = 1; } return ® diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82d072be7..61107c6a9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,7 +1,6 @@ diagnostic(off, subgroup_uniformity); enable f16; -#define Q_TILE 1 #define KV_TILE 32 #define WG_SIZE 32 @@ -11,7 +10,7 @@ struct Params { seq_len_kv: u32, stride_mask3: u32, // Number of KV blocks and Q blocks per batch. - // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. nblk0: u32, nblk1: u32, }; @@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, return; } - let q_start = q_blk * Q_TILE; + let q_start = q_blk; let k_start = kv_blk * KV_TILE; let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); @@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var local_max = -MASK_MAX; var local_any = 0u; - for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { - let q_row = q_start + q_rel; - if (q_row >= params.seq_len_q) { - continue; - } + let q_row = q_start; + if (q_row < params.seq_len_q) { let row_base = mask_batch_base + q_row * params.seq_len_kv; for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { let k_col = k_start + k_rel;