ggml-webgpu: fix compiler warnings and refactor FlashAttention encoding (#21052)
* Update workflows to remove dependence on llvmpipe * Try setting Dawn_DIR * remove c++20 initializers * Move to proper guid * Try avoiding segfaults on vulkan backend process exit * Remove compiler warnings on parameter casting * Fix soft_max and update reg_tile accumulation to f32 for better precision * Refactor flash_attn a bit * remove c++20 initializers and format * Increase div precision for NVIDIA * revert div precision and comment out ggml-ci node for now * Formatting * Try debugging on a failing CI node * Revert "Try debugging on a failing CI node" This reverts commit 1971e33cba919915e12bcfd5828abfbd54ca942e.
This commit is contained in:
@@ -97,6 +97,36 @@ jobs:
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# TODO: investigate slight precision issues in some operations for test-backend-ops on the WebGPU backend.
|
||||
#ggml-ci-nvidia-webgpu:
|
||||
# runs-on: [self-hosted, Linux, NVIDIA]
|
||||
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Dawn Dependency
|
||||
# id: dawn-depends
|
||||
# run: |
|
||||
# DAWN_VERSION="v20260317.182325"
|
||||
# DAWN_OWNER="google"
|
||||
# DAWN_REPO="dawn"
|
||||
# DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release"
|
||||
# echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
# curl -L -o artifact.tar.gz \
|
||||
# "https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
# mkdir dawn
|
||||
# tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# GG_BUILD_WEBGPU=1 \
|
||||
# GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
|
||||
# GG_BUILD_WEBGPU_DAWN_DIR="$GITHUB_WORKSPACE/dawn/lib64/cmake/Dawn" \
|
||||
# bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# TODO: provision AMX-compatible machine
|
||||
#ggml-ci-cpu-amx:
|
||||
# runs-on: [self-hosted, Linux, CPU, AMX]
|
||||
|
||||
@@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
bool use_vec;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.use_vec);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_pipeline_key key;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
size_t wg_mem_limit_bytes;
|
||||
uint32_t max_subgroup_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_decisions {
|
||||
struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
struct ggml_webgpu_flash_attn_vec_decisions {
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
// Head-dim specializations used by the tuned vec f16 path.
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
return 2u;
|
||||
case 96:
|
||||
return 4u;
|
||||
case 128:
|
||||
return 1u;
|
||||
case 192:
|
||||
return 2u;
|
||||
case 576:
|
||||
return 2u;
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.kv_type = context.src1->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
return key;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
||||
@@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh
|
||||
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_reduce";
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
||||
uint32_t q_tile;
|
||||
uint32_t kv_tile;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
|
||||
return q_tile == other.q_tile && kv_tile == other.kv_tile;
|
||||
}
|
||||
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_tile);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_tile);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_blk";
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
|
||||
variant += std::string("_qt") + std::to_string(context.key.q_tile);
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
|
||||
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
|
||||
|
||||
uint32_t wg_size = 1;
|
||||
while ((wg_size << 1) <= context.max_wg_size) {
|
||||
wg_size <<= 1;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
variant += std::string("_wg") + std::to_string(wg_size);
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
@@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
}
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!key.kv_direct) {
|
||||
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
||||
}
|
||||
if (key.has_mask) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
}
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
return kv_tile;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
|
||||
@@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib {
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_vec_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||
@@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
ggml_webgpu_row_norm_pipeline_key key = {};
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
if (it != row_norm_pipelines.end()) {
|
||||
@@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
|
||||
.vec4 = context.src0->ne[0] % 4 == 0,
|
||||
.i64_idx = context.src1->type == GGML_TYPE_I64 };
|
||||
ggml_webgpu_set_rows_pipeline_key key = {};
|
||||
key.dst_type = context.dst->type;
|
||||
key.vec4 = context.src0->ne[0] % 4 == 0;
|
||||
key.i64_idx = context.src1->type == GGML_TYPE_I64;
|
||||
|
||||
auto it = set_rows_pipelines.find(key);
|
||||
if (it != set_rows_pipelines.end()) {
|
||||
@@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
|
||||
ggml_webgpu_set_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
@@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
|
||||
ggml_webgpu_get_rows_pipeline_key key = {
|
||||
.src_type = context.src0->type,
|
||||
.vectorized = (int) vectorized,
|
||||
};
|
||||
ggml_webgpu_get_rows_pipeline_key key = {};
|
||||
key.src_type = context.src0->type;
|
||||
key.vectorized = (int) vectorized;
|
||||
|
||||
auto it = get_rows_pipelines.find(key);
|
||||
if (it != get_rows_pipelines.end()) {
|
||||
@@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type)
|
||||
{
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
@@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
}
|
||||
{
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
@@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
|
||||
ggml_webgpu_scale_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
|
||||
auto it = scale_pipelines.find(key);
|
||||
if (it != scale_pipelines.end()) {
|
||||
@@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_solve_tri_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.n = (int) context.src0->ne[0],
|
||||
.k = (int) context.src1->ne[0],
|
||||
};
|
||||
ggml_webgpu_solve_tri_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.n = (int) context.src0->ne[0];
|
||||
key.k = (int) context.src1->ne[0];
|
||||
|
||||
auto it = solve_tri_pipelines.find(key);
|
||||
if (it != solve_tri_pipelines.end()) {
|
||||
@@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_conv_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.vectorized = context.src1->ne[0] == 4,
|
||||
};
|
||||
ggml_webgpu_ssm_conv_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.vectorized = context.src1->ne[0] == 4;
|
||||
|
||||
auto it = ssm_conv_pipelines.find(key);
|
||||
if (it != ssm_conv_pipelines.end()) {
|
||||
@@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.s_v = (int) context.src2->ne[0],
|
||||
.kda = context.src3->ne[0] == context.src2->ne[0],
|
||||
};
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.s_v = (int) context.src2->ne[0];
|
||||
key.kda = context.src3->ne[0] == context.src2->ne[0];
|
||||
|
||||
auto it = gated_delta_net_pipelines.find(key);
|
||||
if (it != gated_delta_net_pipelines.end()) {
|
||||
@@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
||||
ggml_webgpu_pad_pipeline_key key = {};
|
||||
key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
|
||||
|
||||
auto it = pad_pipelines.find(key);
|
||||
if (it != pad_pipelines.end()) {
|
||||
@@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {
|
||||
.src0_type = context.src0->type,
|
||||
.src1_type = context.src1->type,
|
||||
// Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
|
||||
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0,
|
||||
};
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_vec_pipelines.find(key);
|
||||
if (it != mul_mat_vec_pipelines.end()) {
|
||||
@@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {
|
||||
.src0_type = context.src0->type,
|
||||
.src1_type = context.src1->type,
|
||||
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0,
|
||||
.use_subgroup_matrix = context.supports_subgroup_matrix
|
||||
};
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
|
||||
auto it = mul_mat_fast_pipelines.find(key);
|
||||
if (it != mul_mat_fast_pipelines.end()) {
|
||||
@@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
|
||||
.src1_type = context.src1->type };
|
||||
ggml_webgpu_legacy_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
|
||||
auto it = mul_mat_legacy_pipelines.find(key);
|
||||
if (it != mul_mat_legacy_pipelines.end()) {
|
||||
@@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (context.src0->type)
|
||||
{
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
@@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
@@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_id_pipeline_key key = {
|
||||
.src0_type = context.src0->type,
|
||||
.src1_type = context.src1->type,
|
||||
};
|
||||
ggml_webgpu_mul_mat_id_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
|
||||
auto it = mul_mat_id_pipelines.find(key);
|
||||
if (it != mul_mat_id_pipelines.end()) {
|
||||
@@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
||||
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
||||
ggml_webgpu_unary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = op,
|
||||
.is_unary = is_unary,
|
||||
.inplace = context.inplace,
|
||||
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
|
||||
};
|
||||
ggml_webgpu_unary_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.op = op;
|
||||
key.is_unary = is_unary;
|
||||
key.inplace = context.inplace;
|
||||
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
if (it != unary_pipelines.end()) {
|
||||
@@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_binary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.src_overlap = context.src_overlap,
|
||||
};
|
||||
ggml_webgpu_binary_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.op = context.dst->op;
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
|
||||
auto it = binary_pipelines.find(key);
|
||||
if (it != binary_pipelines.end()) {
|
||||
@@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_concat_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
};
|
||||
ggml_webgpu_concat_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
|
||||
auto it = concat_pipelines.find(key);
|
||||
if (it != concat_pipelines.end()) {
|
||||
@@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_repeat_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
};
|
||||
ggml_webgpu_repeat_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
|
||||
auto it = repeat_pipelines.find(key);
|
||||
if (it != repeat_pipelines.end()) {
|
||||
@@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
auto it = flash_attn_pipelines.find(context.key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
|
||||
switch (context.key.kv_type) {
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
@@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib {
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(context.key.kv_type);
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
if (context.key.has_mask) {
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (context.key.has_sinks) {
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (context.key.uses_logit_softcap) {
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (context.key.kv_direct) {
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (context.key.has_mask && context.key.use_vec) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_blk";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (context.key.use_vec) {
|
||||
q_tile = 1;
|
||||
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
|
||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
}
|
||||
if (context.key.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
|
||||
decisions->q_tile = context.sg_mat_m;
|
||||
|
||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
|
||||
if (key.kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= context.sg_mat_n;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
uint32_t wg_size = 0;
|
||||
if (context.key.use_vec) {
|
||||
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
} else {
|
||||
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
|
||||
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
|
||||
decisions->q_tile = q_tile;
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = wg_size;
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[context.key] = pipeline;
|
||||
return flash_attn_pipelines[context.key];
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_pipelines[key] = pipeline;
|
||||
return flash_attn_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
|
||||
auto it = flash_attn_blk_pipelines.find(context.key);
|
||||
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
||||
auto it = flash_attn_vec_pipelines.find(key);
|
||||
if (it != flash_attn_vec_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("KV_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("KV_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("KV_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
defines.push_back("Q_TILE=1");
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
|
||||
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||
uint32_t vec_ne = 1u;
|
||||
|
||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
||||
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
vec_ne = 2u;
|
||||
break;
|
||||
case 96:
|
||||
vec_ne = 4u;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
||||
pipeline.context = decisions;
|
||||
flash_attn_vec_pipelines[key] = pipeline;
|
||||
return flash_attn_vec_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
||||
key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
||||
auto it = flash_attn_blk_pipelines.find(key);
|
||||
if (it != flash_attn_blk_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_blk_pipelines[context.key] = pipeline;
|
||||
return flash_attn_blk_pipelines[context.key];
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_blk";
|
||||
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
|
||||
variant += std::string("_kvt") + std::to_string(key.kv_tile);
|
||||
|
||||
uint32_t wg_size = 1;
|
||||
while ((wg_size << 1) <= context.max_wg_size) {
|
||||
wg_size <<= 1;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
variant += std::string("_wg") + std::to_string(wg_size);
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
|
||||
flash_attn_blk_pipelines[key] = pipeline;
|
||||
return flash_attn_blk_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
|
||||
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
|
||||
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
|
||||
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.wg_size = context.max_wg_size;
|
||||
auto it = flash_attn_vec_reduce_pipelines.find(key);
|
||||
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
|
||||
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
|
||||
return flash_attn_vec_reduce_pipelines[context.key];
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn_vec_reduce";
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
||||
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
|
||||
flash_attn_vec_reduce_pipelines[key] = pipeline;
|
||||
return flash_attn_vec_reduce_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_cpy_pipeline_key key = {
|
||||
.src_type = context.src0->type,
|
||||
.dst_type = context.dst->type,
|
||||
};
|
||||
ggml_webgpu_cpy_pipeline_key key = {};
|
||||
key.src_type = context.src0->type;
|
||||
key.dst_type = context.dst->type;
|
||||
|
||||
auto it = cpy_pipelines.find(key);
|
||||
if (it != cpy_pipelines.end()) {
|
||||
@@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_glu_pipeline_key key = {
|
||||
.glu_op = ggml_get_glu_op(context.dst),
|
||||
.type = context.dst->type,
|
||||
.split = (context.src1 != nullptr),
|
||||
};
|
||||
ggml_webgpu_glu_pipeline_key key = {};
|
||||
key.glu_op = ggml_get_glu_op(context.dst);
|
||||
key.type = context.dst->type;
|
||||
key.split = (context.src1 != nullptr);
|
||||
|
||||
auto it = glu_pipelines.find(key);
|
||||
if (it != glu_pipelines.end()) {
|
||||
@@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rope_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.inplace = context.inplace,
|
||||
.has_ff = (context.src2 != nullptr),
|
||||
};
|
||||
ggml_webgpu_rope_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.inplace = context.inplace;
|
||||
key.has_ff = (context.src2 != nullptr);
|
||||
|
||||
auto it = rope_pipelines.find(key);
|
||||
if (it != rope_pipelines.end()) {
|
||||
@@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_soft_max_pipeline_key key = {
|
||||
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
|
||||
.has_mask = (context.src1 != nullptr),
|
||||
.has_sink = (context.src2 != nullptr),
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
ggml_webgpu_soft_max_pipeline_key key = {};
|
||||
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
|
||||
key.has_mask = (context.src1 != nullptr);
|
||||
key.has_sink = (context.src2 != nullptr);
|
||||
key.inplace = context.inplace;
|
||||
|
||||
auto it = soft_max_pipelines.find(key);
|
||||
if (it != soft_max_pipelines.end()) {
|
||||
@@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib {
|
||||
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
||||
return { device.CreateComputePipeline(&pipeline_desc), label };
|
||||
}
|
||||
|
||||
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes =
|
||||
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!context.key.kv_direct) {
|
||||
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
|
||||
}
|
||||
if (context.key.has_mask) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,6 @@
|
||||
diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
|
||||
#define Q_TILE 1
|
||||
#define KV_TILE 32
|
||||
#define WG_SIZE 32
|
||||
|
||||
@@ -11,7 +10,7 @@ struct Params {
|
||||
seq_len_kv: u32,
|
||||
stride_mask3: u32,
|
||||
// Number of KV blocks and Q blocks per batch.
|
||||
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
|
||||
// nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q.
|
||||
nblk0: u32,
|
||||
nblk1: u32,
|
||||
};
|
||||
@@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
return;
|
||||
}
|
||||
|
||||
let q_start = q_blk * Q_TILE;
|
||||
let q_start = q_blk;
|
||||
let k_start = kv_blk * KV_TILE;
|
||||
|
||||
let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||
@@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
var local_max = -MASK_MAX;
|
||||
var local_any = 0u;
|
||||
|
||||
for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
|
||||
let q_row = q_start + q_rel;
|
||||
if (q_row >= params.seq_len_q) {
|
||||
continue;
|
||||
}
|
||||
let q_row = q_start;
|
||||
if (q_row < params.seq_len_q) {
|
||||
let row_base = mask_batch_base + q_row * params.seq_len_kv;
|
||||
for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
|
||||
let k_col = k_start + k_rel;
|
||||
|
||||
Reference in New Issue
Block a user