ggml-webgpu: add support for im2col (#22259)
* shader(im2col): implement the im2col shader * shader(im2col): clean the formatting issues * shader(im2col): clean the editorconfig checker warning * fix(shader): address the workgroup issues of im2col and conv2d
This commit is contained in:
@@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** Im2Col **/
|
||||||
|
struct ggml_webgpu_im2col_pipeline_key {
|
||||||
|
ggml_type input_type;
|
||||||
|
ggml_type output_type;
|
||||||
|
|
||||||
|
bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
|
||||||
|
return input_type == other.input_type && output_type == other.output_type;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_webgpu_im2col_pipeline_key_hash {
|
||||||
|
size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
|
||||||
|
size_t seed = 0;
|
||||||
|
ggml_webgpu_hash_combine(seed, key.input_type);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.output_type);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** Gated Delta Net **/
|
/** Gated Delta Net **/
|
||||||
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
||||||
int type;
|
int type;
|
||||||
@@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib {
|
|||||||
soft_max_pipelines;
|
soft_max_pipelines;
|
||||||
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
||||||
conv2d_pipelines;
|
conv2d_pipelines;
|
||||||
|
std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
|
||||||
|
im2col_pipelines;
|
||||||
|
|
||||||
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
|
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
|
||||||
webgpu_pipeline,
|
webgpu_pipeline,
|
||||||
@@ -2504,6 +2525,44 @@ class ggml_webgpu_shader_lib {
|
|||||||
return conv2d_pipelines[key];
|
return conv2d_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||||
|
ggml_webgpu_im2col_pipeline_key key = {};
|
||||||
|
key.input_type = context.src1->type;
|
||||||
|
key.output_type = context.dst->type;
|
||||||
|
|
||||||
|
auto it = im2col_pipelines.find(key);
|
||||||
|
if (it != im2col_pipelines.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> defines;
|
||||||
|
std::string variant = "im2col";
|
||||||
|
|
||||||
|
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
||||||
|
std::string s_prefix = prefix;
|
||||||
|
if (type == GGML_TYPE_F32) {
|
||||||
|
defines.push_back(s_prefix + "_F32");
|
||||||
|
} else if (type == GGML_TYPE_F16) {
|
||||||
|
defines.push_back(s_prefix + "_F16");
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Unsupported type for IM2COL shader");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
push_type_defines("INPUT", key.input_type);
|
||||||
|
push_type_defines("OUTPUT", key.output_type);
|
||||||
|
|
||||||
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||||
|
|
||||||
|
auto processed = preprocessor.preprocess(wgsl_im2col, defines);
|
||||||
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||||
|
decisions->wg_size = context.max_wg_size;
|
||||||
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||||
|
pipeline.context = decisions;
|
||||||
|
im2col_pipelines[key] = pipeline;
|
||||||
|
return im2col_pipelines[key];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||||
std::string shader_code,
|
std::string shader_code,
|
||||||
|
|||||||
@@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
|
|||||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||||
};
|
};
|
||||||
|
|
||||||
uint32_t max_wg_size =
|
|
||||||
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
|
|
||||||
uint32_t wg_size =
|
|
||||||
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
|
|
||||||
|
|
||||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||||
shader_lib_ctx.src0 = src0;
|
shader_lib_ctx.src0 = src0;
|
||||||
shader_lib_ctx.src1 = src1;
|
shader_lib_ctx.src1 = src1;
|
||||||
shader_lib_ctx.dst = dst;
|
shader_lib_ctx.dst = dst;
|
||||||
shader_lib_ctx.max_wg_size = wg_size;
|
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||||
|
|
||||||
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
|
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
|
||||||
|
|
||||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||||
|
|
||||||
uint32_t n_out = ggml_nelements(dst);
|
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||||
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
|
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||||
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||||
uint32_t wg_x = std::min(total_wg, max_wg);
|
|
||||||
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
|
||||||
|
ggml_tensor * src0,
|
||||||
|
ggml_tensor * src1,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||||
|
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||||
|
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
|
||||||
|
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
|
||||||
|
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
|
||||||
|
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
|
||||||
|
const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1;
|
||||||
|
|
||||||
|
const uint32_t KW = src0->ne[0];
|
||||||
|
const uint32_t KH = is_2D ? src0->ne[1] : 1;
|
||||||
|
const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1];
|
||||||
|
|
||||||
|
const uint32_t IW = src1->ne[0];
|
||||||
|
const uint32_t IH = is_2D ? src1->ne[1] : 1;
|
||||||
|
const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2];
|
||||||
|
|
||||||
|
const uint32_t OW = dst->ne[1];
|
||||||
|
const uint32_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
|
|
||||||
|
const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type));
|
||||||
|
const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0;
|
||||||
|
const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type));
|
||||||
|
const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type));
|
||||||
|
|
||||||
|
const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type));
|
||||||
|
const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type));
|
||||||
|
const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0;
|
||||||
|
const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) :
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type));
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
|
||||||
|
si0,
|
||||||
|
si1,
|
||||||
|
si2,
|
||||||
|
si3,
|
||||||
|
so0,
|
||||||
|
so1,
|
||||||
|
so2,
|
||||||
|
so3,
|
||||||
|
|
||||||
|
KW,
|
||||||
|
KH,
|
||||||
|
IC,
|
||||||
|
|
||||||
|
IW,
|
||||||
|
IH,
|
||||||
|
N,
|
||||||
|
|
||||||
|
OW,
|
||||||
|
OH,
|
||||||
|
|
||||||
|
(uint32_t) s0,
|
||||||
|
(uint32_t) s1,
|
||||||
|
(uint32_t) p0,
|
||||||
|
(uint32_t) p1,
|
||||||
|
(uint32_t) d0,
|
||||||
|
(uint32_t) d1,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1),
|
||||||
|
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||||
|
shader_lib_ctx.src0 = src0;
|
||||||
|
shader_lib_ctx.src1 = src1;
|
||||||
|
shader_lib_ctx.dst = dst;
|
||||||
|
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||||
|
|
||||||
|
webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx);
|
||||||
|
|
||||||
|
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||||
|
|
||||||
|
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||||
|
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||||
|
|
||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||||
@@ -1988,8 +2071,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
|||||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||||
|
|
||||||
uint32_t offset_merged_rn_src = 0;
|
uint32_t offset_merged_rn_src = 0;
|
||||||
@@ -2689,6 +2772,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
|||||||
return ggml_webgpu_sum_rows(ctx, src0, node);
|
return ggml_webgpu_sum_rows(ctx, src0, node);
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
return ggml_webgpu_im2col(ctx, src0, src1, node);
|
||||||
default:
|
default:
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
@@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||||||
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
||||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||||
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
||||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||||
webgpu_ctx->param_arena.init(
|
webgpu_ctx->param_arena.init(
|
||||||
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||||
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
|
webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN,
|
||||||
@@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
const bool has_mask = op->src[3] != nullptr;
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
||||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
||||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
||||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
||||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
||||||
if (min_bytes > limit_bytes) {
|
if (min_bytes > limit_bytes) {
|
||||||
@@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
||||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
supports_op = op->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
#include "common_decls.tmpl"
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
#if defined(INPUT_F32)
|
||||||
|
var<storage, read_write> input: array<f32>;
|
||||||
|
#elif defined(INPUT_F16)
|
||||||
|
var<storage, read_write> input: array<f16>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
#if defined(OUTPUT_F32)
|
||||||
|
var<storage, read_write> output: array<f32>;
|
||||||
|
#elif defined(OUTPUT_F16)
|
||||||
|
var<storage, read_write> output: array<f16>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_i: u32,
|
||||||
|
offset_o: u32,
|
||||||
|
|
||||||
|
// element strides
|
||||||
|
si0: u32, si1: u32, si2: u32, si3: u32,
|
||||||
|
so0: u32, so1: u32, so2: u32, so3: u32,
|
||||||
|
|
||||||
|
KW: u32, KH: u32, IC: u32,
|
||||||
|
IW: u32, IH: u32, N: u32,
|
||||||
|
OW: u32, OH: u32,
|
||||||
|
|
||||||
|
// stride
|
||||||
|
s0: u32, s1: u32,
|
||||||
|
// padding
|
||||||
|
p0: u32, p1: u32,
|
||||||
|
// dilation
|
||||||
|
d0: u32, d1: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn load_input(idx: u32) -> f32 {
|
||||||
|
#if defined(INPUT_F32)
|
||||||
|
return input[idx];
|
||||||
|
#elif defined(INPUT_F16)
|
||||||
|
return f32(input[idx]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_output(idx: u32, val: f32) {
|
||||||
|
#if defined(OUTPUT_F32)
|
||||||
|
output[idx] = val;
|
||||||
|
#elif defined(OUTPUT_F16)
|
||||||
|
output[idx] = f16(val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
@compute @workgroup_size(WG_SIZE)
|
||||||
|
fn main(
|
||||||
|
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||||
|
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||||
|
) {
|
||||||
|
|
||||||
|
let threads_per_group = u32(WG_SIZE);
|
||||||
|
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
|
||||||
|
let K = params.KW * params.KH * params.IC;
|
||||||
|
let M = params.OW * params.OH;
|
||||||
|
let total = K * M * params.N;
|
||||||
|
|
||||||
|
if (i_out >= total) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode (k, m, n)
|
||||||
|
var i = i_out;
|
||||||
|
let n = i / (K * M);
|
||||||
|
i = i % (K * M);
|
||||||
|
let m = i / K;
|
||||||
|
let k = i % K;
|
||||||
|
|
||||||
|
// decode (oh, ow)
|
||||||
|
let oh = m / params.OW;
|
||||||
|
let ow = m % params.OW;
|
||||||
|
|
||||||
|
// decode (kw, kh, ic)
|
||||||
|
let kw = k % params.KW;
|
||||||
|
let tmp = k / params.KW;
|
||||||
|
let kh = tmp % params.KH;
|
||||||
|
let ic = tmp / params.KH;
|
||||||
|
|
||||||
|
let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0);
|
||||||
|
let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1);
|
||||||
|
|
||||||
|
if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) {
|
||||||
|
let iw = u32(iw_i32);
|
||||||
|
let ih = u32(ih_i32);
|
||||||
|
let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3;
|
||||||
|
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx));
|
||||||
|
} else {
|
||||||
|
store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user