CANN: add new ops, optimize existing ops (#21204)

New operators:
- GGML_OP_SET: implement via aclnnInplaceCopy on target region
- GGML_OP_CUMSUM: implement via aclnnCumsum
- GGML_OP_FILL: implement via aclnnInplaceFillScalar
- GGML_OP_DIAG: implement via aclnnInplaceCopy on diagonal strides
- GGML_OP_TRI (lower/lower_diag/upper_diag/upper): implement via
  aclnnTril(-1/0) and aclnnTriu(0/1) with appropriate diagonal offsets
- GGML_OP_SOLVE_TRI: implement via aclnnTriangularSolve
- GGML_UNARY_OP_SOFTPLUS: implement via aclnnSoftplus

Optimizations:
- GLU (SwiGLU/GeGLU/GeGLU_ERF/GeGLU_QUICK): fuse with aclnnSwiGlu /
  aclnnGeGluV3 when applicable; fallback conditions now checked inside
  each function rather than at the call site
- CROSS_ENTROPY_LOSS: replace 5-kernel sequence (LogSoftmax→Mul→
  ReduceSum×2→Muls) with single aclnnSoftmaxCrossEntropyWithLogits call
- L2_NORM: fix in-place ClampMin on norm result (was clamping wrong
  tensor); add eps clamping before division to avoid divide-by-zero
- PAD_REFLECT_1D: eliminate per-ne[3] loop; assert contiguity and call
  ReflectionPad1d once on the full 4-D view; remove redundant nb copies
- GET_ROWS: replace IndexSelect with GatherV2 per batch slice; refactor
  helper into gather_batched lambda with batch loop inlined
- SET_ROWS: replace IndexCopy with InplaceIndexCopy per batch slice;
  refactor helper into scatter_batched lambda with batch loop inlined
- OUT_PROD: replace O(ne[3]*ne[2]*ne[1]) Ger+InplaceAdd loop with
  per-slice Matmul loop (src0 @ src1^T); handles strided-broadcast
  batch dims where ne02/ne03 may differ from ne2/ne3
- backend memset_tensor: implement via aclrtMemset (was NULL)

Bug fixes:
- COUNT_EQUAL: use non-inplace EqTensor into a same-type temporary
  buffer instead of InplaceEqTensor, avoiding corruption of src0
- ACL graph cache (USE_ACL_GRAPH): restore node_type and src_type[]
  fields in ggml_graph_node_properties; has_matching_properties() was
  missing type checks, causing F16 and BF16 tensors (same nb[0]=2) to
  incorrectly share cached graphs and produce wrong results (ERR≈679)
- graph cache op_params matching: compare full GGML_MAX_OP_PARAMS
  bytes so that ops differing only in parameters are not incorrectly
  replayed from cache
This commit is contained in:
hipudding
2026-04-28 14:27:22 +08:00
committed by GitHub
parent 14e733e36f
commit c3e08f4700
3 changed files with 624 additions and 258 deletions
+512 -248
View File
@@ -25,6 +25,7 @@
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml.h" #include "ggml.h"
#include <aclnnop/aclnn_add.h> #include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_add_rms_norm.h> #include <aclnnop/aclnn_add_rms_norm.h>
#include <aclnnop/aclnn_addcdiv.h> #include <aclnnop/aclnn_addcdiv.h>
@@ -45,7 +46,9 @@
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h> #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h> #include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h> #include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_gather_v2.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h> #include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_scatter.h>
#include <aclnnop/aclnn_gt_scalar.h> #include <aclnnop/aclnn_gt_scalar.h>
#include <aclnnop/aclnn_im2col.h> #include <aclnnop/aclnn_im2col.h>
#include <aclnnop/aclnn_index_copy.h> #include <aclnnop/aclnn_index_copy.h>
@@ -62,6 +65,7 @@
#include <aclnnop/aclnn_permute.h> #include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow.h> #include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h> #include <aclnnop/aclnn_pow_tensor_tensor.h>
#include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
#include <aclnnop/aclnn_reduce_sum.h> #include <aclnnop/aclnn_reduce_sum.h>
#include <aclnnop/aclnn_reflection_pad1d.h> #include <aclnnop/aclnn_reflection_pad1d.h>
#include <aclnnop/aclnn_repeat.h> #include <aclnnop/aclnn_repeat.h>
@@ -69,11 +73,15 @@
#include <aclnnop/aclnn_rms_norm.h> #include <aclnnop/aclnn_rms_norm.h>
#include <aclnnop/aclnn_roll.h> #include <aclnnop/aclnn_roll.h>
#include <aclnnop/aclnn_softmax.h> #include <aclnnop/aclnn_softmax.h>
#include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h>
#include <aclnnop/aclnn_sub.h> #include <aclnnop/aclnn_sub.h>
#include <aclnnop/aclnn_sum.h> #include <aclnnop/aclnn_sum.h>
#include <aclnnop/aclnn_threshold.h> #include <aclnnop/aclnn_threshold.h>
#include <aclnnop/aclnn_tril.h> #include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triangular_solve.h>
#include <aclnnop/aclnn_triu.h> #include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_logical_not.h>
#include <aclnnop/aclnn_masked_fill_scalar.h>
#include <aclnnop/aclnn_upsample_nearest_2d.h> #include <aclnnop/aclnn_upsample_nearest_2d.h>
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h> #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
#include <aclnnop/aclnn_zero.h> #include <aclnnop/aclnn_zero.h>
@@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());
} }
// Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies
// SiLU to left half, multiplies by right half.
//
// Falls back to the generic two-kernel path when src[1] != nullptr (two
// independent halves) or swapped != 0 (reversed activation order), as
// aclnnSwiGlu only handles the single interleaved tensor in standard order.
//
// CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even.
// aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd.
// We use a 3D view (1+3=4, even) to satisfy this constraint while preserving
// correct split semantics along the innermost (ne[0]) dimension.
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst);
};
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
if (dst->src[1] != nullptr || swapped != 0) {
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
return;
}
// aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise.
if (dst->src[0]->ne[0] % 2 != 0) {
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
return;
}
ggml_tensor * src0 = dst->src[0];
size_t elem_size = ggml_element_size(src0);
// src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3]
// CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last).
int64_t ne0_x2 = src0->ne[0];
int64_t ne1 = src0->ne[1];
int64_t ne23 = src0->ne[2] * src0->ne[3];
int64_t src3d_ne[] = { ne0_x2, ne1, ne23 };
size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] };
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
elem_size, src3d_ne, src3d_nb, 3);
// dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3]
int64_t ne0 = dst->ne[0];
int64_t dst3d_ne[] = { ne0, ne1, ne23 };
size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] };
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
elem_size, dst3d_ne, dst3d_nb, 3);
// CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0.
GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get());
}
// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim),
// activates the LEFT half with GELU, multiplies by right half.
// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention.
// outGelu is a required-but-discard output buffer.
//
// Falls back to the generic two-kernel path when src[1] != nullptr (two
// independent halves) or swapped != 0 (reversed activation order), as
// aclnnGeGluV3 only handles the single interleaved tensor in standard order.
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) {
auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst);
};
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
if (dst->src[1] != nullptr || swapped != 0) {
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
return;
}
// aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise.
if (dst->src[0]->ne[0] % 2 != 0) {
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
return;
}
ggml_tensor * src0 = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
// Allocate a temporary buffer for the required outGelu output (same shape as dst).
// Build contiguous strides since the pool allocation is a fresh buffer.
size_t elem_size = ggml_element_size(dst);
int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
size_t nb[GGML_MAX_DIMS];
nb[0] = elem_size;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
nb[i] = nb[i - 1] * ne[i - 1];
}
size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1];
ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size);
acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor(
gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS);
// V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention.
// GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor).
GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true,
acl_dst.get(), acl_gelu_out.get());
}
/** /**
* @brief Repeats elements of a tensor along each dimension according to the * @brief Repeats elements of a tensor along each dimension according to the
* specified repeat array. * specified repeat array.
@@ -445,28 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes); ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
void * buffer = temp_buffer_allocator.get(); void * buffer = temp_buffer_allocator.get();
int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
size_t div_nb[GGML_MAX_DIMS]; size_t norm_nb[GGML_MAX_DIMS];
div_nb[0] = sizeof(float); norm_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; ++i) { for (int i = 1; i < GGML_MAX_DIMS; ++i) {
div_nb[i] = div_nb[i - 1] * div_ne[i - 1]; norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1];
} }
acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS); acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
std::vector<int64_t> norm_dims = { 3 }; std::vector<int64_t> norm_dims = { 3 };
acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size()); acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());
float p_value = 2.0f; float p_value = 2.0f;
acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get());
// Clamp norm to at least eps: scale = 1/fmaxf(norm, eps) ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool());
acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); acl_tensor_ptr acl_clamped;
float flt_max = FLT_MAX;
acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get());
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); if (eps > 0.0f) {
void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes);
acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get());
}
aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get();
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get());
} }
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -482,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
logits_nb[1] = logits_nb[0] * logits_ne[0]; logits_nb[1] = logits_nb[0] * logits_ne[0];
acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
size_t log_softmax_type_size = sizeof(float);
int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
void * log_softmax_buffer = log_softmax_allocator.get();
int64_t log_softmax_ne[] = { nc, nr };
size_t log_softmax_nb[2];
log_softmax_nb[0] = log_softmax_type_size;
log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,
log_softmax_ne, log_softmax_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());
int64_t labels_ne[] = { nc, nr }; int64_t labels_ne[] = { nc, nr };
size_t labels_nb[2]; size_t labels_nb[2];
labels_nb[0] = ggml_type_size(src1->type); labels_nb[0] = ggml_type_size(src1->type);
labels_nb[1] = labels_nb[0] * labels_ne[0]; labels_nb[1] = labels_nb[0] * labels_ne[0];
acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
size_t mul_type_size = sizeof(float); size_t loss_per_sample_type_size = sizeof(float);
int64_t mul_n_bytes = nr * nc * mul_type_size; int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size;
ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes);
void * mul_buffer = mul_allocator.get(); void * loss_per_sample_buffer = loss_per_sample_allocator.get();
int64_t mul_ne[] = { nc, nr }; int64_t loss_per_sample_ne[] = { nr };
size_t mul_nb[2]; size_t loss_per_sample_nb[1];
mul_nb[0] = mul_type_size; loss_per_sample_nb[0] = loss_per_sample_type_size;
mul_nb[1] = mul_nb[0] * mul_ne[0]; acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor(
acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1);
GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get()); size_t backprop_n_bytes = nr * nc * sizeof(float);
ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes);
void * backprop_buffer = backprop_allocator.get();
acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
size_t sum_per_sample_type_size = sizeof(float); GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(),
int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; acl_loss_per_sample.get(), acl_backprop.get());
ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
void * sum_per_sample_buffer = sum_per_sample_allocator.get();
int64_t sum_per_sample_ne[] = { nr };
size_t sum_per_sample_nb[1];
sum_per_sample_nb[0] = sum_per_sample_type_size;
acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(
sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
std::vector<int64_t> sum_dims = { 1 };
acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());
bool keep_dims = false;
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,
acl_sum_per_sample.get());
size_t total_sum_type_size = sizeof(float); size_t total_sum_type_size = sizeof(float);
int64_t total_sum_n_bytes = 1 * total_sum_type_size; int64_t total_sum_n_bytes = 1 * total_sum_type_size;
@@ -547,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
std::vector<int64_t> total_sum_dims = { 0 }; std::vector<int64_t> total_sum_dims = { 0 };
acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size()); acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());
bool keep_dims = false;
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
acl_total_sum.get()); acl_total_sum.get());
float value = -1.0f / static_cast<float>(nr); float value = 1.0f / static_cast<float>(nr);
acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
acl_tensor_ptr acl_dst = acl_tensor_ptr acl_dst =
ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
@@ -589,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
acl_mean_out.get(), acl_rstd_out.get()); acl_mean_out.get(), acl_rstd_out.get());
} }
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
size_t offset = ((int32_t *) dst->op_params)[3];
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };
// Create a view of dst at the target offset with src1's dimensions
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);
if (!inplace) {
// First copy src0 to dst entirely
size_t cpy_size = ggml_nbytes(dst);
ACL_CHECK(
aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
}
// Copy src1 into the target region of dst
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get());
}
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1]; ggml_tensor * src1 = dst->src[1];
@@ -652,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
aclnn_reduce_sum(ctx, dst, reduce_dims, 4); aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
} }
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
// GGML cumsum operates along dim 0 (innermost / ne[0]).
// ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
// so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
ggml_cann_type_mapping(dst->type), acl_dst.get());
}
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
// mOut: triangular copy of A (required output), same shape as A.
const size_t a_bytes = ggml_nbytes(src0);
ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
acl_tensor_ptr acl_m = ggml_cann_create_tensor(
m_alloc.get(), ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
// Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
acl_b.get(), acl_a.get(), false, false, false,
acl_x.get(), acl_m.get());
}
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
GGML_ASSERT(src->ne[1] == 1);
const int64_t N = src->ne[0];
const int64_t n_batch = src->ne[2] * src->ne[3];
const size_t nb_f32 = sizeof(float);
// Fill dst with zeros.
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
{
float zero = 0.0f;
acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get());
}
// Copy src vector onto the diagonal of dst via strided views.
// src viewed as [N, n_batch], contiguous strides.
int64_t ne_vec[2] = { N, n_batch };
size_t nb_src_vec[2] = { nb_f32, N * nb_f32 };
// dst diagonal view: stride (N+1)*4 steps along the diagonal.
size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 };
acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2);
acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get());
}
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float c = ggml_get_op_params_f32(dst, 0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
}
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
const int64_t S = src->ne[0];
const int64_t n_batch = src->ne[2] * src->ne[3];
const size_t nb_f32 = sizeof(float);
int64_t ne3d[3] = { S, S, n_batch };
size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
switch (ttype) {
case GGML_TRI_TYPE_LOWER:
// Tril(-1): preserve row > col (strict lower), zero upper + diagonal.
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
break;
case GGML_TRI_TYPE_UPPER_DIAG:
// Triu(0): preserve row <= col (upper + diagonal), zero strict lower.
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get());
break;
case GGML_TRI_TYPE_UPPER:
// Triu(1): preserve row < col (strict upper), zero lower + diagonal.
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get());
break;
case GGML_TRI_TYPE_LOWER_DIAG:
// Tril(0): preserve row >= col (lower + diagonal), zero strict upper.
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get());
break;
default:
GGML_ABORT("unsupported tri type");
}
}
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0]; ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
@@ -1695,152 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get()); aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());
} }
/**
* @brief Performs index select operation on a 4D tensor using the CANN backend.
*
* This function applies the `IndexSelect` operation along a specific dimension
* of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
* It iterates over the last two dimensions of the source tensor, creates the corresponding
* CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
* operation for each slice.
*
* @param ctx The context for CANN backend operations.
* @param src_buffer The source buffer containing the 4D input tensor data.
* @param src_ne The dimensions of the source tensor.
* @param src_nb The strides (byte offsets) of the source tensor.
* @param dst_buffer The destination buffer where the output tensor data will be written.
* @param dst_ne The dimensions of the destination tensor.
* @param dst_nb The strides (byte offsets) of the destination tensor.
* @param index The index tensor specifying the indices to select from the source tensor.
* @param type The data type of the source and destination tensors.
*/
static void aclnn_index_select_4d(ggml_backend_cann_context & ctx,
void * src_buffer,
int64_t * src_ne,
size_t * src_nb,
void * dst_buffer,
int64_t * dst_ne,
size_t * dst_nb,
ggml_tensor * index,
ggml_type type) {
for (int64_t i = 0; i < src_ne[3]; i++) {
for (int64_t j = 0; j < src_ne[2]; j++) {
// src
acl_tensor_ptr acl_src_tensor =
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
// index
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
// out
acl_tensor_ptr acl_out =
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());
}
}
}
/**
* @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
*
* This function applies the `IndexCopy` operation along a specific dimension of the
* destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
* to positions specified by the index tensor (`index`).
* It iterates over the last two dimensions of the tensors, creates the corresponding
* CANN tensors for source, index, and destination slices, and performs the index copy
* operation for each slice.
*
* @param ctx The context for CANN backend operations.
* @param src_buffer The source buffer containing the 4D input tensor data to be copied.
* @param src_ne The dimensions of the source tensor.
* @param src_nb The strides (byte offsets) of the source tensor.
* @param dst_buffer The destination buffer where values will be copied to.
* @param dst_ne The dimensions of the destination tensor.
* @param dst_nb The strides (byte offsets) of the destination tensor.
* @param index The index tensor specifying target positions in the destination tensor.
* @param type The data type of the source and destination tensors.
*/
static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,
void * src_buffer,
int64_t * src_ne,
size_t * src_nb,
void * dst_buffer,
int64_t * dst_ne,
size_t * dst_nb,
ggml_tensor * index,
ggml_type type) {
for (int64_t i = 0; i < src_ne[3]; i++) {
for (int64_t j = 0; j < src_ne[2]; j++) {
// src
acl_tensor_ptr acl_src_tensor =
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
// index
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
// out
acl_tensor_ptr acl_out =
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());
}
}
}
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // index ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16); || dst->type == GGML_TYPE_BF16);
// n_idx: number of row indices per (i2, i3) batch slice.
// ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1.
const int64_t n_idx = src1->ne[0];
// Gather all (i2, i3) batch slices from src into dst.
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0].
// GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis).
// nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
// nb[2..3] for computing per-batch-slice base pointer offsets).
auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
const size_t * nb) {
int64_t src_ne[2] = { src0->ne[0], src0->ne[1] };
size_t src_nb_2d[2] = { nb[0], nb[1] };
int64_t dst_ne[2] = { src0->ne[0], n_idx };
size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] };
int64_t idx_ne[1] = { n_idx };
size_t idx_nb[1] = { (size_t)ggml_element_size(src1) };
for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) {
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
(char *)src_base + i3 * nb[3] + i2 * nb[2],
acl_type, type_size, src_ne, src_nb_2d, 2);
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
(char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1],
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
idx_ne, idx_nb, 1);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
acl_type, type_size, dst_ne, dst_nb_2d, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get());
}
}
};
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
if (src0->type == dst->type) { if (src0->type == dst->type) {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, gather_batched(src0->data,
dst->type); ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
src0->nb);
} else { } else {
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); // Cast src0 to dst type, then gather.
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
void * src_trans_buffer = src_buffer_allocator.get(); ggml_nelements(src0) * ggml_element_size(dst));
size_t src_trans_nb[GGML_MAX_DIMS]; size_t src_cast_nb[GGML_MAX_DIMS];
src_trans_nb[0] = dst->nb[0]; src_cast_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
} }
acl_tensor_ptr src_trans_tensor = acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); src0->ne, src_cast_nb, GGML_MAX_DIMS);
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
dst->type);
gather_batched(src_cast_allocator.get(),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src_cast_nb);
} }
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
{ {
// add 1 dim for bcast mul. // Dequantize Q8_0 to dst type, then gather.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;
int64_t scale_offset = 0; weight_ne[0] = QK8_0;
// [3,4,5,64] -> [3,4,5,2,32] weight_ne[1] = src0->ne[0] / QK8_0;
weight_ne[0] = QK8_0; weight_nb[0] = sizeof(int8_t);
weight_ne[1] = src0->ne[0] / QK8_0; weight_nb[1] = weight_nb[0] * weight_ne[0];
weight_nb[0] = sizeof(int8_t);
weight_nb[1] = weight_nb[0] * weight_ne[0];
for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
weight_ne[i] = src0->ne[i - 1]; weight_ne[i] = src0->ne[i - 1];
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
} }
// [3,4,5,64] -> [3,4,5,2,1]
scale_ne[0] = 1; scale_ne[0] = 1;
scale_ne[1] = src0->ne[0] / QK8_0; scale_ne[1] = src0->ne[0] / QK8_0;
scale_nb[0] = sizeof(uint16_t); scale_nb[0] = sizeof(uint16_t);
@@ -1849,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
scale_ne[i] = src0->ne[i - 1]; scale_ne[i] = src0->ne[i - 1];
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
} }
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne; dequant_ne = weight_ne;
dequant_nb[0] = ggml_type_size(dst->type); dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
} }
scale_offset = ggml_nelements(src0) * sizeof(int8_t); const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), ggml_cann_pool_alloc dequant_allocator(ctx.pool(),
ggml_nelements(src0) * ggml_type_size(dst->type)); ggml_nelements(src0) * ggml_type_size(dst->type));
acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
weight_ne, weight_nb, GGML_MAX_DIMS + 1); weight_ne, weight_nb, GGML_MAX_DIMS + 1);
acl_tensor_ptr acl_scale_tensor = acl_tensor_ptr acl_scale = ggml_cann_create_tensor(
ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
acl_tensor_ptr dequant_tensor = acl_tensor_ptr acl_dequant = ggml_cann_create_tensor(
ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), dequant_allocator.get(), ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get()); aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get());
dequant_nb[0] = ggml_type_size(dst->type);
// Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides.
dequant_ne = src0->ne; dequant_ne = src0->ne;
dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
} }
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, gather_batched(dequant_allocator.get(),
dst->nb, src1, dst->type); ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
dequant_nb);
break; break;
} }
default: default:
@@ -1883,31 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
} }
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src ggml_tensor * src0 = dst->src[0]; // source values
ggml_tensor * src1 = dst->src[1]; // index ggml_tensor * src1 = dst->src[1]; // row indices
// n_idx: number of source rows to scatter per batch slice.
// ggml guarantees: src0->ne[1] == src1->ne[0].
const int64_t n_idx = src1->ne[0];
// Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index.
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst.
// InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis).
// src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
// nb[2..3] for computing per-batch-slice base pointer offsets).
auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
const size_t * src_nb) {
int64_t d_ne[2] = { dst->ne[0], dst->ne[1] };
size_t d_nb[2] = { dst->nb[0], dst->nb[1] };
int64_t s_ne[2] = { dst->ne[0], n_idx };
size_t s_nb_2d[2] = { src_nb[0], src_nb[1] };
int64_t i_ne[1] = { n_idx };
size_t i_nb[1] = { (size_t)ggml_element_size(src1) };
for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) {
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
acl_type, type_size, d_ne, d_nb, 2);
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
(char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1],
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
i_ne, i_nb, 1);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
(char *)src_base + i3 * src_nb[3] + i2 * src_nb[2],
acl_type, type_size, s_ne, s_nb_2d, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get());
}
}
};
switch (dst->type) { switch (dst->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ scatter_batched(src0->data,
aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
break; src0->nb);
} break;
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
{ {
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); // Cast src0 (F32) to dst type first.
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
void * src_trans_buffer = src_buffer_allocator.get(); ggml_nelements(src0) * ggml_type_size(dst->type));
size_t src_trans_nb[GGML_MAX_DIMS]; size_t src_cast_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(uint16_t); src_cast_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
} }
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, src0->ne, src_cast_nb, GGML_MAX_DIMS);
dst->type); aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
scatter_batched(src_cast_allocator.get(),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src_cast_nb);
break; break;
} }
default: default:
@@ -3268,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst
int64_t paddingsArray[2] = { opts[0], opts[1] }; int64_t paddingsArray[2] = { opts[0], opts[1] };
acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2); acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2);
for (int64_t i = 0; i < src0->ne[3]; i++) { // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3
acl_tensor_ptr acl_src = // is contiguous with respect to dim2 in both src and dst.
ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]);
ggml_element_size(src0), src0->ne, src0->nb, 3); GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]);
acl_tensor_ptr acl_dst = int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] };
ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] };
ggml_element_size(dst), dst->ne, dst->nb, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
} ggml_element_size(src0), src_ne_3d, src0->nb, 3);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
ggml_element_size(dst), dst_ne_3d, dst->nb, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
} }
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1]; ggml_tensor * src1 = dst->src[1];
// Write element-wise equality (0 or 1) into a temporary buffer to avoid
// modifying src0 in-place. Use the same type as src0 so ReduceSum can
// consume it directly without a type cast.
ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0));
size_t eq_nb[GGML_MAX_DIMS];
eq_nb[0] = ggml_element_size(src0);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr acl_eq = ggml_cann_create_tensor(
eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
src0->ne, eq_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1); acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);
GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get()); // Sum the 0/1 values into dst.
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
ggml_cann_sum(ctx, dst); int64_t dims[4] = { 0, 1, 2, 3 };
acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4);
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true,
ggml_cann_type_mapping(dst->type), acl_dst.get());
} }
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3306,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get()); GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());
} }
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
float beta_val = 1.0f;
float threshold_val = 20.0f;
acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT);
acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get());
}
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst);
}
/** /**
* @brief Performs expert-specific matrix multiplication (MoE) with * @brief Performs expert-specific matrix multiplication (MoE) with
* floating-point precision using the CANN backend. * floating-point precision using the CANN backend.
@@ -3892,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
} }
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03]
ggml_tensor * src1 = dst->src[1]; // input ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13]
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T.
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); //
// ggml_cann_create_tensor reverses dimension order, so ACL sees:
// acl_src0 slice: ggml[m,K] -> ACL[K,m]
// acl_src1 slice: ggml[n,K] -> ACL[K,n]
// acl_dst slice: ggml[m,n] -> ACL[n,m]
//
// Build a transposed view of src1 by swapping ne[0]/ne[1]:
// src1_t: ggml[K,n] (swapped strides) -> ACL[n,K]
//
// Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓
//
// The outer batch loop is kept because src0 may have fewer batch slices than
// dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported
// by standard CANN Matmul broadcasting.
const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type);
const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type);
const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type);
const size_t src0_type_sz = ggml_type_size(src0->type);
const size_t src1_type_sz = ggml_type_size(src1->type);
const size_t dst_type_sz = ggml_type_size(dst->type);
const int64_t dps2 = ne2 / ne02; const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03; const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2; const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3; const int64_t i03 = i3 / dps3;
const int64_t i12 = i2; // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m]
const int64_t i13 = i3; int64_t src0_ne[2] = { ne00, ne01 };
acl_tensor_ptr accumulator = size_t src0_nb[2] = { nb00, nb01 };
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor(
ggml_type_size(dst->type), dst->ne, dst->nb, 2); (char *) src0->data + i02 * nb02 + i03 * nb03,
src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2);
// The outer product needs to be accumulated in this dimension. // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K]
for (int64_t i1 = 0; i1 < ne11; i1++) { int64_t src1_t_ne[2] = { ne11, ne10 };
acl_tensor_ptr acl_input = ggml_cann_create_tensor( size_t src1_t_nb[2] = { nb11, nb10 };
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor(
ggml_type_size(src0->type), src1->ne, src1->nb, 1); (char *) src1->data + i2 * nb12 + i3 * nb13,
src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor( // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m]
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), int64_t dst_ne[2] = { ne0, ne1 };
ggml_type_size(src0->type), src0->ne, src0->nb, 1); size_t dst_nb[2] = { nb0, nb1 };
acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor(
(char *) dst->data + i2 * nb2 + i3 * nb3,
dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2);
ggml_cann_pool_alloc output_allocator(ctx.pool()); // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); GGML_CANN_CALL_ACLNN_OP(ctx, Matmul,
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1);
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
} }
} }
} }
@@ -4170,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
} }
} }
} }
+56
View File
@@ -32,6 +32,9 @@
#include <aclnnop/aclnn_cat.h> #include <aclnnop/aclnn_cat.h>
#include <aclnnop/aclnn_clamp.h> #include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_cos.h> #include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_cumsum.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_exp.h> #include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_gelu.h> #include <aclnnop/aclnn_gelu.h>
#include <aclnnop/aclnn_gelu_v2.h> #include <aclnnop/aclnn_gelu_v2.h>
@@ -47,6 +50,9 @@
#include <aclnnop/aclnn_sign.h> #include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_silu.h> #include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h> #include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_softplus.h>
#include <aclnnop/aclnn_swi_glu.h>
#include <aclnnop/aclnn_geglu.h>
#include <aclnnop/aclnn_slice.h> #include <aclnnop/aclnn_slice.h>
#include <aclnnop/aclnn_sqrt.h> #include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h> #include <aclnnop/aclnn_tanh.h>
@@ -69,6 +75,9 @@
*/ */
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate);
/** /**
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN * @brief Applies the Leaky ReLU activation function to a tensor using the CANN
* backend. * backend.
@@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the cumulative sum of a ggml tensor along dim 0 using the
* CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`.
*/
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes a triangular mask (tril/triu) of a square ggml tensor
* using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_TRI`.
*/
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Solves a triangular linear system AX=B using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`.
*/
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Creates a diagonal matrix from a vector using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_DIAG`.
*/
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Fills a tensor with a constant scalar value using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_FILL`.
*/
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/** /**
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using * @brief Upsamples a ggml tensor using nearest neighbor interpolation using
* the CANN backend. * the CANN backend.
@@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor *
// @see ggml_cann_dup. // @see ggml_cann_dup.
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
// @see ggml_cann_acc, but copies src1 into dst instead of adding.
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/** /**
* @brief Computes the softmax activation with optional masking. * @brief Computes the softmax activation with optional masking.
* *
@@ -813,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
* dst->op is expected to be `GGML_OP_STEP`. * dst->op is expected to be `GGML_OP_STEP`.
*/ */
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/** /**
* @brief Performs the Flash Attention extended operator using the CANN backend. * @brief Performs the Flash Attention extended operator using the CANN backend.
+56 -10
View File
@@ -1428,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
return false; return false;
} }
/**
* @brief Set a region of a tensor's device memory to a specified value.
*
* @param buffer The CANN buffer containing the tensor.
* @param tensor Pointer to the tensor whose memory will be set.
* @param value The value to which each byte in the region will be set.
* @param offset Byte offset within the tensor's data to start setting.
* @param size Number of bytes to set.
*/
static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
ggml_cann_set_device(ctx->device);
ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size));
}
/** /**
* @brief Clear a CANN buffer by setting all its memory to a specified value. * @brief Clear a CANN buffer by setting all its memory to a specified value.
* *
@@ -1454,7 +1470,7 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
/* .get_base = */ ggml_backend_cann_buffer_get_base, /* .get_base = */ ggml_backend_cann_buffer_get_base,
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
/* .memset_tensor = */ NULL, /* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
/* .set_tensor_2d = */ NULL, /* .set_tensor_2d = */ NULL,
@@ -1835,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_STEP:
ggml_cann_step(ctx, dst); ggml_cann_step(ctx, dst);
break; break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cann_softplus(ctx, dst);
break;
default: default:
return false; return false;
} }
@@ -1845,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
GGML_CANN_CALL_OP_UNARY_GATED(Relu); GGML_CANN_CALL_OP_UNARY_GATED(Relu);
break; break;
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh
break;
case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_ERF:
// aclnnGelu internally uses the erf-based approximation. ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
break; break;
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
GGML_CANN_CALL_OP_UNARY_GATED(Silu); ggml_cann_swiglu(ctx, dst);
break; break;
case GGML_GLU_OP_GEGLU_QUICK: case GGML_GLU_OP_GEGLU_QUICK:
{ ggml_cann_geglu_quick(ctx, dst);
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_op_unary_gated(lambda, ctx, dst);
}
break; break;
default: default:
return false; return false;
@@ -1920,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_CPY: case GGML_OP_CPY:
ggml_cann_cpy(ctx, dst); ggml_cann_cpy(ctx, dst);
break; break;
case GGML_OP_SET:
ggml_cann_set(ctx, dst);
break;
case GGML_OP_CONT: case GGML_OP_CONT:
ggml_cann_dup(ctx, dst); ggml_cann_dup(ctx, dst);
break; break;
@@ -1989,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst); ggml_cann_ssm_conv(ctx, dst);
break; break;
case GGML_OP_CUMSUM:
ggml_cann_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cann_tri(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cann_fill(ctx, dst);
break;
case GGML_OP_DIAG:
ggml_cann_diag(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cann_solve_tri(ctx, dst);
break;
default: default:
return false; return false;
} }
@@ -2324,6 +2357,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
if (use_cann_graph) { if (use_cann_graph) {
// If no matching graph is found, the graph needs to be recaptured. // If no matching graph is found, the graph needs to be recaptured.
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
if (graph_capture_required) { if (graph_capture_required) {
// If no matching graph is found, add a new ACL graph. // If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
@@ -2382,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_SOFTPLUS:
return true; return true;
default: default:
return false; return false;
@@ -2572,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_SET:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return true; return true;
case GGML_OP_PAD: case GGML_OP_PAD:
@@ -2649,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
} }
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
return true; return true;
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_FILL:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIAG:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOLVE_TRI:
return op->src[0]->type == GGML_TYPE_F32;
default: default:
return false; return false;
} }