|
|
|
@@ -944,6 +944,7 @@ struct vk_mat_mat_push_constants {
|
|
|
|
|
uint32_t M; uint32_t N; uint32_t K;
|
|
|
|
|
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
|
|
|
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
|
|
|
|
uint32_t base_work_group_z; uint32_t num_batches;
|
|
|
|
|
uint32_t k_split;
|
|
|
|
|
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
|
|
|
|
uint32_t padded_N;
|
|
|
|
@@ -963,6 +964,7 @@ struct vk_mat_vec_push_constants {
|
|
|
|
|
uint32_t batch_stride_b;
|
|
|
|
|
uint32_t batch_stride_d;
|
|
|
|
|
uint32_t fusion_flags;
|
|
|
|
|
uint32_t base_work_group_y;
|
|
|
|
|
uint32_t ne02;
|
|
|
|
|
uint32_t ne12;
|
|
|
|
|
uint32_t broadcast2;
|
|
|
|
@@ -6773,8 +6775,16 @@ static void ggml_vk_matmul(
|
|
|
|
|
uint32_t padded_n) {
|
|
|
|
|
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
|
|
|
|
if (split_k == 1) {
|
|
|
|
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
|
|
|
|
|
|
|
|
|
uint32_t base_work_group_z = 0;
|
|
|
|
|
while (base_work_group_z < batch) {
|
|
|
|
|
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
|
|
|
|
|
|
|
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
|
|
|
|
|
base_work_group_z += groups_z;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -6788,9 +6798,17 @@ static void ggml_vk_matmul(
|
|
|
|
|
uint32_t k_split = CEIL_DIV(k, split_k);
|
|
|
|
|
k_split = ROUNDUP_POW2(k_split, 256);
|
|
|
|
|
|
|
|
|
|
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
|
|
|
// Make sure enough workgroups get assigned for split k to work
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
|
|
|
|
|
|
|
|
|
uint32_t base_work_group_z = 0;
|
|
|
|
|
while (base_work_group_z < batch) {
|
|
|
|
|
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
|
|
|
|
|
|
|
|
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
|
|
|
// Make sure enough workgroups get assigned for split k to work
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
|
|
|
|
|
base_work_group_z += groups_z;
|
|
|
|
|
}
|
|
|
|
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
|
|
|
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
|
|
|
@@ -7186,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Request descriptor sets
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
|
|
if (qx_needs_dequant) {
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
|
|
|
|
|
}
|
|
|
|
@@ -7484,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
|
|
if (quantize_y) {
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
|
|
|
}
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
|
|
|
@@ -7579,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
|
|
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute
|
|
|
|
|
const vk_mat_vec_push_constants pc = {
|
|
|
|
|
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
|
|
|
stride_batch_x, stride_batch_y, stride_batch_d,
|
|
|
|
|
fusion_flags,
|
|
|
|
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
|
|
|
|
};
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
|
|
|
{
|
|
|
|
|
d_X,
|
|
|
|
|
d_Y,
|
|
|
|
|
d_D,
|
|
|
|
|
d_F0,
|
|
|
|
|
d_F1,
|
|
|
|
|
},
|
|
|
|
|
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
|
|
|
|
|
|
|
|
|
|
uint32_t base_work_group_y = 0;
|
|
|
|
|
while (base_work_group_y < ne12 * ne13) {
|
|
|
|
|
|
|
|
|
|
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
|
|
|
const vk_mat_vec_push_constants pc = {
|
|
|
|
|
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
|
|
|
stride_batch_x, stride_batch_y, stride_batch_d,
|
|
|
|
|
fusion_flags, base_work_group_y,
|
|
|
|
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
|
|
|
|
};
|
|
|
|
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
|
|
|
{
|
|
|
|
|
d_X,
|
|
|
|
|
d_Y,
|
|
|
|
|
d_D,
|
|
|
|
|
d_F0,
|
|
|
|
|
d_F1,
|
|
|
|
|
},
|
|
|
|
|
pc, { groups_x, groups_y, groups_z });
|
|
|
|
|
base_work_group_y += groups_y;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_non_contig) {
|
|
|
|
|
ctx->prealloc_x_need_sync = true;
|
|
|
|
@@ -7832,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
|
|
src1->nb[2] <= src1->nb[1] &&
|
|
|
|
|
src1->nb[1] <= src1->nb[3] &&
|
|
|
|
|
src0->ne[3] == 1 &&
|
|
|
|
|
src1->ne[3] == 1) {
|
|
|
|
|
src1->ne[3] == 1 &&
|
|
|
|
|
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
|
|
|
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
|
|
|
|
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
|
|
|
|
|
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
|
|
|
|
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
|
|
|
|
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
|
|
|
|
|
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
|
|
|
|
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
|
|
|
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
|
|
|
|
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
|
|
|
|
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
|
|
|
|
// when ne12 and ne13 are one.
|
|
|
|
@@ -11560,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
|
|
|
|
if (split_k > 1) {
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
|
|
|
|
|
|
|
|
@@ -12069,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
|
|
// y[i] = i % k;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
|
|
|
|
if (split_k > 1) {
|
|
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
|
|
|
|
|
|
|
|
|