ggml-webgpu: Add fused RMS_NORM + MUL (#21983)
* fused rms_norm_mul + mul * Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion. * Decouple num_fused_ops from webgpu_context; misc cleanup * Fix eps handling and remove disable_fusion. * Fix not to use c++20 initializers.
This commit is contained in:
committed by
GitHub
parent
8bccdbbff9
commit
6da7168312
@@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** RMS_NORM + MUL **/
|
||||||
|
|
||||||
|
struct ggml_webgpu_rms_norm_mul_pipeline_key {
|
||||||
|
bool inplace;
|
||||||
|
bool src_overlap;
|
||||||
|
|
||||||
|
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
|
||||||
|
return inplace == other.inplace && src_overlap == other.src_overlap;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
||||||
|
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
|
||||||
|
size_t seed = 0;
|
||||||
|
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** Pad **/
|
/** Pad **/
|
||||||
struct ggml_webgpu_pad_pipeline_key {
|
struct ggml_webgpu_pad_pipeline_key {
|
||||||
bool circular;
|
bool circular;
|
||||||
@@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
|||||||
const size_t q_tile = context.sg_mat_m;
|
const size_t q_tile = context.sg_mat_m;
|
||||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||||
size_t bytes_per_kv = 0;
|
size_t bytes_per_kv = 0;
|
||||||
if (!key.kv_direct) {
|
if (!key.kv_direct) {
|
||||||
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
||||||
}
|
}
|
||||||
@@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib {
|
|||||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||||
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
||||||
row_norm_pipelines; // op/inplace
|
row_norm_pipelines; // op/inplace
|
||||||
|
|
||||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
||||||
get_rows_pipelines; // src_type, vectorized
|
get_rows_pipelines; // src_type, vectorized
|
||||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||||
unary_pipelines; // type/op/inplace
|
unary_pipelines; // type/op/inplace
|
||||||
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
||||||
scale_pipelines; // inplace
|
scale_pipelines; // inplace
|
||||||
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
||||||
solve_tri_pipelines; // type
|
solve_tri_pipelines; // type
|
||||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||||
ssm_conv_pipelines; // type/vectorized
|
ssm_conv_pipelines; // type/vectorized
|
||||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||||
webgpu_pipeline,
|
webgpu_pipeline,
|
||||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||||
@@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib {
|
|||||||
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
||||||
conv2d_pipelines;
|
conv2d_pipelines;
|
||||||
|
|
||||||
|
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
|
||||||
|
webgpu_pipeline,
|
||||||
|
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
|
||||||
|
rms_norm_mul_pipelines;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||||
|
|
||||||
@@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib {
|
|||||||
return unary_pipelines[key];
|
return unary_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
|
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
||||||
|
key.inplace = context.inplace;
|
||||||
|
key.src_overlap = context.src_overlap;
|
||||||
|
|
||||||
|
auto it = rms_norm_mul_pipelines.find(key);
|
||||||
|
if (it != rms_norm_mul_pipelines.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> defines;
|
||||||
|
std::string op_name = "RMS_NORM_MUL";
|
||||||
|
std::string variant = op_name;
|
||||||
|
|
||||||
|
if (key.inplace) {
|
||||||
|
defines.push_back("INPLACE");
|
||||||
|
variant += "_inplace";
|
||||||
|
} else if (key.src_overlap) {
|
||||||
|
defines.push_back("SRC_OVERLAP");
|
||||||
|
variant += "_src_overlap";
|
||||||
|
}
|
||||||
|
|
||||||
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||||
|
|
||||||
|
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
||||||
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||||
|
decisions->wg_size = context.max_wg_size;
|
||||||
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||||
|
pipeline.context = decisions;
|
||||||
|
rms_norm_mul_pipelines[key] = pipeline;
|
||||||
|
return rms_norm_mul_pipelines[key];
|
||||||
|
}
|
||||||
|
|
||||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
ggml_webgpu_binary_pipeline_key key = {};
|
ggml_webgpu_binary_pipeline_key key = {};
|
||||||
key.type = context.dst->type;
|
key.type = context.dst->type;
|
||||||
|
|||||||
@@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor *
|
|||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context & ctx,
|
||||||
|
ggml_tensor * rn_src,
|
||||||
|
ggml_tensor * rn_dst,
|
||||||
|
ggml_tensor * mul_src0,
|
||||||
|
ggml_tensor * mul_src1,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
ggml_tensor * mul_src;
|
||||||
|
|
||||||
|
if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) {
|
||||||
|
mul_src = mul_src1;
|
||||||
|
} else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) {
|
||||||
|
mul_src = mul_src0;
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||||
|
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||||
|
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||||
|
|
||||||
|
uint32_t offset_merged_rn_src = 0;
|
||||||
|
uint32_t offset_merged_mul_src = 0;
|
||||||
|
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
|
||||||
|
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
|
||||||
|
|
||||||
|
if (src_overlap) {
|
||||||
|
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||||
|
offset_merged_rn_src =
|
||||||
|
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
|
||||||
|
offset_merged_mul_src =
|
||||||
|
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
|
||||||
|
offset_merged_rn_src,
|
||||||
|
offset_merged_mul_src,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
|
||||||
|
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
|
||||||
|
(uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)),
|
||||||
|
(uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)),
|
||||||
|
(uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)),
|
||||||
|
(uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) mul_src->ne[0],
|
||||||
|
(uint32_t) mul_src->ne[1],
|
||||||
|
(uint32_t) mul_src->ne[2],
|
||||||
|
(uint32_t) mul_src->ne[3],
|
||||||
|
(uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1],
|
||||||
|
(uint32_t) dst->ne[2],
|
||||||
|
(uint32_t) dst->ne[3],
|
||||||
|
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries;
|
||||||
|
|
||||||
|
if (inplace) {
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||||
|
} else if (src_overlap) {
|
||||||
|
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||||
|
size_t merged_end =
|
||||||
|
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
|
||||||
|
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
|
||||||
|
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
|
||||||
|
merged_end - merged_offset));
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||||
|
} else {
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||||
|
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||||
|
shader_lib_ctx.inplace = inplace;
|
||||||
|
shader_lib_ctx.src_overlap = src_overlap;
|
||||||
|
|
||||||
|
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||||
|
|
||||||
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
|
||||||
|
}
|
||||||
|
|
||||||
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
@@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor
|
|||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) {
|
||||||
|
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// additional constraints specific to this fusion
|
||||||
|
const ggml_tensor * rms_norm = cgraph->nodes[node_idx];
|
||||||
|
const ggml_tensor * mul = cgraph->nodes[node_idx + 1];
|
||||||
|
|
||||||
|
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||||
|
// rms_norm only supports f32
|
||||||
|
if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// if rms_norm is the B operand, then we don't handle broadcast
|
||||||
|
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// rms_norm shader assumes contiguous rows
|
||||||
|
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the encoded command, or std::nullopt if the operation is a no-op
|
// Returns the encoded command, or std::nullopt if the operation is a no-op
|
||||||
static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||||
|
ggml_cgraph * cgraph,
|
||||||
|
int node_idx,
|
||||||
|
int & num_encoded_ops) {
|
||||||
|
ggml_tensor ** nodes = cgraph->nodes;
|
||||||
|
ggml_tensor * node = nodes[node_idx];
|
||||||
|
|
||||||
if (ggml_is_empty(node)) {
|
if (ggml_is_empty(node)) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
|
WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")");
|
||||||
|
|
||||||
ggml_tensor * src0 = node->src[0];
|
ggml_tensor * src0 = node->src[0];
|
||||||
ggml_tensor * src1 = node->src[1];
|
ggml_tensor * src1 = node->src[1];
|
||||||
@@ -2519,6 +2640,13 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
|
|||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
return ggml_webgpu_repeat(ctx, src0, node);
|
return ggml_webgpu_repeat(ctx, src0, node);
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) {
|
||||||
|
num_encoded_ops = 2;
|
||||||
|
ggml_tensor * mul_node = nodes[node_idx + 1];
|
||||||
|
return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node);
|
||||||
|
} else {
|
||||||
|
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||||
|
}
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
@@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|||||||
uint32_t num_inflight_batches = 0;
|
uint32_t num_inflight_batches = 0;
|
||||||
bool contains_set_rows = false;
|
bool contains_set_rows = false;
|
||||||
bool batch_compute_passes = true;
|
bool batch_compute_passes = true;
|
||||||
|
int num_encoded_ops = 1;
|
||||||
|
int node_idx = 0;
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||||
ctx->profile_timestamp_query_count = 0;
|
ctx->profile_timestamp_query_count = 0;
|
||||||
@@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|||||||
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
|
ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
while (node_idx < cgraph->n_nodes) {
|
||||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) {
|
||||||
contains_set_rows = true;
|
contains_set_rows = true;
|
||||||
}
|
}
|
||||||
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) {
|
||||||
commands.push_back(*cmd);
|
commands.push_back(*cmd);
|
||||||
num_batched_kernels += cmd.value().num_kernels;
|
num_batched_kernels += cmd.value().num_kernels;
|
||||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||||
@@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
|||||||
ctx->param_arena.reset();
|
ctx->param_arena.reset();
|
||||||
commands.clear();
|
commands.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node_idx += num_encoded_ops;
|
||||||
|
num_encoded_ops = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx->active_compute_pass) {
|
if (ctx->active_compute_pass) {
|
||||||
@@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||||||
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
||||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||||
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
||||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||||
webgpu_ctx->param_arena.init(
|
webgpu_ctx->param_arena.init(
|
||||||
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||||
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
|
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
|
||||||
@@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
const bool has_mask = op->src[3] != nullptr;
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
||||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
||||||
if (min_bytes > limit_bytes) {
|
if (min_bytes > limit_bytes) {
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
#ifdef INPLACE
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> rn_src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> mul_src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
|
||||||
|
mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif SRC_OVERLAP
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> merged_src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
|
||||||
|
dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> rn_src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> mul_src: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
|
||||||
|
dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_rn_src: u32,
|
||||||
|
offset_mul_src: u32,
|
||||||
|
offset_merged_rn_src: u32,
|
||||||
|
offset_merged_mul_src: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
stride_rn_src1: u32,
|
||||||
|
stride_rn_src2: u32,
|
||||||
|
stride_rn_src3: u32,
|
||||||
|
|
||||||
|
stride_mul_src1: u32,
|
||||||
|
stride_mul_src2: u32,
|
||||||
|
stride_mul_src3: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
mul_src_ne0: u32,
|
||||||
|
mul_src_ne1: u32,
|
||||||
|
mul_src_ne2: u32,
|
||||||
|
mul_src_ne3: u32,
|
||||||
|
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
ne3: u32,
|
||||||
|
|
||||||
|
eps: f32
|
||||||
|
};
|
||||||
|
|
||||||
|
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||||
|
|
||||||
|
@compute @workgroup_size(WG_SIZE)
|
||||||
|
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||||
|
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||||
|
|
||||||
|
// one thread per row
|
||||||
|
var i = wid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1);
|
||||||
|
i = i % (params.ne2 * params.ne1);
|
||||||
|
let i2 = i / params.ne1;
|
||||||
|
let i1 = i % params.ne1;
|
||||||
|
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
|
||||||
|
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
|
||||||
|
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||||
|
|
||||||
|
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||||
|
|
||||||
|
var sum = 0.0f;
|
||||||
|
var col = lid.x;
|
||||||
|
for (var j: u32 = 0; j < elems; j++) {
|
||||||
|
if (col >= params.ne0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#ifdef SRC_OVERLAP
|
||||||
|
sum += pow(merged_src[i_rn_src_row + col], 2.0);
|
||||||
|
#else
|
||||||
|
sum += pow(rn_src[i_rn_src_row + col], 2.0);
|
||||||
|
#endif
|
||||||
|
col += WG_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
scratch[lid.x] = sum;
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
var offset: u32 = WG_SIZE / 2;
|
||||||
|
while (offset > 0) {
|
||||||
|
if (lid.x < offset) {
|
||||||
|
scratch[lid.x] += scratch[lid.x + offset];
|
||||||
|
}
|
||||||
|
offset = offset / 2;
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
sum = scratch[0];
|
||||||
|
|
||||||
|
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||||
|
|
||||||
|
col = lid.x;
|
||||||
|
for (var j: u32 = 0; j < elems; j++) {
|
||||||
|
if (col >= params.ne0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0);
|
||||||
|
col += WG_SIZE;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user