ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (#22456)

* Refactor buffer aliasing to be part of shader lib decisions

* cleanup

* formatting
This commit is contained in:
Reese Levine
2026-04-28 07:27:17 -07:00
committed by GitHub
parent f42e29fdf1
commit 98bb57916a
6 changed files with 326 additions and 238 deletions
+106 -53
View File
@@ -26,21 +26,21 @@
// Matrix multiplication parameters // Matrix multiplication parameters
// Register tiling parameters // Register tiling parameters
#define WEBGPU_MUL_MAT_TILE_M 4 #define WEBGPU_MUL_MAT_TILE_M 4
#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_TILE_N 4
#define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_M 8
#define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
// Subgroup matrix parameters // Subgroup matrix parameters
// The number of subgroups in the M dimension // The number of subgroups in the M dimension
#define WEBGPU_MUL_MAT_SUBGROUP_M 2 #define WEBGPU_MUL_MAT_SUBGROUP_M 2
// The number of subgroups in the N dimension // The number of subgroups in the N dimension
#define WEBGPU_MUL_MAT_SUBGROUP_N 4 #define WEBGPU_MUL_MAT_SUBGROUP_N 4
// The number of subgroup matrices each subgroup accumulates over // The number of subgroup matrices each subgroup accumulates over
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
@@ -59,19 +59,32 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
} }
// Calculates base address of a tensor ignoring the fake base pointer
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
return (uintptr_t) base_tensor->data + tensor->view_offs;
}
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
}
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
}
struct ggml_webgpu_shader_lib_context { struct ggml_webgpu_shader_lib_context {
ggml_tensor * src0; ggml_tensor * src0;
ggml_tensor * src1; ggml_tensor * src1;
ggml_tensor * src2; ggml_tensor * src2;
ggml_tensor * src3; ggml_tensor * src3;
ggml_tensor * src4; ggml_tensor * src4;
ggml_tensor * src5;
ggml_tensor * dst; ggml_tensor * dst;
uint32_t max_wg_size; uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0; size_t wg_mem_limit_bytes = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
bool supports_subgroups = false; bool supports_subgroups = false;
bool supports_subgroup_matrix = false; bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0; uint32_t sg_mat_m = 0;
@@ -88,6 +101,14 @@ struct webgpu_pipeline {
struct ggml_webgpu_generic_shader_decisions { struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0; uint32_t wg_size = 0;
bool inplace = false;
};
struct ggml_webgpu_binary_shader_decisions {
uint32_t wg_size = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
}; };
struct ggml_webgpu_processed_shader { struct ggml_webgpu_processed_shader {
@@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
}; };
struct ggml_webgpu_ssm_scan_pipeline_key { struct ggml_webgpu_ssm_scan_pipeline_key {
int type; int type;
int d_state; int d_state;
bool xbc_overlap;
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
return type == other.type && d_state == other.d_state; return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
} }
}; };
@@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
size_t seed = 0; size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.d_state); ggml_webgpu_hash_combine(seed, key.d_state);
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
return seed; return seed;
} }
}; };
@@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
struct ggml_webgpu_ssm_scan_shader_decisions { struct ggml_webgpu_ssm_scan_shader_decisions {
uint32_t wg_size; uint32_t wg_size;
uint32_t tokens_per_tile; uint32_t tokens_per_tile;
bool xbc_overlap = false;
}; };
/** Argsort **/ /** Argsort **/
@@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
} }
}; };
struct ggml_webgpu_rms_norm_mul_shader_decisions {
uint32_t wg_size = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
};
/** Pad **/ /** Pad **/
struct ggml_webgpu_pad_pipeline_key { struct ggml_webgpu_pad_pipeline_key {
bool circular; bool circular;
@@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
}; };
struct ggml_webgpu_flash_attn_decisions { struct ggml_webgpu_flash_attn_decisions {
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
uint32_t q_tile = 0; uint32_t q_tile = 0;
uint32_t kv_tile = 0; uint32_t kv_tile = 0;
uint32_t wg_size = 0; uint32_t wg_size = 0;
bool kv_direct = false; bool kv_direct = false;
bool kv_overlap = false;
}; };
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
@@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = kv_direct; key.kv_direct = kv_direct;
key.kv_overlap = context.src_overlap; key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = has_mask; key.has_mask = has_mask;
key.has_sinks = has_sinks; key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
@@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {}; ggml_webgpu_row_norm_pipeline_key key = {};
key.op = context.dst->op; key.op = context.dst->op;
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = row_norm_pipelines.find(key); auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) { if (it != row_norm_pipelines.end()) {
@@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib {
const uint32_t row_norm_wg_size = 128u; const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines); auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
decisions->inplace = key.inplace;
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
row_norm_pipelines[key].context = decisions;
return row_norm_pipelines[key]; return row_norm_pipelines[key];
} }
@@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = {}; ggml_webgpu_set_pipeline_key key = {};
key.type = context.dst->type; key.type = context.dst->type;
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = set_pipelines.find(key); auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) { if (it != set_pipelines.end()) {
@@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_set, defines); auto processed = preprocessor.preprocess(wgsl_set, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size; decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
set_pipelines[key] = pipeline; set_pipelines[key] = pipeline;
@@ -1355,7 +1392,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_scale_pipeline_key key = {}; ggml_webgpu_scale_pipeline_key key = {};
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = scale_pipelines.find(key); auto it = scale_pipelines.find(key);
if (it != scale_pipelines.end()) { if (it != scale_pipelines.end()) {
@@ -1375,6 +1412,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_scale, defines); auto processed = preprocessor.preprocess(wgsl_scale, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size; decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
scale_pipelines[key] = pipeline; scale_pipelines[key] = pipeline;
@@ -1468,6 +1506,8 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_ssm_scan_pipeline_key key = {}; ggml_webgpu_ssm_scan_pipeline_key key = {};
key.type = context.dst->type; key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0]; key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key); auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) { if (it != ssm_scan_pipelines.end()) {
@@ -1499,12 +1539,17 @@ class ggml_webgpu_shader_lib {
variant += "_wg_reduce"; variant += "_wg_reduce";
} }
if (key.xbc_overlap) {
defines.push_back("XBC_OVERLAP");
}
variant += "_d" + std::to_string(key.d_state); variant += "_d" + std::to_string(key.d_state);
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
decisions->wg_size = wg_size; decisions->wg_size = wg_size;
decisions->tokens_per_tile = tokens_per_tile; decisions->tokens_per_tile = tokens_per_tile;
decisions->xbc_overlap = key.xbc_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
ssm_scan_pipelines[key] = pipeline; ssm_scan_pipelines[key] = pipeline;
@@ -1764,11 +1809,9 @@ class ggml_webgpu_shader_lib {
uint32_t tile_k; uint32_t tile_k;
if (key.use_subgroup_matrix) { if (key.use_subgroup_matrix) {
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
} else { } else {
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
} }
// Tiles // Tiles
@@ -2001,9 +2044,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("SCALAR"); defines.push_back("SCALAR");
// mul_mat_id is register-tile only. // mul_mat_id is register-tile only.
const uint32_t tile_k = ggml_is_quantized(context.src0->type) const uint32_t tile_k =
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
// Tiles // Tiles
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
@@ -2039,8 +2081,8 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type; key.type = context.dst->type;
key.op = op; key.op = op;
key.is_unary = is_unary; key.is_unary = is_unary;
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
auto it = unary_pipelines.find(key); auto it = unary_pipelines.find(key);
if (it != unary_pipelines.end()) { if (it != unary_pipelines.end()) {
@@ -2098,6 +2140,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_unary, defines); auto processed = preprocessor.preprocess(wgsl_unary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size; decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
unary_pipelines[key] = pipeline; unary_pipelines[key] = pipeline;
@@ -2106,9 +2149,9 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rms_norm_mul_pipeline_key key = {}; ggml_webgpu_rms_norm_mul_pipeline_key key = {};
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.overlap = context.overlap; key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
key.src_overlap = context.src_overlap; key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = rms_norm_mul_pipelines.find(key); auto it = rms_norm_mul_pipelines.find(key);
if (it != rms_norm_mul_pipelines.end()) { if (it != rms_norm_mul_pipelines.end()) {
@@ -2132,12 +2175,15 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
decisions->wg_size = context.max_wg_size; pipeline_decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline_decisions->inplace = key.inplace;
pipeline.context = decisions; pipeline_decisions->overlap = key.overlap;
rms_norm_mul_pipelines[key] = pipeline; pipeline_decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = pipeline_decisions;
rms_norm_mul_pipelines[key] = pipeline;
return rms_norm_mul_pipelines[key]; return rms_norm_mul_pipelines[key];
} }
@@ -2145,9 +2191,9 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_binary_pipeline_key key = {}; ggml_webgpu_binary_pipeline_key key = {};
key.type = context.dst->type; key.type = context.dst->type;
key.op = context.dst->op; key.op = context.dst->op;
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.overlap = context.overlap; key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
key.src_overlap = context.src_overlap; key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = binary_pipelines.find(key); auto it = binary_pipelines.find(key);
if (it != binary_pipelines.end()) { if (it != binary_pipelines.end()) {
@@ -2186,11 +2232,15 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_binary, defines); auto processed = preprocessor.preprocess(wgsl_binary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
decisions->wg_size = context.max_wg_size; pipeline_decisions->wg_size = context.max_wg_size;
pipeline_decisions->inplace = key.inplace;
pipeline_decisions->overlap = key.overlap;
pipeline_decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = pipeline_decisions;
binary_pipelines[key] = pipeline; binary_pipelines[key] = pipeline;
return binary_pipelines[key]; return binary_pipelines[key];
} }
@@ -2351,7 +2401,8 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
} }
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
pipeline_decisions->kv_overlap = key.kv_overlap;
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
@@ -2543,7 +2594,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {}; ggml_webgpu_rope_pipeline_key key = {};
key.type = context.dst->type; key.type = context.dst->type;
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.has_ff = (context.src2 != nullptr); key.has_ff = (context.src2 != nullptr);
auto it = rope_pipelines.find(key); auto it = rope_pipelines.find(key);
@@ -2582,6 +2633,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_rope, defines); auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size; decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
rope_pipelines[key] = pipeline; rope_pipelines[key] = pipeline;
@@ -2593,7 +2645,7 @@ class ggml_webgpu_shader_lib {
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
key.has_mask = (context.src1 != nullptr); key.has_mask = (context.src1 != nullptr);
key.has_sink = (context.src2 != nullptr); key.has_sink = (context.src2 != nullptr);
key.inplace = context.inplace; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = soft_max_pipelines.find(key); auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) { if (it != soft_max_pipelines.end()) {
@@ -2634,6 +2686,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_soft_max, defines); auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size; decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions; pipeline.context = decisions;
soft_max_pipelines[key] = pipeline; soft_max_pipelines[key] = pipeline;
+167 -170
View File
@@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
// their locations. // their locations.
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views. static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
if (tensor->view_src) { return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs;
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
} }
/* Struct definitions */ /* Struct definitions */
@@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
buffer = device.CreateBuffer(&buffer_desc); buffer = device.CreateBuffer(&buffer_desc);
} }
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
return webgpu_tensor_offset(tensor) + tensor->view_offs;
}
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
return ctx->buffer; return ctx->buffer;
@@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
} }
// Used to determine if two tensors are the same for in-place operations struct ggml_webgpu_merged_binding_range {
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { size_t offset;
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && size_t size;
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
}
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
bool src_overlap;
}; };
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range(
ggml_tensor * src1, webgpu_context & ctx,
ggml_tensor * dst) { std::initializer_list<ggml_tensor *> tensors) {
binary_overlap_flags flags = {}; size_t merged_offset = SIZE_MAX;
flags.inplace = ggml_webgpu_tensor_equal(src0, dst); size_t merged_end = 0;
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
return flags; for (ggml_tensor * tensor : tensors) {
const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor);
const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor);
merged_offset = std::min(merged_offset, bind_offset);
merged_end = std::max(merged_end, bind_end);
}
return { merged_offset, merged_end - merged_offset };
}
static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor,
const ggml_webgpu_merged_binding_range & merged_range) {
return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type));
} }
static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding,
@@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
ggml_tensor * src0, ggml_tensor * src0,
ggml_tensor * src1, ggml_tensor * src1,
ggml_tensor * dst) { ggml_tensor * dst) {
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0; shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1; shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const bool inplace = decisions->inplace;
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
@@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
ggml_tensor * dst) { ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0; shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src4 = src4;
shader_lib_ctx.src5 = src5;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_ssm_scan_shader_decisions *>(pipeline.context.get());
const bool xbc_overlap = decisions->xbc_overlap;
uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type));
uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type));
size_t xbc_bind_offset = 0;
size_t xbc_bind_size = 0;
if (xbc_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 });
xbc_bind_offset = merged_range.offset;
xbc_bind_size = merged_range.size;
offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range);
offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range);
}
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), offset_x,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), offset_B,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), offset_C,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
}; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
}; };
if (xbc_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst));
}
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
@@ -1653,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = Q;
shader_lib_ctx.src1 = K;
shader_lib_ctx.src2 = V;
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
const int has_mask = (mask != nullptr); const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr); const int has_sinks = (sinks != nullptr);
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; const bool kv_overlap = decisions->kv_overlap;
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
size_t kv_bind_offset = 0; size_t kv_bind_offset = 0;
size_t kv_bind_size = 0; size_t kv_bind_size = 0;
if (kv_overlap) { if (kv_overlap) {
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); kv_bind_offset = merged_range.offset;
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); kv_bind_size = merged_range.size;
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
kv_bind_offset = std::min(k_bind_offset, v_bind_offset); offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
} }
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
@@ -1720,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
} }
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = Q;
shader_lib_ctx.src1 = K;
shader_lib_ctx.src2 = V;
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.src_overlap = kv_overlap;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
@@ -1921,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY; bool is_unary = dst->op == GGML_OP_UNARY;
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src; shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr; shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const bool inplace = decisions->inplace;
uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -1994,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0, ggml_tensor * src0,
ggml_tensor * src1, ggml_tensor * src1,
ggml_tensor * dst) { ggml_tensor * dst) {
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0; shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1; shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = flags.inplace;
shader_lib_ctx.overlap = flags.overlap;
shader_lib_ctx.src_overlap = flags.src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
uint32_t offset_merged_src0 = 0; uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
uint32_t offset_merged_src1 = 0; uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
if (flags.src_overlap) { size_t merged_offset = 0;
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_size = 0;
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); if (decisions->src_overlap) {
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
} }
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
ne, ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), offset_src0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), offset_src1,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
@@ -2048,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> entries; std::vector<wgpu::BindGroupEntry> entries;
if (flags.src_overlap) { if (decisions->src_overlap) {
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); entries.push_back(
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset,
merged_end - merged_offset));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else { } else {
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0),
@@ -2062,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1),
src1_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset,
ggml_webgpu_tensor_binding_size(ctx, src1))); ggml_webgpu_tensor_binding_size(ctx, src1)));
if (!flags.inplace && !flags.overlap) { if (!decisions->inplace && !decisions->overlap) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
} }
} }
@@ -2168,29 +2177,15 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
} }
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type));
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); uint32_t offset_mul_src =
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type));
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); size_t merged_offset = 0;
size_t merged_size = 0;
uint32_t offset_merged_rn_src = 0;
uint32_t offset_merged_mul_src = 0;
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
if (src_overlap) {
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
offset_merged_rn_src =
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
offset_merged_mul_src =
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
}
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), offset_rn_src,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), offset_mul_src,
offset_merged_rn_src,
offset_merged_mul_src,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
@@ -2214,16 +2209,32 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
std::vector<wgpu::BindGroupEntry> entries; std::vector<wgpu::BindGroupEntry> entries;
if (inplace || overlap) { ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = rn_src;
shader_lib_ctx.src1 = mul_src;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_rms_norm_mul_shader_decisions *>(pipeline.context.get());
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range);
offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range);
params[0] = offset_rn_src;
params[1] = offset_mul_src;
}
if (decisions->inplace || decisions->overlap) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
} else if (src_overlap) { } else if (decisions->src_overlap) {
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); entries.push_back(
size_t merged_end = ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size));
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
merged_end - merged_offset));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else { } else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
@@ -2231,20 +2242,10 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
} }
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
shader_lib_ctx.overlap = overlap;
shader_lib_ctx.src_overlap = src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
} }
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -2261,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader
}; };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src; shader_lib_ctx.src0 = src;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!decisions->inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src));
} }
@@ -2287,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx,
shader_lib_ctx.src2 = src2; shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst); const bool inplace = decisions->inplace;
const int has_freq_factor = (src2 != nullptr); const int has_freq_factor = (src2 != nullptr);
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@@ -2421,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
} }
static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {}; ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src; shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr; shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -2454,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
// bindgroups unchanged // bindgroups unchanged
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!inplace) { if (!decisions->inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} }
@@ -2473,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
shader_lib_ctx.src2 = src2; shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst; shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst); const bool inplace = decisions->inplace;
const int has_mask = (src1 != nullptr); const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr); const int has_sink = (src2 != nullptr);
float max_bias = ggml_get_op_params_f32(dst, 1); float max_bias = ggml_get_op_params_f32(dst, 1);
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -3079,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
size_t size) { size_t size) {
GGML_UNUSED(backend); GGML_UNUSED(backend);
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
// Write aligned portion // Write aligned portion
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -3161,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
<< ", " << offset << ", " << size << ")"); << ", " << offset << ", " << size << ")");
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value. // This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t) value * 0x01010101; uint32_t val32 = (uint32_t) value * 0x01010101;
@@ -3180,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
<< ", " << offset << ", " << size << ")"); << ", " << offset << ", " << size << ")");
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -3212,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
<< ", " << offset << ", " << size << ")"); << ", " << offset << ", " << size << ")");
wgpu::Device device = buf_ctx->global_ctx->device; wgpu::Device device = buf_ctx->global_ctx->device;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
size_t final_size = size; size_t final_size = size;
if (size % 4 != 0) { if (size % 4 != 0) {
@@ -7,8 +7,6 @@ struct Params {
offset_src0: u32, offset_src0: u32,
offset_src1: u32, offset_src1: u32,
offset_dst: u32, offset_dst: u32,
offset_merged_src0: u32,
offset_merged_src1: u32,
stride_src0_0: u32, stride_src0_0: u32,
stride_src0_1: u32, stride_src0_1: u32,
@@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
@compute @workgroup_size(WG_SIZE) @compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) { fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) { if (gid.x < params.ne) {
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); let src0_i = params.offset_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); let src1_i = params.offset_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i); update(params.offset_dst + gid.x, src0_i, src1_i);
} }
} }
@@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
struct Params { struct Params {
offset_rn_src: u32, offset_rn_src: u32,
offset_mul_src: u32, offset_mul_src: u32,
offset_merged_rn_src: u32,
offset_merged_mul_src: u32,
offset_dst: u32, offset_dst: u32,
stride_rn_src1: u32, stride_rn_src1: u32,
@@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
i = i % (params.ne2 * params.ne1); i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1; let i2 = i / params.ne1;
let i1 = i % params.ne1; let i1 = i % params.ne1;
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
@@ -45,6 +45,14 @@ struct Params {
}; };
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>; @group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
#ifdef XBC_OVERLAP
@group(0) @binding(1) var<storage, read_write> x_B_C_merged: array<f32>;
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
@group(0) @binding(4) var<storage, read_write> ids: array<i32>;
@group(0) @binding(5) var<storage, read_write> dst: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
#else
@group(0) @binding(1) var<storage, read_write> x: array<f32>; @group(0) @binding(1) var<storage, read_write> x: array<f32>;
@group(0) @binding(2) var<storage, read_write> dt: array<f32>; @group(0) @binding(2) var<storage, read_write> dt: array<f32>;
@group(0) @binding(3) var<storage, read_write> A: array<f32>; @group(0) @binding(3) var<storage, read_write> A: array<f32>;
@@ -53,6 +61,7 @@ struct Params {
@group(0) @binding(6) var<storage, read_write> ids: array<i32>; @group(0) @binding(6) var<storage, read_write> ids: array<i32>;
@group(0) @binding(7) var<storage, read_write> dst: array<f32>; @group(0) @binding(7) var<storage, read_write> dst: array<f32>;
@group(0) @binding(8) var<uniform> params: Params; @group(0) @binding(8) var<uniform> params: Params;
#endif
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>; var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>; var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
@@ -98,7 +107,11 @@ fn main(
let dt0 = dt[dt_idx]; let dt0 = dt[dt_idx];
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
shared_dtsp[tid] = dtsp; shared_dtsp[tid] = dtsp;
#ifdef XBC_OVERLAP
shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp;
#else
shared_x_dt[tid] = x[x_idx] * dtsp; shared_x_dt[tid] = x[x_idx] * dtsp;
#endif
} }
} }
@@ -116,16 +129,28 @@ fn main(
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
#ifdef XBC_OVERLAP
let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt;
#else
let s = s_prev * dA + B[b_idx] * x_dt; let s = s_prev * dA + B[b_idx] * x_dt;
#endif
s_prev = s; s_prev = s;
#ifdef USE_SUBGROUP_REDUCTION #ifdef USE_SUBGROUP_REDUCTION
#ifdef XBC_OVERLAP
let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]);
#else
let subgroup_partial = subgroupAdd(s * C[c_idx]); let subgroup_partial = subgroupAdd(s * C[c_idx]);
#endif
if (subgroup_invocation_id == 0u) { if (subgroup_invocation_id == 0u) {
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
} }
#else
#ifdef XBC_OVERLAP
shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx];
#else #else
shared_reduce[reduce_idx] = s * C[c_idx]; shared_reduce[reduce_idx] = s * C[c_idx];
#endif
#endif #endif
workgroupBarrier(); workgroupBarrier();
+24 -7
View File
@@ -2984,7 +2984,7 @@ struct test_bin_bcast : public test_case {
bool run_whole_graph() override { return nf > 1; } bool run_whole_graph() override { return nf > 1; }
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(type, ne, nr, nf, perm1); return VARS_TO_STR6(type, ne, nr, nf, perm1, src_overlap);
} }
size_t op_size(ggml_tensor * t) override { size_t op_size(ggml_tensor * t) override {
@@ -3589,9 +3589,10 @@ struct test_ssm_scan : public test_case {
const int64_t n_group; const int64_t n_group;
const int64_t n_seq_tokens; const int64_t n_seq_tokens;
const int64_t n_seqs; const int64_t n_seqs;
const bool xbc_overlap;
std::string vars() override { std::string vars() override {
return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); return VARS_TO_STR8(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs, xbc_overlap);
} }
test_ssm_scan(ggml_type type = GGML_TYPE_F32, test_ssm_scan(ggml_type type = GGML_TYPE_F32,
@@ -3600,16 +3601,31 @@ struct test_ssm_scan : public test_case {
int64_t n_head = 32, int64_t n_head = 32,
int64_t n_group = 1, int64_t n_group = 1,
int64_t n_seq_tokens = 32, int64_t n_seq_tokens = 32,
int64_t n_seqs = 32) int64_t n_seqs = 32,
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} bool xbc_overlap = false)
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), xbc_overlap(xbc_overlap) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); ggml_tensor * x;
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); ggml_tensor * B;
ggml_tensor * C;
if (xbc_overlap) {
ggml_tensor * xbc = ggml_new_tensor_4d(ctx, type, d_state, n_head, n_seq_tokens, 2 * n_seqs);
x = ggml_view_4d(ctx, xbc, head_dim, n_head, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], xbc->nb[3]);
B = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], 0);
C = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], 2 * xbc->nb[3]);
} else {
x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
}
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out; return out;
@@ -7964,6 +7980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 128, 4, 4, 16, 2, true)); // x/B/C overlap
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));