ggml-webgpu: Support non-contiguous src0 and overlapping src0/src1 in binary ops (#19850)
* ggml-webgpu: Add binary op support for overlapping and non-contiguous. * Add newline to binary.wgsl * Append the test of binary op for src overlapping to test_bin_bcast. * Remove unnecessary newline.
This commit is contained in:
committed by
GitHub
parent
feefb92836
commit
36a7a6589c
@@ -788,6 +788,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
||||
struct binary_overlap_flags {
|
||||
bool inplace; // src0 == dst
|
||||
bool overlap; // src1 == dst
|
||||
bool src_overlap;
|
||||
};
|
||||
|
||||
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
||||
@@ -796,6 +797,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
|
||||
binary_overlap_flags flags = {};
|
||||
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
||||
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
|
||||
|
||||
return flags;
|
||||
}
|
||||
@@ -1353,6 +1355,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = flags.inplace,
|
||||
.overlap = flags.overlap,
|
||||
.src_overlap = flags.src_overlap,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
@@ -1361,11 +1364,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
|
||||
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
|
||||
|
||||
uint32_t offset_merged_src0 = 0;
|
||||
uint32_t offset_merged_src1 = 0;
|
||||
if (flags.src_overlap) {
|
||||
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
offset_merged_src0,
|
||||
offset_merged_src1,
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
@@ -1381,25 +1401,43 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
});
|
||||
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
||||
});
|
||||
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
entries.push_back({ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
if (flags.src_overlap) {
|
||||
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = merged_offset,
|
||||
.size = merged_end - merged_offset,
|
||||
});
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
||||
});
|
||||
} else {
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = src0_webgpu_tensor_align_offset,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
});
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = src1_webgpu_tensor_align_offset,
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
||||
});
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
entries.push_back({
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
@@ -2816,10 +2854,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
|
||||
// see https://github.com/ggml-org/llama.cpp/pull/16857
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
(src1->type == op->type);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
Reference in New Issue
Block a user