support permuted, remove check s0/s10 (#19889)

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
This commit is contained in:
Neo Zhang
2026-02-26 10:27:20 +08:00
committed by GitHub
parent 3769fe6eb7
commit 2943210c1e
+21 -20
View File
@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13, int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03, int s00, int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13, int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2); item_ct1.get_local_id(2);
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
for (int i0 = i0s; i0 < ne0; for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
} }
} }
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
int ne0, int ne1, int ne2, int ne3, int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13, int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03, int s00, int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13, int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
} }
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
queue_ptr stream) {
int nr0 = ne10 / ne0; int nr0 = ne10 / ne0;
int nr1 = ne11/ne1; int nr1 = ne11/ne1;
int nr2 = ne12/ne2; int nr2 = ne12/ne2;
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
cnb[3] *= cne[3]; cnb[3] *= cne[3];
}; };
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
if (nr[i] != 1) { if (nr[i] != 1) {
break; break;
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
size_t nb12 = cnb1[2]; size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3]; size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t); // size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t);
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128; const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL); int64_t hne0 = std::max(ne0/2LL, 1LL);
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>( k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
s03, s11, s12, s13, item_ct1); s03, s10, s11, s12, s13, item_ct1);
}); });
} }
} else { } else {
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13, ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13, s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,
item_ct1); item_ct1);
}); });
} }
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream); main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else { } else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));