CANN: support gated linear attn (#18653)

* CANN: support gated linear attn

This change adds support for the GGML_OP_GATED_LINEAR_ATTN operator.
The feature was implemented by YushengZhao. Because the previous
submission was based on an outdated codebase, this PR was rebased to
merge.

Co-authored-by: YushengZhao <yusheng.chao@outlook.com>
Co-authored-by: hipudding <huafengchun@gmail.com>

* CANN: optimize OP gla

Optimize gla for high preformance

* Remove unused comments

---------

Co-authored-by: 赵禹昇 <2501112001@cninfer02.localdomain>
Co-authored-by: YushengZhao <yusheng.chao@outlook.com>
This commit is contained in:
hipudding
2026-01-16 16:18:49 +08:00
committed by GitHub
parent 785a710085
commit baa4ba0aec
3 changed files with 186 additions and 161 deletions
+107 -41
View File
@@ -58,6 +58,7 @@
#include <aclnnop/aclnn_mean.h> #include <aclnnop/aclnn_mean.h>
#include <aclnnop/aclnn_mm.h> #include <aclnnop/aclnn_mm.h>
#include <aclnnop/aclnn_mul.h> #include <aclnnop/aclnn_mul.h>
#include <aclnnop/aclnn_mv.h>
#include <aclnnop/aclnn_permute.h> #include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow.h> #include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h> #include <aclnnop/aclnn_pow_tensor_tensor.h>
@@ -2346,12 +2347,13 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
if (ctx.rope_cache.yarn_ramp_cache != nullptr) { if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
} }
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
ACL_MEM_MALLOC_HUGE_FIRST));
// -rope_yarn_ramp // -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1; // return MIN(1, MAX(0, y)) - 1;
acl_yarn_ramp_tensor = acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1; float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
} else { } else {
acl_yarn_ramp_tensor = acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); theta_scale_ne, theta_scale_nb, 1);
} }
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) { if (ext_factor != 0) {
@@ -2991,12 +2993,12 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get()); GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
} }
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1]; ggml_tensor * src1 = dst->src[1];
// stride // stride
int64_t s0 = ((const int32_t*)(dst->op_params))[0]; int64_t s0 = ((const int32_t *) (dst->op_params))[0];
acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
@@ -3017,9 +3019,9 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
int64_t strideVal[1]; int64_t strideVal[1];
strideVal[0] = s0; strideVal[0] = s0;
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = {0}; int64_t paddingVal[] = { 0 };
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = {1}; int64_t dilationVal[] = { 1 };
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
bool transposed = true; bool transposed = true;
int64_t groups = 1; int64_t groups = 1;
@@ -3049,19 +3051,18 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
// set zero to destination // set zero to destination
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
for(int k = 0; k < part_num; k++){ for (int k = 0; k < part_num; k++) {
// create part kernel tensor and slice from big kernel // create part kernel tensor and slice from big kernel
slice_start = max_kernel_size * k; slice_start = max_kernel_size * k;
if(k == part_num - 1){ if (k == part_num - 1) {
slice_end = kernel_size; slice_end = kernel_size;
interval = kernel_size - max_kernel_size * k; interval = kernel_size - max_kernel_size * k;
}else{ } else {
slice_end = max_kernel_size * (k+1); slice_end = max_kernel_size * (k + 1);
} }
int64_t part_ne[4]; int64_t part_ne[4];
for(int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
part_ne[i] = *(src0->ne + i); part_ne[i] = *(src0->ne + i);
} }
part_ne[0] = interval; part_ne[0] = interval;
@@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
ggml_cann_pool_alloc part_kernel_allocator; ggml_cann_pool_alloc part_kernel_allocator;
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]); part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
void* part_kernel_buf = part_kernel_allocator.get(); void * part_kernel_buf = part_kernel_allocator.get();
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL); part_ne, part_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get()); GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
part_kernel.get());
// create the part conv result tensor // create the part conv result tensor
int64_t part_dst_ne[4]; int64_t part_dst_ne[4];
for(int i = 0; i < 4; i++){ for (int i = 0; i < 4; i++) {
part_dst_ne[i] = *(dst->ne + i); part_dst_ne[i] = *(dst->ne + i);
} }
part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1; part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
@@ -3095,7 +3097,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
} }
ggml_cann_pool_alloc part_dst_allocator; ggml_cann_pool_alloc part_dst_allocator;
part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]); part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
void* part_dst_buf = part_dst_allocator.get(); void * part_dst_buf = part_dst_allocator.get();
acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst), acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
@@ -3103,7 +3105,8 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
// compute part conv transpose 1d // compute part conv transpose 1d
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(), GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType); padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
cubeMathType);
// compute the position of part result in final result // compute the position of part result in final result
int64_t global_start = slice_start; int64_t global_start = slice_start;
@@ -3112,7 +3115,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
left_pad_len = global_start; left_pad_len = global_start;
right_pad_len = dst_len - global_end; right_pad_len = dst_len - global_end;
std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len}; std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len };
acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
acl_scalar_ptr pad_value = nullptr; acl_scalar_ptr pad_value = nullptr;
@@ -3120,7 +3123,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
int64_t conv_result_ne[4]; int64_t conv_result_ne[4];
for(int i = 0; i < 4; i++){ for (int i = 0; i < 4; i++) {
conv_result_ne[i] = *(dst->ne + i); conv_result_ne[i] = *(dst->ne + i);
} }
@@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
ggml_cann_pool_alloc conv_result_allocator; ggml_cann_pool_alloc conv_result_allocator;
conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]); conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
void* conv_result_buf = conv_result_allocator.get(); void * conv_result_buf = conv_result_allocator.get();
acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst), acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get()); GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
} }
} }
@@ -3749,8 +3753,8 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// so we can reuse nb0 and nb1, and set nb2 = nb1. // so we can reuse nb0 and nb1, and set nb2 = nb1.
size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
acl_tensor_ptr acl_w = ggml_cann_create_tensor( acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
// 3) Output: dst is { d_inner, n_t, n_s } (CLN) // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
// //
@@ -3769,10 +3773,11 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// nb_y[1] = sizeof(float); // step in C // nb_y[1] = sizeof(float); // step in C
// nb_y[2] = nr * n_t * sizeof(float); // step in N // nb_y[2] = nr * n_t * sizeof(float); // step in N
int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t] size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
dst->nb[3] }; // [nr, 1, nr * n_t]
acl_tensor_ptr acl_y = ggml_cann_create_tensor( acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
// --- Conv1d parameters: depthwise, stride 1, no padding ("valid") --- // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
int64_t strideVal[1] = { 1 }; int64_t strideVal[1] = { 1 };
@@ -3791,22 +3796,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
cubeMathType = 1; cubeMathType = 1;
#endif #endif
GGML_CANN_CALL_ACLNN_OP(ctx, GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
Convolution,
acl_x.get(), // input: N, C, L_in = ncs acl_x.get(), // input: N, C, L_in = ncs
acl_w.get(), // weight: [C, 1, K] with groups=nr acl_w.get(), // weight: [C, 1, K] with groups=nr
nullptr, // bias nullptr, // bias
stride.get(), stride.get(), padding.get(), dilation.get(), transposed,
padding.get(),
dilation.get(),
transposed,
padding.get(), // output padding (unused for non-transposed) padding.get(), // output padding (unused for non-transposed)
groups, groups, acl_y.get(), cubeMathType);
acl_y.get(),
cubeMathType);
} }
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
ggml_tensor * add_node, ggml_tensor * add_node,
ggml_tensor * rms_norm_node) { ggml_tensor * rms_norm_node) {
@@ -3860,3 +3858,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
eps, // double type eps, // double type
acl_yout.get(), acl_rstd.get(), acl_xout.get()); acl_yout.get(), acl_rstd.get(), acl_xout.get());
} }
void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * k = dst->src[0];
ggml_tensor * v = dst->src[1];
ggml_tensor * q = dst->src[2];
ggml_tensor * g = dst->src[3];
ggml_tensor * s = dst->src[4];
int64_t B = dst->src[4]->ne[1];
int64_t T = dst->src[0]->ne[2];
int64_t H = dst->src[0]->ne[1];
int64_t C = dst->ne[0];
int64_t D = C / H;
int64_t L = T / B;
int64_t ne_qkg[2] = { 1, D };
int64_t ne_s[2] = { D, D };
int64_t ne_st[2] = { ne_s[1], ne_s[0] };
int64_t ne_vo[2] = { D, 1 };
int64_t ne_q[1] = { D };
size_t nb_base = ggml_type_size(k->type);
size_t nb_qkg[2] = { nb_base, nb_base };
size_t nb_s[2] = { nb_base, D * nb_base };
size_t nb_st[2] = { nb_s[1], nb_s[0] };
size_t nb_vo[2] = { nb_base, D * nb_base };
size_t nb_q[1] = { nb_base };
const float scale = ggml_get_op_params_f32(dst, 0);
acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
cann_copy(ctx, acl_s.get(), new_state.get());
for (int64_t b = 0; b < B; b++) {
for (int64_t h = 0; h < H; h++) {
size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
// D * D
acl_tensor_ptr acl_s_new =
ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
acl_tensor_ptr acl_s_new_t =
ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
for (int64_t l = 0; l < L; l++) {
size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
// D * 1
acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
// D
acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
// 1 * D
acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
// D
acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
// k ⊗ v
size_t buf_size = D * D * nb_base;
ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor(
buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
//s_new = g ⊗ s_old + k ⊗ v
aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
// compute output
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
}
}
}
}
+15 -60
View File
@@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/ */
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/*
* @brief A generic wrapper for ACL resources with custom deleter support.
*/
using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
/** /**
* @brief Trait structure used to define how to destroy a given ACL resource type. * @brief Forward Gated Linear Attention on the CANN backend.
* *
* @tparam T ACL resource type. * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions:
*/ * k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H
template <typename T> struct acl_resource_traits; * s: initial state [B, H, D, D], where B is batch and D=C/H
* dst holds both outputs (o) and updated state; a scale factor is read from op params.
/**
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
*/
template <> struct acl_resource_traits<aclTensor> {
static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
};
/**
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
*/
template <> struct acl_resource_traits<aclIntArray> {
static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
};
/**
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
*/
template <> struct acl_resource_traits<aclScalar> {
static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
};
/**
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
*/
template <> struct acl_resource_traits<aclTensorList> {
static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
};
/**
* @brief Creates a generic ACL resource wrapper with proper destruction logic.
* *
* @tparam T ACL resource type. * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale.
* @param ptr Raw pointer to ACL resource.
* @return any_acl_resource Smart pointer that handles destruction.
*/
template <typename T> any_acl_resource make_acl_resource(T * ptr) {
return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
}
/**
* @brief Registers multiple ACL resources into a vector for lifetime management.
* *
* @tparam Args Variadic list of ACL resource types. * @param ctx Backend context providing stream/allocator utilities.
* @param vec Target vector to hold ACL resources. * @param dst Output tensor; src deps are k, v, q, g, s as above.
* @param args Raw pointers to ACL resources.
*/ */
template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) { void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);
(vec.emplace_back(make_acl_resource(args)), ...);
}
/** /**
* @brief Launches an asynchronous task using the memory allocator. * @brief Launches an asynchronous task using the memory allocator.
@@ -894,7 +847,7 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
* same stream are executed in queue order. * same stream are executed in queue order.
*/ */
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ # define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \ do { \
uint64_t workspaceSize = 0; \ uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \ aclOpExecutor * executor; \
@@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
* and epsilon parameter. * and epsilon parameter.
*/ */
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node); void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
ggml_tensor * add_node,
ggml_tensor * rms_norm_node);
/** /**
* @brief Check whether a tensor is a weight tensor for matrix multiplication. * @brief Check whether a tensor is a weight tensor for matrix multiplication.
@@ -1104,7 +1059,7 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
* @see ggml_cann_op_unary * @see ggml_cann_op_unary
* @see GGML_CANN_CALL_ACLNN_OP * @see GGML_CANN_CALL_ACLNN_OP
*/ */
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ # define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
do { \ do { \
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
@@ -1133,7 +1088,7 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
* @see ggml_cann_op_unary_gated * @see ggml_cann_op_unary_gated
* @see GGML_CANN_CALL_ACLNN_OP * @see GGML_CANN_CALL_ACLNN_OP
*/ */
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ # define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
do { \ do { \
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
+4
View File
@@ -1889,6 +1889,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst); ggml_cann_out_prod(ctx, dst);
break; break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cann_gated_linear_attn(ctx, dst);
break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst); ggml_cann_ssm_conv(ctx, dst);
break; break;
@@ -2454,6 +2457,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_GATED_LINEAR_ATTN:
return true; return true;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
{ {