ggml-webgpu: Add supports for GGML_OP_REPEAT (#20230)
* Add GGML_OP_REPEAT to webgpu backend. * Add i16 support for GGML_OP_REPEAT.
This commit is contained in:
committed by
GitHub
parent
d28961d81e
commit
f2ab047f27
+1
-1
@@ -80,7 +80,7 @@ Legend:
|
|||||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
|
|||||||
+14
-14
@@ -5023,20 +5023,20 @@
|
|||||||
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
|
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
|
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
|
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
|
||||||
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
|
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
|
||||||
|
|||||||
|
Can't render this file because it is too large.
|
@@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** Repeat **/
|
||||||
|
|
||||||
|
struct ggml_webgpu_repeat_pipeline_key {
|
||||||
|
int type;
|
||||||
|
|
||||||
|
bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_webgpu_repeat_pipeline_key_hash {
|
||||||
|
size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
|
||||||
|
size_t seed = 0;
|
||||||
|
ggml_webgpu_hash_combine(seed, key.type);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** Binary **/
|
/** Binary **/
|
||||||
|
|
||||||
struct ggml_webgpu_binary_pipeline_key {
|
struct ggml_webgpu_binary_pipeline_key {
|
||||||
@@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib {
|
|||||||
binary_pipelines; // type/op/inplace/overlap
|
binary_pipelines; // type/op/inplace/overlap
|
||||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||||
concat_pipelines; // type
|
concat_pipelines; // type
|
||||||
|
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||||
|
repeat_pipelines; // type
|
||||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||||
flash_attn_pipelines;
|
flash_attn_pipelines;
|
||||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||||
@@ -1147,7 +1165,7 @@ class ggml_webgpu_shader_lib {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> defines;
|
std::vector<std::string> defines;
|
||||||
std::string variant = "concat";
|
std::string variant = "concat";
|
||||||
|
|
||||||
switch (key.type) {
|
switch (key.type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
@@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib {
|
|||||||
|
|
||||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||||
|
|
||||||
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
||||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||||
decisions->wg_size = context.max_wg_size;
|
decisions->wg_size = context.max_wg_size;
|
||||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||||
pipeline.context = decisions;
|
pipeline.context = decisions;
|
||||||
concat_pipelines[key] = pipeline;
|
concat_pipelines[key] = pipeline;
|
||||||
return concat_pipelines[key];
|
return concat_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
|
ggml_webgpu_repeat_pipeline_key key = {
|
||||||
|
.type = context.dst->type,
|
||||||
|
};
|
||||||
|
|
||||||
|
auto it = repeat_pipelines.find(key);
|
||||||
|
if (it != repeat_pipelines.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> defines;
|
||||||
|
std::string variant = "repeat";
|
||||||
|
|
||||||
|
switch (key.type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
defines.push_back("TYPE_F32");
|
||||||
|
variant += "_f32";
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_I32:
|
||||||
|
defines.push_back("TYPE_I32");
|
||||||
|
variant += "_i32";
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_I16:
|
||||||
|
defines.push_back("TYPE_I16");
|
||||||
|
variant += "_i16";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("Unsupported type for repeat shader");
|
||||||
|
}
|
||||||
|
|
||||||
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||||
|
|
||||||
|
auto processed = preprocessor.preprocess(wgsl_repeat, defines);
|
||||||
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||||
|
decisions->wg_size = context.max_wg_size;
|
||||||
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||||
|
pipeline.context = decisions;
|
||||||
|
repeat_pipelines[key] = pipeline;
|
||||||
|
return repeat_pipelines[key];
|
||||||
|
}
|
||||||
|
|
||||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
const bool has_mask = context.src3 != nullptr;
|
const bool has_mask = context.src3 != nullptr;
|
||||||
const bool has_sinks = context.src4 != nullptr;
|
const bool has_sinks = context.src4 != nullptr;
|
||||||
|
|||||||
@@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
|||||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = { ne,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
|
||||||
|
ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->ne[0]),
|
||||||
|
(uint32_t) (src0->ne[1]),
|
||||||
|
(uint32_t) (src0->ne[2]),
|
||||||
|
(uint32_t) (src0->ne[3]),
|
||||||
|
(uint32_t) (dst->ne[0]),
|
||||||
|
(uint32_t) (dst->ne[1]),
|
||||||
|
(uint32_t) (dst->ne[2]) };
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
{ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||||
|
.src0 = src0,
|
||||||
|
.dst = dst,
|
||||||
|
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||||
|
};
|
||||||
|
|
||||||
|
webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
|
||||||
|
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||||
|
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||||
|
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||||
|
}
|
||||||
|
|
||||||
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
@@ -2158,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||||||
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
return ggml_webgpu_concat(ctx, src0, src1, node);
|
return ggml_webgpu_concat(ctx, src0, src1, node);
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
|
return ggml_webgpu_repeat(ctx, src0, node);
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
@@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|||||||
/* .iface = */ {
|
/* .iface = */ {
|
||||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||||
/* .alloc_buffer = */
|
/* .alloc_buffer = */
|
||||||
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
||||||
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
||||||
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
||||||
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
||||||
},
|
},
|
||||||
/* .device = */
|
/* .device = */
|
||||||
dev,
|
dev,
|
||||||
@@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_REPEAT:
|
||||||
|
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
|
||||||
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32,
|
||||||
|
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
stride_src0_0: u32,
|
||||||
|
stride_src0_1: u32,
|
||||||
|
stride_src0_2: u32,
|
||||||
|
stride_src0_3: u32,
|
||||||
|
|
||||||
|
a_ne0: u32,
|
||||||
|
a_ne1: u32,
|
||||||
|
a_ne2: u32,
|
||||||
|
a_ne3: u32,
|
||||||
|
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef TYPE_F32
|
||||||
|
#define DataType f32
|
||||||
|
#endif
|
||||||
|
#ifdef TYPE_I32
|
||||||
|
#define DataType i32
|
||||||
|
#endif
|
||||||
|
#ifdef TYPE_I16
|
||||||
|
// same size (16-bit) is sufficient for repeat
|
||||||
|
#define DataType f16
|
||||||
|
#endif
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<DataType>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<DataType>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
@compute @workgroup_size(WG_SIZE)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let a_i0 = i0 % params.a_ne0;
|
||||||
|
let a_i1 = i1 % params.a_ne1;
|
||||||
|
let a_i2 = i2 % params.a_ne2;
|
||||||
|
let a_i3 = i3 % params.a_ne3;
|
||||||
|
|
||||||
|
let a_index = a_i0 * params.stride_src0_0 +
|
||||||
|
a_i1 * params.stride_src0_1 +
|
||||||
|
a_i2 * params.stride_src0_2 +
|
||||||
|
a_i3 * params.stride_src0_3;
|
||||||
|
|
||||||
|
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user