diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 955903418..0d9b5e289 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2693,6 +2693,39 @@ static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X + + if (!src0 || !src1) { + return false; + } + + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (src0->ne[0] != src0->ne[1]) { + return false; + } + + if (src0->ne[1] != src1->ne[1]) { + return false; + } + + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; + } + + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2731,7 +2764,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; - + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3277,6 +3310,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_diag(sess, op); break; + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index b1ae60a9c..8bd528478 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -36,6 +36,7 @@ add_library(${HTP_LIB} SHARED cumsum-ops.c fill-ops.c diag-ops.c + solve-tri-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index f8c89211a..d704fedee 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -103,5 +103,6 @@ int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 56d7b398d..4397245c5 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -82,7 +82,7 @@ enum htp_op_code { HTP_OP_CUMSUM, HTP_OP_FILL, HTP_OP_DIAG, - + HTP_OP_SOLVE_TRI, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ed6026e76..d0926dedd 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + #else static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) @@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_vmpy_VhfVhf(a, b); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + #endif // __HVX_ARCH__ < 79 #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 088434a63..db277a25e 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -573,6 +573,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_DIAG: return op_diag(octx); + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 000000000..ae8e1a504 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +}