ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (#22456)
* Refactor buffer aliasing to be part of shader lib decisions * cleanup * formatting
This commit is contained in:
@@ -59,19 +59,32 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
// Calculates base address of a tensor ignoring the fake base pointer
|
||||
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
|
||||
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
return (uintptr_t) base_tensor->data + tensor->view_offs;
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
||||
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
|
||||
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
|
||||
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
|
||||
}
|
||||
|
||||
struct ggml_webgpu_shader_lib_context {
|
||||
ggml_tensor * src0;
|
||||
ggml_tensor * src1;
|
||||
ggml_tensor * src2;
|
||||
ggml_tensor * src3;
|
||||
ggml_tensor * src4;
|
||||
ggml_tensor * src5;
|
||||
ggml_tensor * dst;
|
||||
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
@@ -88,6 +101,14 @@ struct webgpu_pipeline {
|
||||
|
||||
struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
@@ -104,9 +125,10 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
struct ggml_webgpu_ssm_scan_pipeline_key {
|
||||
int type;
|
||||
int d_state;
|
||||
bool xbc_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
||||
return type == other.type && d_state == other.d_state;
|
||||
return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.d_state);
|
||||
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
||||
struct ggml_webgpu_ssm_scan_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
uint32_t tokens_per_tile;
|
||||
bool xbc_overlap = false;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
@@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_rms_norm_mul_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
};
|
||||
|
||||
/** Pad **/
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
@@ -508,6 +539,7 @@ struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
bool kv_overlap = false;
|
||||
};
|
||||
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||
@@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.kv_overlap = context.src_overlap;
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
@@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {};
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
if (it != row_norm_pipelines.end()) {
|
||||
@@ -1052,7 +1084,11 @@ class ggml_webgpu_shader_lib {
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
row_norm_pipelines[key].context = decisions;
|
||||
return row_norm_pipelines[key];
|
||||
}
|
||||
|
||||
@@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
@@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
set_pipelines[key] = pipeline;
|
||||
@@ -1355,7 +1392,7 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_scale_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = scale_pipelines.find(key);
|
||||
if (it != scale_pipelines.end()) {
|
||||
@@ -1375,6 +1412,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
scale_pipelines[key] = pipeline;
|
||||
@@ -1468,6 +1506,8 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
@@ -1499,12 +1539,17 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_wg_reduce";
|
||||
}
|
||||
|
||||
if (key.xbc_overlap) {
|
||||
defines.push_back("XBC_OVERLAP");
|
||||
}
|
||||
|
||||
variant += "_d" + std::to_string(key.d_state);
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->tokens_per_tile = tokens_per_tile;
|
||||
decisions->xbc_overlap = key.xbc_overlap;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_scan_pipelines[key] = pipeline;
|
||||
@@ -1764,11 +1809,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
uint32_t tile_k;
|
||||
if (key.use_subgroup_matrix) {
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
||||
} else {
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
}
|
||||
|
||||
// Tiles
|
||||
@@ -2001,9 +2044,8 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("SCALAR");
|
||||
|
||||
// mul_mat_id is register-tile only.
|
||||
const uint32_t tile_k = ggml_is_quantized(context.src0->type)
|
||||
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
|
||||
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
const uint32_t tile_k =
|
||||
ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
||||
|
||||
// Tiles
|
||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
@@ -2039,7 +2081,7 @@ class ggml_webgpu_shader_lib {
|
||||
key.type = context.dst->type;
|
||||
key.op = op;
|
||||
key.is_unary = is_unary;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
|
||||
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
@@ -2098,6 +2140,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
unary_pipelines[key] = pipeline;
|
||||
@@ -2106,9 +2149,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
||||
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
||||
|
||||
auto it = rms_norm_mul_pipelines.find(key);
|
||||
if (it != rms_norm_mul_pipelines.end()) {
|
||||
@@ -2133,10 +2176,13 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
|
||||
pipeline_decisions->wg_size = context.max_wg_size;
|
||||
pipeline_decisions->inplace = key.inplace;
|
||||
pipeline_decisions->overlap = key.overlap;
|
||||
pipeline_decisions->src_overlap = key.src_overlap;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
pipeline.context = pipeline_decisions;
|
||||
rms_norm_mul_pipelines[key] = pipeline;
|
||||
return rms_norm_mul_pipelines[key];
|
||||
}
|
||||
@@ -2145,9 +2191,9 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_binary_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
||||
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
||||
|
||||
auto it = binary_pipelines.find(key);
|
||||
if (it != binary_pipelines.end()) {
|
||||
@@ -2187,10 +2233,14 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
||||
pipeline_decisions->wg_size = context.max_wg_size;
|
||||
pipeline_decisions->inplace = key.inplace;
|
||||
pipeline_decisions->overlap = key.overlap;
|
||||
pipeline_decisions->src_overlap = key.src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
pipeline.context = pipeline_decisions;
|
||||
binary_pipelines[key] = pipeline;
|
||||
return binary_pipelines[key];
|
||||
}
|
||||
@@ -2352,6 +2402,7 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
pipeline_decisions->kv_overlap = key.kv_overlap;
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||
@@ -2543,7 +2594,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rope_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
key.has_ff = (context.src2 != nullptr);
|
||||
|
||||
auto it = rope_pipelines.find(key);
|
||||
@@ -2582,6 +2633,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_rope, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
rope_pipelines[key] = pipeline;
|
||||
@@ -2593,7 +2645,7 @@ class ggml_webgpu_shader_lib {
|
||||
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
|
||||
key.has_mask = (context.src1 != nullptr);
|
||||
key.has_sink = (context.src2 != nullptr);
|
||||
key.inplace = context.inplace;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = soft_max_pipelines.find(key);
|
||||
if (it != soft_max_pipelines.end()) {
|
||||
@@ -2634,6 +2686,7 @@ class ggml_webgpu_shader_lib {
|
||||
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->inplace = key.inplace;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
soft_max_pipelines[key] = pipeline;
|
||||
|
||||
@@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
// Always returns the base offset of a tensor, regardless of views.
|
||||
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
if (tensor->view_src) {
|
||||
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
|
||||
}
|
||||
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs;
|
||||
}
|
||||
|
||||
/* Struct definitions */
|
||||
@@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
buffer = device.CreateBuffer(&buffer_desc);
|
||||
}
|
||||
|
||||
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
||||
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
|
||||
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
||||
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
return ctx->buffer;
|
||||
@@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
|
||||
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
|
||||
// Used to determine if two tensors are the same for in-place operations
|
||||
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
|
||||
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
|
||||
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
|
||||
}
|
||||
|
||||
struct binary_overlap_flags {
|
||||
bool inplace; // src0 == dst
|
||||
bool overlap; // src1 == dst
|
||||
bool src_overlap;
|
||||
struct ggml_webgpu_merged_binding_range {
|
||||
size_t offset;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = {};
|
||||
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
||||
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
|
||||
static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range(
|
||||
webgpu_context & ctx,
|
||||
std::initializer_list<ggml_tensor *> tensors) {
|
||||
size_t merged_offset = SIZE_MAX;
|
||||
size_t merged_end = 0;
|
||||
|
||||
return flags;
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor);
|
||||
const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor);
|
||||
|
||||
merged_offset = std::min(merged_offset, bind_offset);
|
||||
merged_end = std::max(merged_end, bind_end);
|
||||
}
|
||||
|
||||
return { merged_offset, merged_end - merged_offset };
|
||||
}
|
||||
|
||||
static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor,
|
||||
const ggml_webgpu_merged_binding_range & merged_range) {
|
||||
return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type));
|
||||
}
|
||||
|
||||
static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding,
|
||||
@@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
const bool inplace = decisions->inplace;
|
||||
|
||||
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
|
||||
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
|
||||
@@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src4 = src4;
|
||||
shader_lib_ctx.src5 = src5;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_ssm_scan_shader_decisions *>(pipeline.context.get());
|
||||
const bool xbc_overlap = decisions->xbc_overlap;
|
||||
|
||||
uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
|
||||
uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type));
|
||||
uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type));
|
||||
size_t xbc_bind_offset = 0;
|
||||
size_t xbc_bind_size = 0;
|
||||
if (xbc_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 });
|
||||
xbc_bind_offset = merged_range.offset;
|
||||
xbc_bind_size = merged_range.size;
|
||||
offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
|
||||
offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range);
|
||||
offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
offset_x,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
|
||||
offset_B,
|
||||
offset_C,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
@@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
};
|
||||
if (xbc_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst));
|
||||
}
|
||||
|
||||
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
@@ -1653,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = Q;
|
||||
shader_lib_ctx.src1 = K;
|
||||
shader_lib_ctx.src2 = V;
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
|
||||
const bool kv_overlap = decisions->kv_overlap;
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
if (kv_overlap) {
|
||||
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
|
||||
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
|
||||
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
|
||||
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
|
||||
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
|
||||
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
|
||||
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
|
||||
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
|
||||
const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
|
||||
kv_bind_offset = merged_range.offset;
|
||||
kv_bind_size = merged_range.size;
|
||||
offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
|
||||
offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
@@ -1720,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
}
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = Q;
|
||||
shader_lib_ctx.src1 = K;
|
||||
shader_lib_ctx.src2 = V;
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.src_overlap = kv_overlap;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
|
||||
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
@@ -1921,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.src1 = nullptr;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
const bool inplace = decisions->inplace;
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
@@ -1994,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = flags.inplace;
|
||||
shader_lib_ctx.overlap = flags.overlap;
|
||||
shader_lib_ctx.src_overlap = flags.src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
|
||||
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
|
||||
|
||||
uint32_t offset_merged_src0 = 0;
|
||||
uint32_t offset_merged_src1 = 0;
|
||||
if (flags.src_overlap) {
|
||||
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
||||
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
|
||||
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
|
||||
size_t merged_offset = 0;
|
||||
size_t merged_size = 0;
|
||||
if (decisions->src_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
|
||||
merged_offset = merged_range.offset;
|
||||
merged_size = merged_range.size;
|
||||
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
|
||||
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
offset_src0,
|
||||
offset_src1,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
offset_merged_src0,
|
||||
offset_merged_src1,
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
@@ -2048,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (flags.src_overlap) {
|
||||
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
||||
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset,
|
||||
merged_end - merged_offset));
|
||||
if (decisions->src_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0),
|
||||
@@ -2062,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1),
|
||||
src1_webgpu_tensor_align_offset,
|
||||
ggml_webgpu_tensor_binding_size(ctx, src1)));
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
if (!decisions->inplace && !decisions->overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
}
|
||||
}
|
||||
@@ -2168,29 +2177,15 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
uint32_t offset_merged_mul_src = 0;
|
||||
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
|
||||
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
|
||||
|
||||
if (src_overlap) {
|
||||
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||
offset_merged_rn_src =
|
||||
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
|
||||
offset_merged_mul_src =
|
||||
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
|
||||
}
|
||||
uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type));
|
||||
uint32_t offset_mul_src =
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type));
|
||||
size_t merged_offset = 0;
|
||||
size_t merged_size = 0;
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
|
||||
offset_merged_rn_src,
|
||||
offset_merged_mul_src,
|
||||
offset_rn_src,
|
||||
offset_mul_src,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
|
||||
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
|
||||
@@ -2214,16 +2209,32 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (inplace || overlap) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = rn_src;
|
||||
shader_lib_ctx.src1 = mul_src;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_rms_norm_mul_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
if (decisions->src_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src });
|
||||
merged_offset = merged_range.offset;
|
||||
merged_size = merged_range.size;
|
||||
offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range);
|
||||
offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range);
|
||||
params[0] = offset_rn_src;
|
||||
params[1] = offset_mul_src;
|
||||
}
|
||||
|
||||
if (decisions->inplace || decisions->overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||
} else if (src_overlap) {
|
||||
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
|
||||
size_t merged_end =
|
||||
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
|
||||
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
|
||||
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
|
||||
merged_end - merged_offset));
|
||||
} else if (decisions->src_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
@@ -2231,20 +2242,10 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
shader_lib_ctx.overlap = overlap;
|
||||
shader_lib_ctx.src_overlap = src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
@@ -2261,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor
|
||||
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
if (!inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
if (!decisions->inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src));
|
||||
}
|
||||
|
||||
@@ -2287,13 +2288,12 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx,
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
const bool inplace = decisions->inplace;
|
||||
const int has_freq_factor = (src2 != nullptr);
|
||||
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@@ -2421,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.src1 = nullptr;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
@@ -2454,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
|
||||
// bindgroups unchanged
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
|
||||
|
||||
if (!inplace) {
|
||||
if (!decisions->inplace) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
@@ -2473,11 +2470,11 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
const bool inplace = decisions->inplace;
|
||||
const int has_mask = (src1 != nullptr);
|
||||
const int has_sink = (src2 != nullptr);
|
||||
float max_bias = ggml_get_op_params_f32(dst, 1);
|
||||
@@ -3079,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
|
||||
size_t size) {
|
||||
GGML_UNUSED(backend);
|
||||
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
// Write aligned portion
|
||||
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
||||
@@ -3161,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
||||
uint32_t val32 = (uint32_t) value * 0x01010101;
|
||||
@@ -3180,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
||||
|
||||
@@ -3212,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
<< ", " << offset << ", " << size << ")");
|
||||
wgpu::Device device = buf_ctx->global_ctx->device;
|
||||
|
||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
|
||||
|
||||
size_t final_size = size;
|
||||
if (size % 4 != 0) {
|
||||
|
||||
@@ -7,8 +7,6 @@ struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
offset_merged_src0: u32,
|
||||
offset_merged_src1: u32,
|
||||
|
||||
stride_src0_0: u32,
|
||||
stride_src0_1: u32,
|
||||
@@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
|
||||
let src0_i = params.offset_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + src1_index(gid.x);
|
||||
update(params.offset_dst + gid.x, src0_i, src1_i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
|
||||
struct Params {
|
||||
offset_rn_src: u32,
|
||||
offset_mul_src: u32,
|
||||
offset_merged_rn_src: u32,
|
||||
offset_merged_mul_src: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_rn_src1: u32,
|
||||
@@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
|
||||
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
|
||||
let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
|
||||
let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
|
||||
@@ -45,6 +45,14 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
|
||||
#ifdef XBC_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> x_B_C_merged: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@group(0) @binding(4) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(5) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(6) var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
|
||||
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
|
||||
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
|
||||
@@ -53,6 +61,7 @@ struct Params {
|
||||
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
|
||||
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
|
||||
@group(0) @binding(8) var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
|
||||
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
|
||||
@@ -98,7 +107,11 @@ fn main(
|
||||
let dt0 = dt[dt_idx];
|
||||
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
|
||||
shared_dtsp[tid] = dtsp;
|
||||
#ifdef XBC_OVERLAP
|
||||
shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp;
|
||||
#else
|
||||
shared_x_dt[tid] = x[x_idx] * dtsp;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,16 +129,28 @@ fn main(
|
||||
|
||||
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
|
||||
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
|
||||
#ifdef XBC_OVERLAP
|
||||
let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt;
|
||||
#else
|
||||
let s = s_prev * dA + B[b_idx] * x_dt;
|
||||
#endif
|
||||
s_prev = s;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
#ifdef XBC_OVERLAP
|
||||
let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]);
|
||||
#else
|
||||
let subgroup_partial = subgroupAdd(s * C[c_idx]);
|
||||
#endif
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
|
||||
}
|
||||
#else
|
||||
#ifdef XBC_OVERLAP
|
||||
shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx];
|
||||
#else
|
||||
shared_reduce[reduce_idx] = s * C[c_idx];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -2984,7 +2984,7 @@ struct test_bin_bcast : public test_case {
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne, nr, nf, perm1);
|
||||
return VARS_TO_STR6(type, ne, nr, nf, perm1, src_overlap);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
@@ -3589,9 +3589,10 @@ struct test_ssm_scan : public test_case {
|
||||
const int64_t n_group;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
const bool xbc_overlap;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
|
||||
return VARS_TO_STR8(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs, xbc_overlap);
|
||||
}
|
||||
|
||||
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
|
||||
@@ -3600,16 +3601,31 @@ struct test_ssm_scan : public test_case {
|
||||
int64_t n_head = 32,
|
||||
int64_t n_group = 1,
|
||||
int64_t n_seq_tokens = 32,
|
||||
int64_t n_seqs = 32)
|
||||
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
int64_t n_seqs = 32,
|
||||
bool xbc_overlap = false)
|
||||
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), xbc_overlap(xbc_overlap) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
|
||||
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
|
||||
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * x;
|
||||
ggml_tensor * B;
|
||||
ggml_tensor * C;
|
||||
|
||||
if (xbc_overlap) {
|
||||
ggml_tensor * xbc = ggml_new_tensor_4d(ctx, type, d_state, n_head, n_seq_tokens, 2 * n_seqs);
|
||||
x = ggml_view_4d(ctx, xbc, head_dim, n_head, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], xbc->nb[3]);
|
||||
B = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], 0);
|
||||
C = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
|
||||
xbc->nb[1], xbc->nb[2], xbc->nb[3], 2 * xbc->nb[3]);
|
||||
} else {
|
||||
x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
|
||||
}
|
||||
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
|
||||
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
|
||||
return out;
|
||||
@@ -7964,6 +7980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 128, 4, 4, 16, 2, true)); // x/B/C overlap
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
|
||||
Reference in New Issue
Block a user