ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (#22199)
* ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap
This commit is contained in:
@@ -436,19 +436,27 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
|||||||
|
|
||||||
/** FlashAttention */
|
/** FlashAttention */
|
||||||
|
|
||||||
|
enum ggml_webgpu_flash_attn_path : uint32_t {
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u,
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u,
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u,
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||||
ggml_type kv_type;
|
ggml_type kv_type;
|
||||||
uint32_t head_dim_qk;
|
uint32_t head_dim_qk;
|
||||||
uint32_t head_dim_v;
|
uint32_t head_dim_v;
|
||||||
bool kv_direct;
|
bool kv_direct;
|
||||||
|
bool kv_overlap;
|
||||||
bool has_mask;
|
bool has_mask;
|
||||||
bool has_sinks;
|
bool has_sinks;
|
||||||
bool uses_logit_softcap;
|
bool uses_logit_softcap;
|
||||||
|
uint32_t path;
|
||||||
|
|
||||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
||||||
uses_logit_softcap == other.uses_logit_softcap;
|
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -459,39 +467,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|||||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
||||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||||
|
ggml_webgpu_hash_combine(seed, key.path);
|
||||||
return seed;
|
return seed;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_webgpu_flash_attn_decisions {
|
struct ggml_webgpu_flash_attn_decisions {
|
||||||
uint32_t q_tile = 0;
|
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||||
uint32_t kv_tile = 0;
|
uint32_t q_tile = 0;
|
||||||
uint32_t wg_size = 0;
|
uint32_t kv_tile = 0;
|
||||||
|
uint32_t wg_size = 0;
|
||||||
|
bool kv_direct = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_webgpu_flash_attn_vec_decisions {
|
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||||
uint32_t kv_tile = 0;
|
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
||||||
uint32_t wg_size = 0;
|
|
||||||
};
|
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||||
|
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
|
||||||
|
key.head_dim_qk != key.head_dim_v) {
|
||||||
|
return 1u;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (key.head_dim_qk) {
|
||||||
|
case 64:
|
||||||
|
case 192:
|
||||||
|
case 576:
|
||||||
|
return 2u;
|
||||||
|
case 96:
|
||||||
|
return 4u;
|
||||||
|
default:
|
||||||
|
return 1u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||||
const ggml_webgpu_shader_lib_context & context) {
|
const ggml_webgpu_shader_lib_context & context,
|
||||||
|
uint32_t path) {
|
||||||
const bool has_mask = context.src3 != nullptr;
|
const bool has_mask = context.src3 != nullptr;
|
||||||
const bool has_sinks = context.src4 != nullptr;
|
const bool has_sinks = context.src4 != nullptr;
|
||||||
const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
bool kv_direct = false;
|
||||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||||
|
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
|
||||||
|
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||||
|
kv_direct_align = context.sg_mat_k;
|
||||||
|
}
|
||||||
|
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
|
||||||
|
(context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) &&
|
||||||
|
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||||
key.kv_type = context.src1->type;
|
key.kv_type = context.src1->type;
|
||||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||||
key.kv_direct = kv_direct;
|
key.kv_direct = kv_direct;
|
||||||
|
key.kv_overlap = context.src_overlap;
|
||||||
key.has_mask = has_mask;
|
key.has_mask = has_mask;
|
||||||
key.has_sinks = has_sinks;
|
key.has_sinks = has_sinks;
|
||||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||||
|
key.path = path;
|
||||||
return key;
|
return key;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -554,8 +593,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|||||||
|
|
||||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
||||||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||||
const size_t q_tile = context.sg_mat_m;
|
uint32_t q_tile = context.sg_mat_m;
|
||||||
|
uint32_t kv_granularity = context.sg_mat_n;
|
||||||
|
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||||
|
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||||
|
kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||||
|
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
|
q_tile = 1u;
|
||||||
|
kv_granularity = 8u;
|
||||||
|
}
|
||||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||||
size_t bytes_per_kv = 0;
|
size_t bytes_per_kv = 0;
|
||||||
@@ -568,23 +615,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
|||||||
bytes_per_kv += q_tile;
|
bytes_per_kv += q_tile;
|
||||||
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
||||||
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
||||||
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
return (max_kv_tile / kv_granularity) * kv_granularity;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) {
|
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
const ggml_webgpu_shader_lib_context & context,
|
||||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
size_t storage_offset_alignment) {
|
||||||
uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile));
|
ggml_webgpu_flash_attn_decisions decisions = {};
|
||||||
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
|
||||||
|
const auto * K = context.src1;
|
||||||
|
const auto * V = context.src2;
|
||||||
|
GGML_ASSERT(K != nullptr);
|
||||||
|
GGML_ASSERT(V != nullptr);
|
||||||
|
|
||||||
if (key.kv_direct) {
|
const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t {
|
||||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
||||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
||||||
kv_tile -= context.sg_mat_n;
|
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
||||||
|
};
|
||||||
|
|
||||||
|
const uint32_t k_offset_elems =
|
||||||
|
(uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||||
|
const uint32_t v_offset_elems =
|
||||||
|
(uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||||
|
const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) &&
|
||||||
|
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||||
|
const bool kv_vec_type_supported =
|
||||||
|
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||||
|
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
|
||||||
|
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||||
|
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||||
|
(context.src2->type == K->type);
|
||||||
|
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||||
|
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||||
|
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||||
|
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec;
|
||||||
|
|
||||||
|
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||||
|
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
|
||||||
|
|
||||||
|
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||||
|
decisions.kv_direct = key.kv_direct;
|
||||||
|
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
|
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||||
|
decisions.q_tile = 1u;
|
||||||
|
decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile));
|
||||||
|
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||||
|
decisions.wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
||||||
|
if (decisions.kv_direct) {
|
||||||
|
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||||
|
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||||
|
decisions.kv_tile -= 8u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return decisions;
|
||||||
|
}
|
||||||
|
|
||||||
|
decisions.q_tile =
|
||||||
|
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
||||||
|
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||||
|
std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) :
|
||||||
|
std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key),
|
||||||
|
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||||
|
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||||
|
GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE :
|
||||||
|
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||||
|
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||||
|
const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size);
|
||||||
|
decisions.kv_tile =
|
||||||
|
std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decisions.kv_direct) {
|
||||||
|
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||||
|
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||||
|
decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||||
|
std::max(1u, context.max_subgroup_size) :
|
||||||
|
context.sg_mat_n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv_tile;
|
return decisions;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Matrix Multiplication **/
|
/** Matrix Multiplication **/
|
||||||
@@ -821,8 +935,6 @@ class ggml_webgpu_shader_lib {
|
|||||||
repeat_pipelines; // type
|
repeat_pipelines; // type
|
||||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||||
flash_attn_pipelines;
|
flash_attn_pipelines;
|
||||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
|
||||||
flash_attn_vec_pipelines;
|
|
||||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||||
webgpu_pipeline,
|
webgpu_pipeline,
|
||||||
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
||||||
@@ -2044,14 +2156,19 @@ class ggml_webgpu_shader_lib {
|
|||||||
return repeat_pipelines[key];
|
return repeat_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context,
|
||||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
size_t storage_offset_alignment) {
|
||||||
auto it = flash_attn_pipelines.find(key);
|
const ggml_webgpu_flash_attn_decisions decisions =
|
||||||
|
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||||
|
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path);
|
||||||
|
auto it = flash_attn_pipelines.find(key);
|
||||||
if (it != flash_attn_pipelines.end()) {
|
if (it != flash_attn_pipelines.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::vector<std::string> defines;
|
std::vector<std::string> defines;
|
||||||
std::string variant = "flash_attn";
|
std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" :
|
||||||
|
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" :
|
||||||
|
"flash_attn";
|
||||||
|
|
||||||
switch (key.kv_type) {
|
switch (key.kv_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
@@ -2073,7 +2190,12 @@ class ggml_webgpu_shader_lib {
|
|||||||
|
|
||||||
if (key.has_mask) {
|
if (key.has_mask) {
|
||||||
defines.push_back("MASK");
|
defines.push_back("MASK");
|
||||||
variant += "_mask";
|
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
|
defines.push_back("BLK");
|
||||||
|
variant += "_mask_blk";
|
||||||
|
} else {
|
||||||
|
variant += "_mask";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (key.has_sinks) {
|
if (key.has_sinks) {
|
||||||
defines.push_back("SINKS");
|
defines.push_back("SINKS");
|
||||||
@@ -2087,6 +2209,10 @@ class ggml_webgpu_shader_lib {
|
|||||||
defines.push_back("KV_DIRECT");
|
defines.push_back("KV_DIRECT");
|
||||||
variant += "_kvdirect";
|
variant += "_kvdirect";
|
||||||
}
|
}
|
||||||
|
if (key.kv_overlap) {
|
||||||
|
defines.push_back("KV_OVERLAP");
|
||||||
|
variant += "_kv_overlap";
|
||||||
|
}
|
||||||
|
|
||||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||||
@@ -2094,129 +2220,37 @@ class ggml_webgpu_shader_lib {
|
|||||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||||
|
|
||||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
const char * shader_src = wgsl_flash_attn;
|
||||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
defines.push_back("KV_GRANULARITY=8");
|
||||||
|
defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u");
|
||||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
|
shader_src = wgsl_flash_attn_vec_split;
|
||||||
decisions->q_tile = context.sg_mat_m;
|
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||||
|
shader_src = wgsl_flash_attn_tile;
|
||||||
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
|
||||||
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
|
||||||
|
variant += "_tile";
|
||||||
if (key.kv_direct) {
|
} else {
|
||||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||||
kv_tile -= context.sg_mat_n;
|
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
decisions->kv_tile = kv_tile;
|
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||||
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||||
|
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
|
||||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
|
||||||
|
|
||||||
webgpu_pipeline pipeline =
|
webgpu_pipeline pipeline =
|
||||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
|
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||||
pipeline.context = decisions;
|
pipeline.context = pipeline_decisions;
|
||||||
flash_attn_pipelines[key] = pipeline;
|
flash_attn_pipelines[key] = pipeline;
|
||||||
return flash_attn_pipelines[key];
|
return flash_attn_pipelines[key];
|
||||||
}
|
}
|
||||||
|
|
||||||
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
||||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
|
|
||||||
auto it = flash_attn_vec_pipelines.find(key);
|
|
||||||
if (it != flash_attn_vec_pipelines.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> defines;
|
|
||||||
std::string variant = "flash_attn_vec";
|
|
||||||
|
|
||||||
switch (key.kv_type) {
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
defines.push_back("KV_F32");
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
defines.push_back("KV_F16");
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_Q4_0:
|
|
||||||
defines.push_back("KV_Q4_0");
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_Q8_0:
|
|
||||||
defines.push_back("KV_Q8_0");
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
|
||||||
}
|
|
||||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
|
||||||
|
|
||||||
if (key.has_mask) {
|
|
||||||
defines.push_back("MASK");
|
|
||||||
defines.push_back("BLK");
|
|
||||||
variant += "_mask_blk";
|
|
||||||
}
|
|
||||||
if (key.has_sinks) {
|
|
||||||
defines.push_back("SINKS");
|
|
||||||
variant += "_sinks";
|
|
||||||
}
|
|
||||||
if (key.uses_logit_softcap) {
|
|
||||||
defines.push_back("LOGIT_SOFTCAP");
|
|
||||||
variant += "_lgsc";
|
|
||||||
}
|
|
||||||
if (key.kv_direct) {
|
|
||||||
defines.push_back("KV_DIRECT");
|
|
||||||
variant += "_kvdirect";
|
|
||||||
}
|
|
||||||
|
|
||||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
|
||||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
|
||||||
|
|
||||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
||||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
||||||
|
|
||||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
|
||||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
|
||||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
||||||
defines.push_back("Q_TILE=1");
|
|
||||||
|
|
||||||
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
|
|
||||||
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
|
||||||
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
|
|
||||||
uint32_t vec_ne = 1u;
|
|
||||||
|
|
||||||
// Keep conservative defaults unless this is the f16 vec-split shape family.
|
|
||||||
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
|
|
||||||
switch (key.head_dim_qk) {
|
|
||||||
case 64:
|
|
||||||
case 192:
|
|
||||||
case 576:
|
|
||||||
vec_ne = 2u;
|
|
||||||
break;
|
|
||||||
case 96:
|
|
||||||
vec_ne = 4u;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
|
|
||||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
|
|
||||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
|
||||||
|
|
||||||
webgpu_pipeline pipeline =
|
|
||||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
|
||||||
pipeline.context = decisions;
|
|
||||||
flash_attn_vec_pipelines[key] = pipeline;
|
|
||||||
return flash_attn_vec_pipelines[key];
|
|
||||||
}
|
|
||||||
|
|
||||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
||||||
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
||||||
key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
|
key.kv_tile = kv_tile;
|
||||||
auto it = flash_attn_blk_pipelines.find(key);
|
auto it = flash_attn_blk_pipelines.find(key);
|
||||||
if (it != flash_attn_blk_pipelines.end()) {
|
if (it != flash_attn_blk_pipelines.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
|
|||||||
@@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t
|
|||||||
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx,
|
|
||||||
const ggml_tensor * Q,
|
|
||||||
const ggml_tensor * K,
|
|
||||||
const ggml_tensor * V) {
|
|
||||||
const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
|
||||||
const uint32_t k_offset_elems =
|
|
||||||
(uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
|
||||||
const uint32_t v_offset_elems =
|
|
||||||
(uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
|
||||||
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
|
|
||||||
const bool kv_vec_type_supported =
|
|
||||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
|
||||||
|
|
||||||
return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
|
|
||||||
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
||||||
size_t offset = ggml_webgpu_tensor_offset(t);
|
size_t offset = ggml_webgpu_tensor_offset(t);
|
||||||
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
@@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
|||||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef __EMSCRIPTEN__
|
|
||||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||||
ggml_tensor * Q,
|
ggml_tensor * Q,
|
||||||
ggml_tensor * K,
|
ggml_tensor * K,
|
||||||
@@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
const int has_mask = (mask != nullptr);
|
const int has_mask = (mask != nullptr);
|
||||||
const int has_sinks = (sinks != nullptr);
|
const int has_sinks = (sinks != nullptr);
|
||||||
|
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
|
||||||
|
|
||||||
|
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||||
|
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||||
|
size_t kv_bind_offset = 0;
|
||||||
|
size_t kv_bind_size = 0;
|
||||||
|
if (kv_overlap) {
|
||||||
|
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
|
||||||
|
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
|
||||||
|
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
|
||||||
|
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
|
||||||
|
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
|
||||||
|
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
|
||||||
|
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
|
||||||
|
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
offset_k,
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
offset_v,
|
||||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
@@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
};
|
};
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
|
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
|
||||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K),
|
|
||||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V),
|
|
||||||
};
|
};
|
||||||
uint32_t binding_index = 3;
|
if (kv_overlap) {
|
||||||
|
entries.push_back(
|
||||||
|
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||||
|
} else {
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
|
||||||
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
|
||||||
|
}
|
||||||
|
uint32_t binding_index = kv_overlap ? 2u : 3u;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
||||||
}
|
}
|
||||||
@@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
shader_lib_ctx.src3 = mask;
|
shader_lib_ctx.src3 = mask;
|
||||||
shader_lib_ctx.src4 = sinks;
|
shader_lib_ctx.src4 = sinks;
|
||||||
shader_lib_ctx.dst = dst;
|
shader_lib_ctx.dst = dst;
|
||||||
|
shader_lib_ctx.src_overlap = kv_overlap;
|
||||||
|
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||||
|
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||||
const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V);
|
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||||
webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) :
|
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||||
ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||||
|
|
||||||
if (!use_vec) {
|
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
|
||||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
|
|
||||||
|
|
||||||
wgpu::Buffer blk_buf = {};
|
wgpu::Buffer blk_buf = {};
|
||||||
uint64_t blk_size_bytes = 0;
|
uint64_t blk_size_bytes = 0;
|
||||||
uint32_t blk_nblk0 = 0;
|
uint32_t blk_nblk0 = 0;
|
||||||
@@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
tmp_bind_size = tmp_size_bytes;
|
tmp_bind_size = tmp_size_bytes;
|
||||||
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||||
} else {
|
} else {
|
||||||
// nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
|
// nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region.
|
||||||
|
tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT;
|
||||||
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
tmp_buf = ggml_webgpu_tensor_buf(dst);
|
||||||
tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
|
tmp_bind_offset = scratch_offset;
|
||||||
tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
|
tmp_bind_size = tmp_size_bytes;
|
||||||
|
scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
webgpu_pipeline blk_pipeline;
|
webgpu_pipeline blk_pipeline;
|
||||||
@@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||||
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
|
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
|
||||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
|
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
|
||||||
|
|
||||||
blk_params = {
|
blk_params = {
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
|
||||||
@@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
std::vector<wgpu::BindGroupEntry> split_entries = {
|
std::vector<wgpu::BindGroupEntry> split_entries = {
|
||||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
|
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||||
ggml_webgpu_tensor_binding_size(ctx, Q)),
|
ggml_webgpu_tensor_binding_size(ctx, Q)),
|
||||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K),
|
|
||||||
ggml_webgpu_tensor_binding_size(ctx, K)),
|
|
||||||
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V),
|
|
||||||
ggml_webgpu_tensor_binding_size(ctx, V)),
|
|
||||||
};
|
};
|
||||||
uint32_t split_binding_index = 3;
|
if (kv_overlap) {
|
||||||
|
split_entries.push_back(
|
||||||
|
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||||
|
} else {
|
||||||
|
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
|
||||||
|
ggml_webgpu_tensor_align_offset(ctx, K),
|
||||||
|
ggml_webgpu_tensor_binding_size(ctx, K)));
|
||||||
|
split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V),
|
||||||
|
ggml_webgpu_tensor_align_offset(ctx, V),
|
||||||
|
ggml_webgpu_tensor_binding_size(ctx, V)));
|
||||||
|
}
|
||||||
|
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
|
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
|
||||||
ggml_webgpu_tensor_align_offset(ctx, mask),
|
ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||||
@@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||||||
|
|
||||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||||
}
|
}
|
||||||
#endif // __EMSCRIPTEN__
|
|
||||||
|
|
||||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||||
@@ -2710,11 +2721,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
|||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node);
|
return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node);
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
#ifndef __EMSCRIPTEN__
|
|
||||||
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
||||||
#else
|
|
||||||
return std::nullopt;
|
|
||||||
#endif
|
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
@@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||||
shader_lib_ctx.wg_mem_limit_bytes =
|
shader_lib_ctx.wg_mem_limit_bytes =
|
||||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
|
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||||
|
shader_lib_ctx.supports_subgroup_matrix =
|
||||||
|
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||||
|
|
||||||
if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) {
|
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||||
const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx);
|
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||||
|
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
|
const uint32_t kv_tile = decisions.kv_tile;
|
||||||
|
|
||||||
const uint32_t vec_nwg_cap = std::max(
|
const uint32_t vec_nwg_cap = std::max(
|
||||||
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
|
||||||
@@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||||
res += tmp_size_bytes + align;
|
res += tmp_size_bytes + align;
|
||||||
|
} else {
|
||||||
|
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
|
||||||
}
|
}
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||||
@@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||||
|
|
||||||
|
bool valid_subgroup_matrix_config = false;
|
||||||
#ifndef __EMSCRIPTEN__
|
#ifndef __EMSCRIPTEN__
|
||||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||||
bool valid_subgroup_matrix_config = false;
|
|
||||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||||
@@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
|
||||||
#endif
|
#endif
|
||||||
|
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||||
|
|
||||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||||
@@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||||||
// Enable Dawn-specific toggles to increase native performance
|
// Enable Dawn-specific toggles to increase native performance
|
||||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||||
// only for native performance?
|
// only for native performance?
|
||||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
|
||||||
"disable_polyfills_on_integer_div_and_mod" };
|
"disable_polyfills_on_integer_div_and_mod" };
|
||||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||||
deviceTogglesDesc.enabledToggleCount = 4;
|
deviceTogglesDesc.enabledToggleCount = 3;
|
||||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||||
deviceTogglesDesc.disabledToggleCount = 1;
|
deviceTogglesDesc.disabledToggleCount = 1;
|
||||||
|
|
||||||
@@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
break;
|
break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
#ifndef __EMSCRIPTEN__
|
|
||||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
|
||||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
|
||||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
|
||||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
|
||||||
const bool has_mask = op->src[3] != nullptr;
|
|
||||||
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
|
||||||
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
|
||||||
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
|
||||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
|
||||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
|
||||||
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
|
||||||
if (min_bytes > limit_bytes) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||||
#endif
|
if (!supports_op) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||||
|
shader_lib_ctx.src0 = src0;
|
||||||
|
shader_lib_ctx.src1 = src1;
|
||||||
|
shader_lib_ctx.src2 = src2;
|
||||||
|
shader_lib_ctx.src3 = op->src[3];
|
||||||
|
shader_lib_ctx.src4 = op->src[4];
|
||||||
|
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
|
||||||
|
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||||
|
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||||
|
shader_lib_ctx.wg_mem_limit_bytes =
|
||||||
|
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
|
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||||
|
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||||
|
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||||
|
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||||
|
|
||||||
|
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||||
|
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||||
|
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||||
|
const size_t min_bytes =
|
||||||
|
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||||
|
if (min_bytes > limit_bytes) {
|
||||||
|
supports_op = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||||
|
const size_t min_bytes =
|
||||||
|
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||||
|
if (min_bytes > limit_bytes) {
|
||||||
|
supports_op = false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||||
|
supports_op = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const size_t min_bytes =
|
||||||
|
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
|
||||||
|
if (min_bytes > limit_bytes) {
|
||||||
|
supports_op = false;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
|||||||
@@ -138,25 +138,54 @@ struct Params {
|
|||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||||
|
#define V K
|
||||||
|
#else
|
||||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(MASK) && defined(SINKS)
|
#if defined(MASK) && defined(SINKS)
|
||||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
#ifdef KV_OVERLAP
|
||||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
#define DST_BINDING 5
|
|
||||||
#define PARAMS_BINDING 6
|
|
||||||
#elif defined(MASK)
|
|
||||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
|
||||||
#define DST_BINDING 4
|
|
||||||
#define PARAMS_BINDING 5
|
|
||||||
#elif defined(SINKS)
|
|
||||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
#define DST_BINDING 4
|
#define DST_BINDING 4
|
||||||
#define PARAMS_BINDING 5
|
#define PARAMS_BINDING 5
|
||||||
#else
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
|
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 5
|
||||||
|
#define PARAMS_BINDING 6
|
||||||
|
#endif
|
||||||
|
#elif defined(MASK)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
#define DST_BINDING 3
|
#define DST_BINDING 3
|
||||||
#define PARAMS_BINDING 4
|
#define PARAMS_BINDING 4
|
||||||
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#endif
|
||||||
|
#elif defined(SINKS)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
#define DST_BINDING 2
|
||||||
|
#define PARAMS_BINDING 3
|
||||||
|
#else
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||||
|
|||||||
@@ -0,0 +1,330 @@
|
|||||||
|
enable f16;
|
||||||
|
enable subgroups;
|
||||||
|
|
||||||
|
#define HEAD_DIM_QK 64
|
||||||
|
#define HEAD_DIM_V 64
|
||||||
|
#define KV_STAGE_STRIDE 64
|
||||||
|
#define Q_TILE 4
|
||||||
|
#define KV_TILE 64
|
||||||
|
#define WG_SIZE 128
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_q: u32,
|
||||||
|
offset_k: u32,
|
||||||
|
offset_v: u32,
|
||||||
|
offset_mask: u32,
|
||||||
|
offset_sinks: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
n_heads: u32,
|
||||||
|
seq_len_q: u32,
|
||||||
|
seq_len_kv: u32,
|
||||||
|
|
||||||
|
stride_q1: u32,
|
||||||
|
stride_q2: u32,
|
||||||
|
stride_q3: u32,
|
||||||
|
stride_k1: u32,
|
||||||
|
stride_k2: u32,
|
||||||
|
stride_k3: u32,
|
||||||
|
stride_v1: u32,
|
||||||
|
stride_v2: u32,
|
||||||
|
stride_v3: u32,
|
||||||
|
stride_mask3: u32,
|
||||||
|
|
||||||
|
q_per_kv: u32,
|
||||||
|
|
||||||
|
scale: f32,
|
||||||
|
max_bias: f32,
|
||||||
|
logit_softcap: f32,
|
||||||
|
n_head_log2: f32,
|
||||||
|
m0: f32,
|
||||||
|
m1: f32,
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||||
|
#define V K
|
||||||
|
#else
|
||||||
|
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
|
||||||
|
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(MASK) && defined(SINKS)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
|
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 5
|
||||||
|
#define PARAMS_BINDING 6
|
||||||
|
#endif
|
||||||
|
#elif defined(MASK)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#endif
|
||||||
|
#elif defined(SINKS)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
|
#else
|
||||||
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
#define DST_BINDING 2
|
||||||
|
#define PARAMS_BINDING 3
|
||||||
|
#else
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
|
||||||
|
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
|
||||||
|
|
||||||
|
const FLOAT_MIN: f32 = -1.0e9;
|
||||||
|
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
|
||||||
|
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
||||||
|
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||||
|
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
|
||||||
|
|
||||||
|
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
||||||
|
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
|
||||||
|
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
|
||||||
|
|
||||||
|
@compute @workgroup_size(WG_SIZE)
|
||||||
|
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||||
|
@builtin(subgroup_id) subgroup_id: u32,
|
||||||
|
@builtin(subgroup_size) subgroup_size: u32,
|
||||||
|
@builtin(num_subgroups) num_subgroups: u32,
|
||||||
|
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||||
|
if (subgroup_size == 0u || num_subgroups < Q_TILE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
||||||
|
let wg_per_batch = wg_per_head * params.n_heads;
|
||||||
|
|
||||||
|
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||||
|
let dst3_stride = dst2_stride * params.seq_len_q;
|
||||||
|
|
||||||
|
let batch_idx = wg_id.x / wg_per_batch;
|
||||||
|
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
|
||||||
|
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
|
||||||
|
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
|
||||||
|
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
|
||||||
|
let wg_in_batch = wg_id.x % wg_per_batch;
|
||||||
|
|
||||||
|
let head_idx = wg_in_batch / wg_per_head;
|
||||||
|
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
|
||||||
|
let k_head_idx = head_idx / params.q_per_kv;
|
||||||
|
let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2;
|
||||||
|
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||||
|
|
||||||
|
let wg_in_head = wg_in_batch % wg_per_head;
|
||||||
|
let q_row_start = wg_in_head * Q_TILE;
|
||||||
|
let global_q_row = q_row_start + subgroup_id;
|
||||||
|
let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q;
|
||||||
|
|
||||||
|
#ifdef MASK
|
||||||
|
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
|
||||||
|
|
||||||
|
let head = f32(head_idx);
|
||||||
|
let slope = select(1.0,
|
||||||
|
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
|
||||||
|
pow(params.m0, head + 1.0),
|
||||||
|
head < params.n_head_log2),
|
||||||
|
params.max_bias > 0.0);
|
||||||
|
|
||||||
|
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||||
|
let q_tile_row = elem_idx / HEAD_DIM_QK;
|
||||||
|
let q_col = elem_idx % HEAD_DIM_QK;
|
||||||
|
let head_q_row = q_row_start + q_tile_row;
|
||||||
|
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||||
|
q_shmem[elem_idx] = f16(select(
|
||||||
|
0.0,
|
||||||
|
Q[global_q_row_offset + q_col] * params.scale,
|
||||||
|
head_q_row < params.seq_len_q));
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
var row_max = FLOAT_MIN;
|
||||||
|
var exp_sum = 0.0;
|
||||||
|
var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>;
|
||||||
|
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||||
|
out_regs[reg_idx] = vec4<f32>(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let q_base = subgroup_id * HEAD_DIM_QK;
|
||||||
|
let subgroup_p_offset = subgroup_id * KV_TILE;
|
||||||
|
|
||||||
|
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||||
|
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
|
||||||
|
let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size);
|
||||||
|
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||||
|
var local_scores: array<f32, SCORE_REGS_PER_LANE>;
|
||||||
|
for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) {
|
||||||
|
local_scores[slot] = FLOAT_MIN;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||||
|
let kv_local = vec_idx_local / Q_CHUNKS;
|
||||||
|
let chunk = vec_idx_local % Q_CHUNKS;
|
||||||
|
let global_k_row = kv_tile + kv_local;
|
||||||
|
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||||
|
let k4 = K[k_vec_index];
|
||||||
|
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||||
|
kv_shmem[kv_off + 0u] = k4.x;
|
||||||
|
kv_shmem[kv_off + 1u] = k4.y;
|
||||||
|
kv_shmem[kv_off + 2u] = k4.z;
|
||||||
|
kv_shmem[kv_off + 3u] = k4.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
var local_max = FLOAT_MIN;
|
||||||
|
if (row_active) {
|
||||||
|
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||||
|
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||||
|
if (kv_local >= kv_count) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let global_k_row = kv_tile + kv_local;
|
||||||
|
var dot_val = 0.0;
|
||||||
|
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
|
||||||
|
let q_off = q_base + chunk * 4u;
|
||||||
|
let qv = vec4<f32>(
|
||||||
|
f32(q_shmem[q_off + 0u]),
|
||||||
|
f32(q_shmem[q_off + 1u]),
|
||||||
|
f32(q_shmem[q_off + 2u]),
|
||||||
|
f32(q_shmem[q_off + 3u]));
|
||||||
|
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||||
|
let kv = vec4<f32>(
|
||||||
|
f32(kv_shmem[kv_off + 0u]),
|
||||||
|
f32(kv_shmem[kv_off + 1u]),
|
||||||
|
f32(kv_shmem[kv_off + 2u]),
|
||||||
|
f32(kv_shmem[kv_off + 3u]));
|
||||||
|
dot_val += dot(qv, kv);
|
||||||
|
}
|
||||||
|
#ifdef LOGIT_SOFTCAP
|
||||||
|
dot_val = params.logit_softcap * tanh(dot_val);
|
||||||
|
#endif
|
||||||
|
#ifdef MASK
|
||||||
|
let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row;
|
||||||
|
dot_val += slope * f32(mask[mask_idx]);
|
||||||
|
#endif
|
||||||
|
local_scores[slot] = dot_val;
|
||||||
|
local_max = max(local_max, dot_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tile_max = subgroupMax(local_max);
|
||||||
|
let new_max = max(row_max, tile_max);
|
||||||
|
let cur_exp = exp(row_max - new_max);
|
||||||
|
exp_sum *= cur_exp;
|
||||||
|
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||||
|
out_regs[reg_idx] *= cur_exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
var local_sum = 0.0;
|
||||||
|
for (var slot = 0u; slot < score_slots; slot += 1u) {
|
||||||
|
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||||
|
if (row_active && kv_local < kv_count) {
|
||||||
|
let p = exp(local_scores[slot] - new_max);
|
||||||
|
p_shmem[subgroup_p_offset + kv_local] = p;
|
||||||
|
local_sum += p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||||
|
let kv_local = vec_idx_local / V_CHUNKS;
|
||||||
|
let chunk = vec_idx_local % V_CHUNKS;
|
||||||
|
let global_v_row = kv_tile + kv_local;
|
||||||
|
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||||
|
let v4 = V[v_vec_index];
|
||||||
|
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||||
|
kv_shmem[kv_off + 0u] = v4.x;
|
||||||
|
kv_shmem[kv_off + 1u] = v4.y;
|
||||||
|
kv_shmem[kv_off + 2u] = v4.z;
|
||||||
|
kv_shmem[kv_off + 3u] = v4.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
let tile_sum = subgroupAdd(local_sum);
|
||||||
|
exp_sum += tile_sum;
|
||||||
|
row_max = new_max;
|
||||||
|
|
||||||
|
if (row_active) {
|
||||||
|
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||||
|
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||||
|
if (chunk >= V_CHUNKS) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
var acc = out_regs[reg_idx];
|
||||||
|
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
|
||||||
|
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||||
|
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||||
|
let v4 = vec4<f32>(
|
||||||
|
f32(kv_shmem[kv_off + 0u]),
|
||||||
|
f32(kv_shmem[kv_off + 1u]),
|
||||||
|
f32(kv_shmem[kv_off + 2u]),
|
||||||
|
f32(kv_shmem[kv_off + 3u]));
|
||||||
|
acc += p * v4;
|
||||||
|
}
|
||||||
|
out_regs[reg_idx] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef SINKS
|
||||||
|
if (row_active) {
|
||||||
|
let sink_score = sinks[params.offset_sinks + head_idx];
|
||||||
|
let sink_max = max(row_max, sink_score);
|
||||||
|
let sink_scale = exp(row_max - sink_max);
|
||||||
|
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
|
||||||
|
out_regs[reg_idx] *= sink_scale;
|
||||||
|
}
|
||||||
|
exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max);
|
||||||
|
row_max = sink_max;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (row_active) {
|
||||||
|
let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||||
|
let row_base = dst_global_offset + subgroup_id * dst2_stride;
|
||||||
|
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
|
||||||
|
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
|
||||||
|
let chunk = sg_inv_id + reg_idx * subgroup_size;
|
||||||
|
if (chunk >= V_CHUNKS) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
|
||||||
|
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ struct Params {
|
|||||||
nblk1: u32,
|
nblk1: u32,
|
||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(0) var<storage, read> mask: array<f16>;
|
@group(0) @binding(0) var<storage, read_write> mask: array<f16>;
|
||||||
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
|
||||||
@group(0) @binding(2) var<uniform> params: Params;
|
@group(0) @binding(2) var<uniform> params: Params;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
|
||||||
diagnostic(off, subgroup_uniformity);
|
diagnostic(off, subgroup_uniformity);
|
||||||
enable f16;
|
enable f16;
|
||||||
enable subgroups;
|
enable subgroups;
|
||||||
enable chromium_experimental_subgroup_matrix;
|
|
||||||
|
|
||||||
#ifdef KV_F32
|
#ifdef KV_F32
|
||||||
#define KV_TYPE f32
|
#define KV_TYPE f32
|
||||||
@@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix;
|
|||||||
#define HEAD_DIM_QK 64
|
#define HEAD_DIM_QK 64
|
||||||
#define HEAD_DIM_V 64
|
#define HEAD_DIM_V 64
|
||||||
|
|
||||||
|
#define KV_GRANULARITY 8
|
||||||
#define SG_MAT_M 8
|
|
||||||
#define SG_MAT_N 8
|
|
||||||
#define SG_MAT_K 8
|
|
||||||
|
|
||||||
#define Q_TILE SG_MAT_M
|
|
||||||
#define KV_TILE 16
|
#define KV_TILE 16
|
||||||
#define WG_SIZE 64
|
#define WG_SIZE 64
|
||||||
#ifndef VEC_NE
|
#ifndef VEC_NE
|
||||||
#define VEC_NE 4u
|
#define VEC_NE 4u
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
#define KV_BLOCKS (KV_TILE / KV_GRANULARITY)
|
||||||
|
|
||||||
#define BLOCK_SIZE 32
|
#define BLOCK_SIZE 32
|
||||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||||
@@ -97,6 +90,14 @@ struct Params {
|
|||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||||
|
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||||
|
#else
|
||||||
|
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||||
|
#endif
|
||||||
|
#define V K
|
||||||
|
#else
|
||||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||||
#else
|
#else
|
||||||
@@ -107,7 +108,22 @@ struct Params {
|
|||||||
#else
|
#else
|
||||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
#if defined(MASK) && defined(SINKS)
|
#if defined(MASK) && defined(SINKS)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#ifdef BLK
|
||||||
|
#define BLK_BINDING 4
|
||||||
|
#define TMP_BINDING 5
|
||||||
|
#define DST_BINDING 6
|
||||||
|
#define PARAMS_BINDING 7
|
||||||
|
#else
|
||||||
|
#define TMP_BINDING 4
|
||||||
|
#define DST_BINDING 5
|
||||||
|
#define PARAMS_BINDING 6
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
|
||||||
#ifdef BLK
|
#ifdef BLK
|
||||||
@@ -120,7 +136,21 @@ struct Params {
|
|||||||
#define DST_BINDING 6
|
#define DST_BINDING 6
|
||||||
#define PARAMS_BINDING 7
|
#define PARAMS_BINDING 7
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
#elif defined(MASK)
|
#elif defined(MASK)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
|
||||||
|
#ifdef BLK
|
||||||
|
#define BLK_BINDING 3
|
||||||
|
#define TMP_BINDING 4
|
||||||
|
#define DST_BINDING 5
|
||||||
|
#define PARAMS_BINDING 6
|
||||||
|
#else
|
||||||
|
#define TMP_BINDING 3
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
|
||||||
#ifdef BLK
|
#ifdef BLK
|
||||||
#define BLK_BINDING 4
|
#define BLK_BINDING 4
|
||||||
@@ -132,16 +162,30 @@ struct Params {
|
|||||||
#define DST_BINDING 5
|
#define DST_BINDING 5
|
||||||
#define PARAMS_BINDING 6
|
#define PARAMS_BINDING 6
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
#elif defined(SINKS)
|
#elif defined(SINKS)
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
|
||||||
|
#define TMP_BINDING 3
|
||||||
|
#define DST_BINDING 4
|
||||||
|
#define PARAMS_BINDING 5
|
||||||
|
#else
|
||||||
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
|
||||||
#define TMP_BINDING 4
|
#define TMP_BINDING 4
|
||||||
#define DST_BINDING 5
|
#define DST_BINDING 5
|
||||||
#define PARAMS_BINDING 6
|
#define PARAMS_BINDING 6
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#ifdef KV_OVERLAP
|
||||||
|
#define TMP_BINDING 2
|
||||||
|
#define DST_BINDING 3
|
||||||
|
#define PARAMS_BINDING 4
|
||||||
#else
|
#else
|
||||||
#define TMP_BINDING 3
|
#define TMP_BINDING 3
|
||||||
#define DST_BINDING 4
|
#define DST_BINDING 4
|
||||||
#define PARAMS_BINDING 5
|
#define PARAMS_BINDING 5
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef BLK
|
#ifdef BLK
|
||||||
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
|
||||||
@@ -153,7 +197,7 @@ struct Params {
|
|||||||
// Just a very small float value.
|
// Just a very small float value.
|
||||||
const FLOAT_MIN: f32 = -1.0e9;
|
const FLOAT_MIN: f32 = -1.0e9;
|
||||||
|
|
||||||
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
|
var<workgroup> q_shmem: array<f16, HEAD_DIM_QK>;
|
||||||
|
|
||||||
#ifndef KV_DIRECT
|
#ifndef KV_DIRECT
|
||||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||||
@@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
|||||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
|
var<workgroup> o_shmem: array<f16, HEAD_DIM_V>;
|
||||||
|
|
||||||
#ifdef MASK
|
#ifdef MASK
|
||||||
// storage for mask values
|
// storage for mask values
|
||||||
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
|
var<workgroup> mask_shmem: array<f16, KV_TILE>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// note that we reuse the same storage for both since we only need one at a time
|
// note that we reuse the same storage for both since we only need one at a time
|
||||||
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
|
var<workgroup> inter_shmem: array<f16, KV_TILE>;
|
||||||
|
|
||||||
// Storage for row max and exp sum during online softmax
|
// Storage for row max and exp sum during online softmax
|
||||||
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
|
fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
||||||
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
|
|
||||||
var<workgroup> blk_state_wg: u32;
|
|
||||||
|
|
||||||
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
|
|
||||||
var v = select(FLOAT_MIN,
|
var v = select(FLOAT_MIN,
|
||||||
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
|
f32(inter_shmem[kv_idx]) * params.scale,
|
||||||
kv_idx < KV_TILE);
|
kv_idx < KV_TILE);
|
||||||
#ifdef LOGIT_SOFTCAP
|
#ifdef LOGIT_SOFTCAP
|
||||||
v = params.logit_softcap * tanh(v);
|
v = params.logit_softcap * tanh(v);
|
||||||
#endif
|
#endif
|
||||||
#ifdef MASK
|
#ifdef MASK
|
||||||
if (apply_mask) {
|
if (apply_mask) {
|
||||||
var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
|
var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE);
|
||||||
v += select(mask_val, slope * mask_val, has_bias);
|
v += select(mask_val, slope * mask_val, has_bias);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
@builtin(subgroup_size) subgroup_size: u32,
|
@builtin(subgroup_size) subgroup_size: u32,
|
||||||
@builtin(num_subgroups) num_subgroups: u32,
|
@builtin(num_subgroups) num_subgroups: u32,
|
||||||
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
|
||||||
|
// Vec path processes exactly one query row per workgroup, so subgroup 0 can
|
||||||
|
// keep the running softmax state in private storage.
|
||||||
|
var row_max = FLOAT_MIN;
|
||||||
|
var exp_sum = 0.0;
|
||||||
|
|
||||||
// initialize row max for online softmax
|
for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) {
|
||||||
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
|
|
||||||
row_max_shmem[i] = FLOAT_MIN;
|
|
||||||
exp_sum_shmem[i] = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
|
|
||||||
o_shmem[i] = 0.0;
|
o_shmem[i] = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// workgroups per head/batch
|
// workgroups per head/batch
|
||||||
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
|
let wg_per_head = params.seq_len_q;
|
||||||
let wg_per_batch = wg_per_head * params.n_heads;
|
let wg_per_batch = wg_per_head * params.n_heads;
|
||||||
|
|
||||||
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
let dst2_stride = HEAD_DIM_V * params.n_heads;
|
||||||
@@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
|
||||||
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
|
||||||
|
|
||||||
// starting Q row for this workgroup
|
// Vec path handles one Q row per workgroup.
|
||||||
let wg_in_head = wg_in_batch % wg_per_head;
|
let wg_in_head = wg_in_batch % wg_per_head;
|
||||||
let q_row_start = wg_in_head * Q_TILE;
|
let q_row_start = wg_in_head;
|
||||||
|
|
||||||
#ifdef MASK
|
#ifdef MASK
|
||||||
// mask offset
|
// mask offset
|
||||||
@@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let has_bias = params.max_bias > 0.0;
|
let has_bias = params.max_bias > 0.0;
|
||||||
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
|
||||||
|
|
||||||
// load q tile into shared memory
|
// load the single Q row into shared memory
|
||||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||||
let q_row = elem_idx / HEAD_DIM_QK;
|
let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1;
|
||||||
let q_col = elem_idx % HEAD_DIM_QK;
|
|
||||||
let head_q_row = q_row_start + q_row;
|
|
||||||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
|
||||||
q_shmem[elem_idx] = f16(select(
|
q_shmem[elem_idx] = f16(select(
|
||||||
0.0,
|
0.0,
|
||||||
Q[global_q_row_offset + q_col],
|
Q[global_q_row_offset + elem_idx],
|
||||||
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
|
q_row_start < params.seq_len_q));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||||
#ifdef BLK
|
#ifdef BLK
|
||||||
let q_blk = q_row_start / Q_TILE;
|
let q_blk = q_row_start;
|
||||||
let kv_blk = kv_tile / KV_TILE;
|
let kv_blk = kv_tile / KV_TILE;
|
||||||
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
|
||||||
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
|
||||||
@@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
#else
|
#else
|
||||||
let blk_state_local = 1u;
|
let blk_state_local = 1u;
|
||||||
#endif
|
#endif
|
||||||
if (local_id.x == 0u) {
|
let blk_state = blk_state_local;
|
||||||
blk_state_wg = blk_state_local;
|
|
||||||
}
|
|
||||||
workgroupBarrier();
|
|
||||||
let blk_state = blk_state_wg;
|
|
||||||
let skip_tile = blk_state == 0u;
|
let skip_tile = blk_state == 0u;
|
||||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||||
inter_shmem[elem_idx] = f16(0.0);
|
inter_shmem[elem_idx] = f16(0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let num_of_threads = subgroup_size / VEC_NE;
|
let num_of_threads = subgroup_size / VEC_NE;
|
||||||
let tx = sg_inv_id % num_of_threads;
|
let tx = sg_inv_id % num_of_threads;
|
||||||
let ty = sg_inv_id / num_of_threads;
|
let ty = sg_inv_id / num_of_threads;
|
||||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||||
let global_q_row = q_row_start + q_tile_row;
|
|
||||||
if (global_q_row >= params.seq_len_q) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
|
|
||||||
|
|
||||||
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
|
||||||
let kv_idx = kv_base + ty;
|
let kv_idx = kv_base + ty;
|
||||||
var partial_sum: f32 = 0.0;
|
var partial_sum: f32 = 0.0;
|
||||||
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
|
||||||
if (kv_valid) {
|
if (kv_valid) {
|
||||||
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
|
||||||
let q_off = local_q_row_offset + i * 4u;
|
let q_off = i * 4u;
|
||||||
|
|
||||||
let qv = vec4<f32>(
|
let qv = vec4<f32>(
|
||||||
f32(q_shmem[q_off + 0u]),
|
f32(q_shmem[q_off + 0u]),
|
||||||
@@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
|
|
||||||
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
|
||||||
if (tx == 0u && kv_valid) {
|
if (tx == 0u && kv_valid) {
|
||||||
let dst_idx = q_tile_row * KV_TILE + kv_idx;
|
inter_shmem[kv_idx] = f16(sum_bcast);
|
||||||
inter_shmem[dst_idx] = f16(sum_bcast);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
let apply_mask = !skip_tile && (blk_state != 2u);
|
let apply_mask = !skip_tile && (blk_state != 2u);
|
||||||
if (apply_mask) {
|
if (apply_mask) {
|
||||||
// load mask tile into shared memory for this KV block
|
// load mask tile into shared memory for this KV block
|
||||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) {
|
||||||
let mask_row = elem_idx / KV_TILE;
|
let global_k_col = kv_tile + elem_idx;
|
||||||
let mask_col = elem_idx % KV_TILE;
|
let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv;
|
||||||
let global_q_row = q_row_start + mask_row;
|
let mask_idx = mask_global_offset + global_k_col;
|
||||||
let global_k_col = kv_tile + mask_col;
|
|
||||||
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
|
|
||||||
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
|
|
||||||
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
workgroupBarrier();
|
workgroupBarrier();
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (!skip_tile) {
|
if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||||
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
|
var prev_max = row_max;
|
||||||
let global_q_row = q_row_start + q_tile_row;
|
var final_max = prev_max;
|
||||||
if (global_q_row >= params.seq_len_q) {
|
// pass 1: compute final max across the full KV tile in chunks
|
||||||
break;
|
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||||
}
|
let kv_idx = kv_offset + sg_inv_id;
|
||||||
|
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
||||||
|
let softmax_term = select(FLOAT_MIN,
|
||||||
|
calc_softmax_term(kv_idx, slope, has_bias, apply_mask),
|
||||||
|
kv_valid);
|
||||||
|
final_max = subgroupMax(max(final_max, softmax_term));
|
||||||
|
}
|
||||||
|
|
||||||
var prev_max = row_max_shmem[q_tile_row];
|
var total_exp_term: f32 = 0.0;
|
||||||
var final_max = prev_max;
|
// pass 2: compute exp sum and write P using final_max
|
||||||
// pass 1: compute final max across the full KV tile in chunks
|
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
||||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
let kv_idx = kv_offset + sg_inv_id;
|
||||||
let kv_idx = kv_offset + sg_inv_id;
|
let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask);
|
||||||
let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
|
let cur_p = select(0.0,
|
||||||
let softmax_term = select(FLOAT_MIN,
|
exp(softmax_term - final_max),
|
||||||
calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
|
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
||||||
kv_valid);
|
total_exp_term += subgroupAdd(cur_p);
|
||||||
final_max = subgroupMax(max(final_max, softmax_term));
|
if (kv_idx < KV_TILE) {
|
||||||
|
inter_shmem[kv_idx] = f16(cur_p);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var total_exp_term: f32 = 0.0;
|
let cur_exp = exp(prev_max - final_max);
|
||||||
// pass 2: compute exp sum and write P using final_max
|
|
||||||
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
|
|
||||||
let kv_idx = kv_offset + sg_inv_id;
|
|
||||||
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
|
|
||||||
let cur_p = select(0.0,
|
|
||||||
exp(softmax_term - final_max),
|
|
||||||
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
|
|
||||||
total_exp_term += subgroupAdd(cur_p);
|
|
||||||
if (kv_idx < KV_TILE) {
|
|
||||||
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let cur_exp = exp(prev_max - final_max);
|
row_max = final_max;
|
||||||
|
exp_sum = exp_sum * cur_exp + total_exp_term;
|
||||||
|
|
||||||
if (sg_inv_id == 0) {
|
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||||
row_max_shmem[q_tile_row] = final_max;
|
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp);
|
||||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
|
||||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
|
||||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
workgroupBarrier();
|
workgroupBarrier();
|
||||||
|
|
||||||
if (!skip_tile) {
|
if (!skip_tile) {
|
||||||
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
// we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
|
||||||
// we want to compute O += P * V across the full KV tile
|
// we want to compute O += P * V across the full KV tile
|
||||||
let ne_threads : u32 = VEC_NE;
|
let ne_threads : u32 = VEC_NE;
|
||||||
let nl_threads = max(1u, subgroup_size / ne_threads);
|
let nl_threads = max(1u, subgroup_size / ne_threads);
|
||||||
let tx_pv = sg_inv_id % nl_threads;
|
let tx_pv = sg_inv_id % nl_threads;
|
||||||
let ty_pv = sg_inv_id / nl_threads;
|
let ty_pv = sg_inv_id / nl_threads;
|
||||||
for (var q_tile_row = subgroup_id;
|
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||||
q_tile_row < Q_TILE;
|
|
||||||
q_tile_row += num_subgroups) {
|
|
||||||
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
|
||||||
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||||
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
|
||||||
@@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
|
let p = f32(inter_shmem[kv_idx]);
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
|
||||||
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
let v4 = vec4<f32>(V[v_idx >> 2u]);
|
||||||
@@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
|
|
||||||
if (ty_pv == 0u) {
|
if (ty_pv == 0u) {
|
||||||
let elem_base = vec_col * 4u;
|
let elem_base = vec_col * 4u;
|
||||||
let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
|
o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x);
|
||||||
o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
|
o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y);
|
||||||
o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
|
o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z);
|
||||||
o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
|
o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w);
|
||||||
o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
|
|
||||||
#ifdef SINKS
|
#ifdef SINKS
|
||||||
// Sinks are global terms and must be applied exactly once across split workgroups.
|
// Sinks are global terms and must be applied exactly once across split workgroups.
|
||||||
if (iwg == 0u) {
|
if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||||
for (var q_tile_row = subgroup_id;
|
var prev_max = row_max;
|
||||||
q_tile_row < Q_TILE;
|
|
||||||
q_tile_row += num_subgroups) {
|
|
||||||
let global_q_row = q_row_start + q_tile_row;
|
|
||||||
if (global_q_row >= params.seq_len_q) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
var prev_max = row_max_shmem[q_tile_row];
|
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
||||||
|
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u);
|
||||||
|
let new_max = subgroupMax(max(prev_max, sink_val));
|
||||||
|
let max_exp = exp(prev_max - new_max);
|
||||||
|
let sink_exp = exp(sink_val - new_max);
|
||||||
|
|
||||||
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
|
let sink_exp_sum = subgroupAdd(sink_exp);
|
||||||
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
|
|
||||||
let new_max = subgroupMax(max(prev_max, sink_val));
|
|
||||||
let max_exp = exp(prev_max - new_max);
|
|
||||||
let sink_exp = exp(sink_val - new_max);
|
|
||||||
|
|
||||||
let sink_exp_sum = subgroupAdd(sink_exp);
|
row_max = new_max;
|
||||||
|
exp_sum = exp_sum * max_exp + sink_exp_sum;
|
||||||
|
|
||||||
if (sg_inv_id == 0) {
|
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
||||||
row_max_shmem[q_tile_row] = new_max;
|
o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp);
|
||||||
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
|
|
||||||
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
|
|
||||||
o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
workgroupBarrier();
|
|
||||||
}
|
}
|
||||||
|
workgroupBarrier();
|
||||||
#endif
|
#endif
|
||||||
let rows_per_batch = params.n_heads * params.seq_len_q;
|
let rows_per_batch = params.n_heads * params.seq_len_q;
|
||||||
for (var q_tile_row = subgroup_id;
|
if (subgroup_id == 0u && q_row_start < params.seq_len_q) {
|
||||||
q_tile_row < Q_TILE;
|
|
||||||
q_tile_row += num_subgroups) {
|
|
||||||
|
|
||||||
let global_q_row = q_row_start + q_tile_row;
|
|
||||||
if (global_q_row >= params.seq_len_q) { break; }
|
|
||||||
|
|
||||||
if (params.nwg == 1u) {
|
if (params.nwg == 1u) {
|
||||||
let exp_sum = exp_sum_shmem[q_tile_row];
|
|
||||||
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
|
||||||
let row_base: u32 =
|
let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride +
|
||||||
params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
|
head_idx * HEAD_DIM_V;
|
||||||
|
|
||||||
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
|
||||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
|
||||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
|
||||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
|
||||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
|
||||||
|
|
||||||
let v = vec4<f32>(
|
let v = vec4<f32>(
|
||||||
f32(o_shmem[i0]) * scale,
|
f32(o_shmem[elem_base + 0u]) * scale,
|
||||||
f32(o_shmem[i1]) * scale,
|
f32(o_shmem[elem_base + 1u]) * scale,
|
||||||
f32(o_shmem[i2]) * scale,
|
f32(o_shmem[elem_base + 2u]) * scale,
|
||||||
f32(o_shmem[i3]) * scale
|
f32(o_shmem[elem_base + 3u]) * scale
|
||||||
);
|
);
|
||||||
|
|
||||||
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
|
||||||
dst[dst_vec_index] = v;
|
dst[dst_vec_index] = v;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
|
let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start;
|
||||||
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
|
||||||
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
|
||||||
|
|
||||||
@@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||||||
elem_base < HEAD_DIM_V;
|
elem_base < HEAD_DIM_V;
|
||||||
elem_base += subgroup_size * 4u) {
|
elem_base += subgroup_size * 4u) {
|
||||||
|
|
||||||
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
|
|
||||||
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
|
|
||||||
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
|
|
||||||
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
|
|
||||||
|
|
||||||
let tbase = tmp_row_data_base + elem_base;
|
let tbase = tmp_row_data_base + elem_base;
|
||||||
tmp[tbase + 0u] = f32(o_shmem[i0]);
|
tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]);
|
||||||
tmp[tbase + 1u] = f32(o_shmem[i1]);
|
tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]);
|
||||||
tmp[tbase + 2u] = f32(o_shmem[i2]);
|
tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]);
|
||||||
tmp[tbase + 3u] = f32(o_shmem[i3]);
|
tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sg_inv_id == 0u) {
|
if (sg_inv_id == 0u) {
|
||||||
tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
|
tmp[tmp_row_stats_base + 0u] = exp_sum;
|
||||||
tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
|
tmp[tmp_row_stats_base + 1u] = row_max;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user