From 98bb57916a727b0d6b059a3b50581d4ab9addb76 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Apr 2026 07:27:17 -0700 Subject: [PATCH] ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (#22456) * Refactor buffer aliasing to be part of shader lib decisions * cleanup * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 159 ++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 337 +++++++++--------- ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 6 +- .../wgsl-shaders/rms_norm_mul.wgsl | 6 +- .../ggml-webgpu/wgsl-shaders/ssm_scan.wgsl | 25 ++ tests/test-backend-ops.cpp | 31 +- 6 files changed, 326 insertions(+), 238 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index fb2c9527f..34cbf3694 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,21 +26,21 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 4 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 4 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 @@ -59,19 +59,32 @@ template inline void ggml_webgpu_hash_combine(size_t & seed, const seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +// Calculates base address of a tensor ignoring the fake base pointer +inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (uintptr_t) base_tensor->data + tensor->view_offs; +} + +inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b); +} + +inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) && + ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a); +} + struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; ggml_tensor * src2; ggml_tensor * src3; ggml_tensor * src4; + ggml_tensor * src5; ggml_tensor * dst; uint32_t max_wg_size; size_t wg_mem_limit_bytes = 0; - bool inplace = false; - bool overlap = false; - bool src_overlap = false; bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -88,6 +101,14 @@ struct webgpu_pipeline { struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; + bool inplace = false; +}; + +struct ggml_webgpu_binary_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; }; struct ggml_webgpu_processed_shader { @@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions { }; struct ggml_webgpu_ssm_scan_pipeline_key { - int type; - int d_state; + int type; + int d_state; + bool xbc_overlap; bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { - return type == other.type && d_state == other.d_state; + return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap; } }; @@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.d_state); + ggml_webgpu_hash_combine(seed, key.xbc_overlap); return seed; } }; @@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { struct ggml_webgpu_ssm_scan_shader_decisions { uint32_t wg_size; uint32_t tokens_per_tile; + bool xbc_overlap = false; }; /** Argsort **/ @@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { } }; +struct ggml_webgpu_rms_norm_mul_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; + bool kv_overlap = false; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; @@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; - key.kv_overlap = context.src_overlap; + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; @@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib { const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_row_norm, defines); - row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->inplace = key.inplace; + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + row_norm_pipelines[key].context = decisions; return row_norm_pipelines[key]; } @@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_set_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_set, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; set_pipelines[key] = pipeline; @@ -1355,7 +1392,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_scale_pipeline_key key = {}; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1375,6 +1412,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_scale, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; @@ -1468,6 +1506,8 @@ class ggml_webgpu_shader_lib { ggml_webgpu_ssm_scan_pipeline_key key = {}; key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; + key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1499,12 +1539,17 @@ class ggml_webgpu_shader_lib { variant += "_wg_reduce"; } + if (key.xbc_overlap) { + defines.push_back("XBC_OVERLAP"); + } + variant += "_d" + std::to_string(key.d_state); auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; decisions->tokens_per_tile = tokens_per_tile; + decisions->xbc_overlap = key.xbc_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; ssm_scan_pipelines[key] = pipeline; @@ -1764,11 +1809,9 @@ class ggml_webgpu_shader_lib { uint32_t tile_k; if (key.use_subgroup_matrix) { - tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT - : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; } else { - tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; } // Tiles @@ -2001,9 +2044,8 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); // mul_mat_id is register-tile only. - const uint32_t tile_k = ggml_is_quantized(context.src0->type) - ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + const uint32_t tile_k = + ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); @@ -2039,8 +2081,8 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.op = op; key.is_unary = is_unary; - key.inplace = context.inplace; - key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -2098,6 +2140,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_unary, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; unary_pipelines[key] = pipeline; @@ -2106,9 +2149,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = rms_norm_mul_pipelines.find(key); if (it != rms_norm_mul_pipelines.end()) { @@ -2132,12 +2175,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - rms_norm_mul_pipelines[key] = pipeline; + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + rms_norm_mul_pipelines[key] = pipeline; return rms_norm_mul_pipelines[key]; } @@ -2145,9 +2191,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; key.op = context.dst->op; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -2186,11 +2232,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_binary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; + pipeline.context = pipeline_decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } @@ -2351,7 +2401,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - auto pipeline_decisions = std::make_shared(decisions); + auto pipeline_decisions = std::make_shared(decisions); + pipeline_decisions->kv_overlap = key.kv_overlap; defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); @@ -2543,7 +2594,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rope_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); @@ -2582,6 +2633,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_rope, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; rope_pipelines[key] = pipeline; @@ -2593,7 +2645,7 @@ class ggml_webgpu_shader_lib { key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; key.has_mask = (context.src1 != nullptr); key.has_sink = (context.src2 != nullptr); - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2634,6 +2686,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_soft_max, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; soft_max_pipelines[key] = pipeline; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6d861c0c7..762d9f8d1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs; } /* Struct definitions */ @@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; return ctx->buffer; @@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; +struct ggml_webgpu_merged_binding_range { + size_t offset; + size_t size; }; -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); +static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range( + webgpu_context & ctx, + std::initializer_list tensors) { + size_t merged_offset = SIZE_MAX; + size_t merged_end = 0; - return flags; + for (ggml_tensor * tensor : tensors) { + const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor); + const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor); + + merged_offset = std::min(merged_offset, bind_offset); + merged_end = std::max(merged_end, bind_end); + } + + return { merged_offset, merged_end - merged_offset }; +} + +static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor, + const ggml_webgpu_merged_binding_range & merged_range) { + return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type)); } static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, @@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); @@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.src5 = src5; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + const bool xbc_overlap = decisions->xbc_overlap; + + uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)); + uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)); + size_t xbc_bind_offset = 0; + size_t xbc_bind_size = 0; + if (xbc_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 }); + xbc_bind_offset = merged_range.offset; + xbc_bind_size = merged_range.size; + offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range); + offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_x, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + offset_B, + offset_C, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, }; std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; + if (xbc_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst)); + } const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1653,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); const int has_mask = (mask != nullptr); const int has_sinks = (sinks != nullptr); - const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + const bool kv_overlap = decisions->kv_overlap; uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); size_t kv_bind_offset = 0; size_t kv_bind_size = 0; if (kv_overlap) { - const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); - const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); - const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); - const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); - kv_bind_offset = std::min(k_bind_offset, v_bind_offset); - kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; - offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); - offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); + kv_bind_offset = merged_range.offset; + kv_bind_size = merged_range.size; + offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); + offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } std::vector params = { @@ -1720,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, } entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.src_overlap = kv_overlap; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast(pipeline.context.get()); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches @@ -1921,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; - bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1994,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = flags.inplace; - shader_lib_ctx.overlap = flags.overlap; - shader_lib_ctx.src_overlap = flags.src_overlap; - webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); - uint32_t offset_merged_src0 = 0; - uint32_t offset_merged_src1 = 0; - if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); - offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); } std::vector params = { ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_src0, + offset_src1, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - offset_merged_src0, - offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), @@ -2048,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - if (flags.src_overlap) { - size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), - src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, - merged_end - merged_offset)); + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), @@ -2062,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), src1_webgpu_tensor_align_offset, ggml_webgpu_tensor_binding_size(ctx, src1))); - if (!flags.inplace && !flags.overlap) { + if (!decisions->inplace && !decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2168,29 +2177,15 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); - bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); - bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); - - uint32_t offset_merged_rn_src = 0; - uint32_t offset_merged_mul_src = 0; - size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); - size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); - - if (src_overlap) { - size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - offset_merged_rn_src = - (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); - offset_merged_mul_src = - (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); - } + uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)); + uint32_t offset_mul_src = + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)); + size_t merged_offset = 0; + size_t merged_size = 0; std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), - offset_merged_rn_src, - offset_merged_mul_src, + offset_rn_src, + offset_mul_src, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), @@ -2214,16 +2209,32 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace || overlap) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = rn_src; + shader_lib_ctx.src1 = mul_src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range); + offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range); + params[0] = offset_rn_src; + params[1] = offset_mul_src; + } + + if (decisions->inplace || decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); - } else if (src_overlap) { - size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - size_t merged_end = - std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), - mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, - merged_end - merged_offset)); + } else if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); @@ -2231,20 +2242,10 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - shader_lib_ctx.overlap = overlap; - shader_lib_ctx.src_overlap = src_overlap; - - webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2261,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); - } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } @@ -2287,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_freq_factor = (src2 != nullptr); + const bool inplace = decisions->inplace; + const int has_freq_factor = (src2 != nullptr); const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2421,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, } static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2454,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s // bindgroups unchanged std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { + if (!decisions->inplace) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } @@ -2473,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); - webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias = ggml_get_op_params_f32(dst, 1); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool inplace = decisions->inplace; + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -3079,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, size_t size) { GGML_UNUSED(backend); auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // Write aligned portion buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3161,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; @@ -3180,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3212,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, << ", " << offset << ", " << size << ")"); wgpu::Device device = buf_ctx->global_ctx->device; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; size_t final_size = size; if (size % 4 != 0) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index a748dc1b8..605de7aa7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,8 +7,6 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, - offset_merged_src0: u32, - offset_merged_src1: u32, stride_src0_0: u32, stride_src0_1: u32, @@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + let src0_i = params.offset_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + src1_index(gid.x); update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 74aaa2753..fd20a4e54 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) struct Params { offset_rn_src: u32, offset_mul_src: u32, - offset_merged_rn_src: u32, - offset_merged_mul_src: u32, offset_dst: u32, stride_rn_src1: u32, @@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3, i = i % (params.ne2 * params.ne1); let i2 = i / params.ne1; let i1 = i % params.ne1; - let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; - let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl index 643247385..05761dec3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -45,6 +45,14 @@ struct Params { }; @group(0) @binding(0) var s_in: array; +#ifdef XBC_OVERLAP +@group(0) @binding(1) var x_B_C_merged: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var ids: array; +@group(0) @binding(5) var dst: array; +@group(0) @binding(6) var params: Params; +#else @group(0) @binding(1) var x: array; @group(0) @binding(2) var dt: array; @group(0) @binding(3) var A: array; @@ -53,6 +61,7 @@ struct Params { @group(0) @binding(6) var ids: array; @group(0) @binding(7) var dst: array; @group(0) @binding(8) var params: Params; +#endif var shared_x_dt: array; var shared_dtsp: array; @@ -98,7 +107,11 @@ fn main( let dt0 = dt[dt_idx]; let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); shared_dtsp[tid] = dtsp; +#ifdef XBC_OVERLAP + shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp; +#else shared_x_dt[tid] = x[x_idx] * dtsp; +#endif } } @@ -116,16 +129,28 @@ fn main( let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; +#ifdef XBC_OVERLAP + let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt; +#else let s = s_prev * dA + B[b_idx] * x_dt; +#endif s_prev = s; #ifdef USE_SUBGROUP_REDUCTION +#ifdef XBC_OVERLAP + let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]); +#else let subgroup_partial = subgroupAdd(s * C[c_idx]); +#endif if (subgroup_invocation_id == 0u) { shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; } +#else +#ifdef XBC_OVERLAP + shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx]; #else shared_reduce[reduce_idx] = s * C[c_idx]; +#endif #endif workgroupBarrier(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 716011316..b4fd2c4dc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2984,7 +2984,7 @@ struct test_bin_bcast : public test_case { bool run_whole_graph() override { return nf > 1; } std::string vars() override { - return VARS_TO_STR5(type, ne, nr, nf, perm1); + return VARS_TO_STR6(type, ne, nr, nf, perm1, src_overlap); } size_t op_size(ggml_tensor * t) override { @@ -3589,9 +3589,10 @@ struct test_ssm_scan : public test_case { const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; + const bool xbc_overlap; std::string vars() override { - return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR8(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs, xbc_overlap); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, @@ -3600,16 +3601,31 @@ struct test_ssm_scan : public test_case { int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, - int64_t n_seqs = 32) - : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t n_seqs = 32, + bool xbc_overlap = false) + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), xbc_overlap(xbc_overlap) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * x; + ggml_tensor * B; + ggml_tensor * C; + + if (xbc_overlap) { + ggml_tensor * xbc = ggml_new_tensor_4d(ctx, type, d_state, n_head, n_seq_tokens, 2 * n_seqs); + x = ggml_view_4d(ctx, xbc, head_dim, n_head, n_seq_tokens, n_seqs, + xbc->nb[1], xbc->nb[2], xbc->nb[3], xbc->nb[3]); + B = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs, + xbc->nb[1], xbc->nb[2], xbc->nb[3], 0); + C = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs, + xbc->nb[1], xbc->nb[2], xbc->nb[3], 2 * xbc->nb[3]); + } else { + x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + } ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; @@ -7964,6 +7980,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 128, 4, 4, 16, 2, true)); // x/B/C overlap test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));