vulkan: support larger argsort (#17313)
* vulkan: support larger argsort This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024. * Reduce the number of shader variants. Use smaller workgroups when doing a single pass, for a modest perf boost * reduce loop overhead * run multiple cols per invocation, to reduce barrier overhead
This commit is contained in:
@@ -406,8 +406,8 @@ enum shader_reduction_mode {
|
||||
SHADER_REDUCTION_MODE_COUNT,
|
||||
};
|
||||
|
||||
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
@@ -526,6 +526,7 @@ struct vk_device_struct {
|
||||
bool multi_add;
|
||||
bool shader_int64;
|
||||
bool buffer_device_address;
|
||||
bool vulkan_memory_model;
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
@@ -539,6 +540,9 @@ struct vk_device_struct {
|
||||
uint32_t subgroup_max_size;
|
||||
bool subgroup_require_full_support;
|
||||
|
||||
// floor(log2(maxComputeWorkGroupInvocations))
|
||||
uint32_t max_workgroup_size_log2 {};
|
||||
|
||||
bool coopmat_support;
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
@@ -684,6 +688,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
@@ -1174,8 +1179,14 @@ struct vk_op_soft_max_push_constants {
|
||||
|
||||
struct vk_op_argsort_push_constants {
|
||||
uint32_t ncols;
|
||||
uint32_t ncols_padded;
|
||||
uint32_t ncols_padded_log2;
|
||||
uint32_t nrows;
|
||||
int32_t order;
|
||||
uint32_t order;
|
||||
uint32_t outer_start;
|
||||
uint32_t outer_end;
|
||||
uint32_t inner_start;
|
||||
uint32_t inner_end;
|
||||
};
|
||||
|
||||
struct vk_op_im2col_push_constants {
|
||||
@@ -3895,7 +3906,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
|
||||
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||
if (i <= device->max_workgroup_size_log2 &&
|
||||
2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||
}
|
||||
const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
|
||||
BLOCK_SIZE /= WG_UNROLL_FACTOR;
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
@@ -4296,6 +4315,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
||||
// Try to find a non-graphics compute queue and transfer-focused queues
|
||||
@@ -4435,6 +4456,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
device->shader_int64 = device_features2.features.shaderInt64;
|
||||
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
||||
device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
|
||||
|
||||
if (device->subgroup_size_control) {
|
||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||
@@ -8359,19 +8381,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_ARGSORT:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
||||
return ctx->device->pipeline_topk_moe[idx][mode];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
return ctx->device->pipeline_argsort_f32[idx];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
@@ -8763,8 +8772,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
@@ -9891,16 +9899,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
}
|
||||
|
||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
int32_t * op_params = (int32_t *)dst->op_params;
|
||||
const uint32_t * op_params = (const uint32_t *)dst->op_params;
|
||||
|
||||
uint32_t ncols = src0->ne[0];
|
||||
uint32_t nrows = ggml_nrows(src0);
|
||||
|
||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||
ncols,
|
||||
nrows,
|
||||
op_params[0],
|
||||
});
|
||||
uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
|
||||
uint32_t ncolsp2 = 1 << ncols_pad_log2;
|
||||
|
||||
vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
|
||||
|
||||
// Pick the largest workgroup size <= ncolsp2
|
||||
uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
|
||||
|
||||
// Use the "small" argsort shader if the whole sort can be done by a single workgroup.
|
||||
bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
|
||||
ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
|
||||
|
||||
vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
|
||||
: ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||
|
||||
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer subbuf1 = dst_buf;
|
||||
|
||||
// Reserve space for ivec2 per element, with rows padded to a power of two
|
||||
if (!use_small) {
|
||||
const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_x_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
elements[0] = ncolsp2;
|
||||
elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = 1;
|
||||
|
||||
// First dispatch initializes tmp_idx and does the first N passes where
|
||||
// there is only communication between threads in the same workgroup.
|
||||
{
|
||||
vk_op_argsort_push_constants pc2 = pc;
|
||||
pc2.outer_start = 0;
|
||||
pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
|
||||
pc2.inner_start = 0;
|
||||
pc2.inner_end = 100;
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||
}
|
||||
if (!use_small) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
// Loop over outer/inner passes, synchronizing between each pass.
|
||||
for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
|
||||
for (uint32_t inner = 0; inner < outer + 1; ++inner) {
|
||||
vk_op_argsort_push_constants pc2 = pc;
|
||||
pc2.outer_start = outer;
|
||||
pc2.outer_end = outer + 1;
|
||||
pc2.inner_start = inner;
|
||||
pc2.inner_end = inner + 1;
|
||||
// When the inner idx is large enough, there's only communication
|
||||
// within a workgroup. So the remaining inner iterations can all
|
||||
// run in the same dispatch.
|
||||
if (outer - inner < pipeline_idx) {
|
||||
pc2.inner_end = 100;
|
||||
inner = outer;
|
||||
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
|
||||
} else {
|
||||
// Smaller workgroup empirically seems to perform better
|
||||
pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
}
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -13721,7 +13802,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_LOG:
|
||||
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_ARGSORT:
|
||||
return op->ne[0] <= max_argsort_cols;
|
||||
{
|
||||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
// pipeline_argsort_large_f32 requires vulkan memory model.
|
||||
if (device->vulkan_memory_model) {
|
||||
return true;
|
||||
} else {
|
||||
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||
}
|
||||
}
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
|
||||
Reference in New Issue
Block a user