CANN: support gated linear attn (#18653)
* CANN: support gated linear attn This change adds support for the GGML_OP_GATED_LINEAR_ATTN operator. The feature was implemented by YushengZhao. Because the previous submission was based on an outdated codebase, this PR was rebased to merge. Co-authored-by: YushengZhao <yusheng.chao@outlook.com> Co-authored-by: hipudding <huafengchun@gmail.com> * CANN: optimize OP gla Optimize gla for high preformance * Remove unused comments --------- Co-authored-by: 赵禹昇 <2501112001@cninfer02.localdomain> Co-authored-by: YushengZhao <yusheng.chao@outlook.com>
This commit is contained in:
@@ -58,6 +58,7 @@
|
|||||||
#include <aclnnop/aclnn_mean.h>
|
#include <aclnnop/aclnn_mean.h>
|
||||||
#include <aclnnop/aclnn_mm.h>
|
#include <aclnnop/aclnn_mm.h>
|
||||||
#include <aclnnop/aclnn_mul.h>
|
#include <aclnnop/aclnn_mul.h>
|
||||||
|
#include <aclnnop/aclnn_mv.h>
|
||||||
#include <aclnnop/aclnn_permute.h>
|
#include <aclnnop/aclnn_permute.h>
|
||||||
#include <aclnnop/aclnn_pow.h>
|
#include <aclnnop/aclnn_pow.h>
|
||||||
#include <aclnnop/aclnn_pow_tensor_tensor.h>
|
#include <aclnnop/aclnn_pow_tensor_tensor.h>
|
||||||
@@ -2346,12 +2347,13 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
|||||||
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
|
if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
|
||||||
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
|
ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
|
||||||
}
|
}
|
||||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
|
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
|
||||||
|
ACL_MEM_MALLOC_HUGE_FIRST));
|
||||||
// -rope_yarn_ramp
|
// -rope_yarn_ramp
|
||||||
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||||
// return MIN(1, MAX(0, y)) - 1;
|
// return MIN(1, MAX(0, y)) - 1;
|
||||||
acl_yarn_ramp_tensor =
|
acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
|
||||||
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
theta_scale_ne, theta_scale_nb, 1);
|
||||||
float zero_value = 0, one_value = 1;
|
float zero_value = 0, one_value = 1;
|
||||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||||
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||||
@@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
|
|||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
|
||||||
} else {
|
} else {
|
||||||
acl_yarn_ramp_tensor =
|
acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
|
||||||
ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
|
theta_scale_ne, theta_scale_nb, 1);
|
||||||
}
|
}
|
||||||
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
|
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
|
||||||
if (ext_factor != 0) {
|
if (ext_factor != 0) {
|
||||||
@@ -3050,7 +3052,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
|
||||||
|
|
||||||
for (int k = 0; k < part_num; k++) {
|
for (int k = 0; k < part_num; k++) {
|
||||||
|
|
||||||
// create part kernel tensor and slice from big kernel
|
// create part kernel tensor and slice from big kernel
|
||||||
slice_start = max_kernel_size * k;
|
slice_start = max_kernel_size * k;
|
||||||
if (k == part_num - 1) {
|
if (k == part_num - 1) {
|
||||||
@@ -3076,10 +3077,11 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||||||
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
|
part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
|
||||||
void * part_kernel_buf = part_kernel_allocator.get();
|
void * part_kernel_buf = part_kernel_allocator.get();
|
||||||
|
|
||||||
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
|
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
|
||||||
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
|
part_ne, part_nb, 3, ACL_FORMAT_NCL);
|
||||||
|
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
|
||||||
|
part_kernel.get());
|
||||||
|
|
||||||
// create the part conv result tensor
|
// create the part conv result tensor
|
||||||
int64_t part_dst_ne[4];
|
int64_t part_dst_ne[4];
|
||||||
@@ -3103,7 +3105,8 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||||||
|
|
||||||
// compute part conv transpose 1d
|
// compute part conv transpose 1d
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
|
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
|
||||||
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
|
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
|
||||||
|
cubeMathType);
|
||||||
|
|
||||||
// compute the position of part result in final result
|
// compute the position of part result in final result
|
||||||
int64_t global_start = slice_start;
|
int64_t global_start = slice_start;
|
||||||
@@ -3138,7 +3141,8 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||||||
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
|
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
|
||||||
|
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
|
||||||
|
conv_result.get());
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
|
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3749,8 +3753,8 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||||||
// so we can reuse nb0 and nb1, and set nb2 = nb1.
|
// so we can reuse nb0 and nb1, and set nb2 = nb1.
|
||||||
size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
|
size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
|
||||||
|
|
||||||
acl_tensor_ptr acl_w = ggml_cann_create_tensor(
|
acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
|
||||||
src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
|
ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
|
||||||
|
|
||||||
// 3) Output: dst is { d_inner, n_t, n_s } (CLN)
|
// 3) Output: dst is { d_inner, n_t, n_s } (CLN)
|
||||||
//
|
//
|
||||||
@@ -3769,10 +3773,11 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||||||
// nb_y[1] = sizeof(float); // step in C
|
// nb_y[1] = sizeof(float); // step in C
|
||||||
// nb_y[2] = nr * n_t * sizeof(float); // step in N
|
// nb_y[2] = nr * n_t * sizeof(float); // step in N
|
||||||
int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
|
int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
|
||||||
size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
|
size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
|
||||||
|
dst->nb[3] }; // [nr, 1, nr * n_t]
|
||||||
|
|
||||||
acl_tensor_ptr acl_y = ggml_cann_create_tensor(
|
acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
|
||||||
dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
|
ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
|
||||||
|
|
||||||
// --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
|
// --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
|
||||||
int64_t strideVal[1] = { 1 };
|
int64_t strideVal[1] = { 1 };
|
||||||
@@ -3791,22 +3796,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||||||
cubeMathType = 1;
|
cubeMathType = 1;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
GGML_CANN_CALL_ACLNN_OP(ctx,
|
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
|
||||||
Convolution,
|
|
||||||
acl_x.get(), // input: N, C, L_in = ncs
|
acl_x.get(), // input: N, C, L_in = ncs
|
||||||
acl_w.get(), // weight: [C, 1, K] with groups=nr
|
acl_w.get(), // weight: [C, 1, K] with groups=nr
|
||||||
nullptr, // bias
|
nullptr, // bias
|
||||||
stride.get(),
|
stride.get(), padding.get(), dilation.get(), transposed,
|
||||||
padding.get(),
|
|
||||||
dilation.get(),
|
|
||||||
transposed,
|
|
||||||
padding.get(), // output padding (unused for non-transposed)
|
padding.get(), // output padding (unused for non-transposed)
|
||||||
groups,
|
groups, acl_y.get(), cubeMathType);
|
||||||
acl_y.get(),
|
|
||||||
cubeMathType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
|
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
|
||||||
ggml_tensor * add_node,
|
ggml_tensor * add_node,
|
||||||
ggml_tensor * rms_norm_node) {
|
ggml_tensor * rms_norm_node) {
|
||||||
@@ -3860,3 +3858,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
|
|||||||
eps, // double type
|
eps, // double type
|
||||||
acl_yout.get(), acl_rstd.get(), acl_xout.get());
|
acl_yout.get(), acl_rstd.get(), acl_xout.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_tensor * k = dst->src[0];
|
||||||
|
ggml_tensor * v = dst->src[1];
|
||||||
|
ggml_tensor * q = dst->src[2];
|
||||||
|
ggml_tensor * g = dst->src[3];
|
||||||
|
ggml_tensor * s = dst->src[4];
|
||||||
|
|
||||||
|
int64_t B = dst->src[4]->ne[1];
|
||||||
|
int64_t T = dst->src[0]->ne[2];
|
||||||
|
int64_t H = dst->src[0]->ne[1];
|
||||||
|
int64_t C = dst->ne[0];
|
||||||
|
int64_t D = C / H;
|
||||||
|
int64_t L = T / B;
|
||||||
|
|
||||||
|
int64_t ne_qkg[2] = { 1, D };
|
||||||
|
int64_t ne_s[2] = { D, D };
|
||||||
|
int64_t ne_st[2] = { ne_s[1], ne_s[0] };
|
||||||
|
int64_t ne_vo[2] = { D, 1 };
|
||||||
|
int64_t ne_q[1] = { D };
|
||||||
|
size_t nb_base = ggml_type_size(k->type);
|
||||||
|
size_t nb_qkg[2] = { nb_base, nb_base };
|
||||||
|
size_t nb_s[2] = { nb_base, D * nb_base };
|
||||||
|
size_t nb_st[2] = { nb_s[1], nb_s[0] };
|
||||||
|
size_t nb_vo[2] = { nb_base, D * nb_base };
|
||||||
|
size_t nb_q[1] = { nb_base };
|
||||||
|
|
||||||
|
const float scale = ggml_get_op_params_f32(dst, 0);
|
||||||
|
|
||||||
|
acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
|
||||||
|
acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
|
||||||
|
cann_copy(ctx, acl_s.get(), new_state.get());
|
||||||
|
|
||||||
|
for (int64_t b = 0; b < B; b++) {
|
||||||
|
for (int64_t h = 0; h < H; h++) {
|
||||||
|
size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
|
||||||
|
// D * D
|
||||||
|
acl_tensor_ptr acl_s_new =
|
||||||
|
ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
|
||||||
|
acl_tensor_ptr acl_s_new_t =
|
||||||
|
ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
|
||||||
|
for (int64_t l = 0; l < L; l++) {
|
||||||
|
size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
|
||||||
|
// D * 1
|
||||||
|
acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
|
||||||
|
acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
|
||||||
|
// D
|
||||||
|
acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
|
||||||
|
// 1 * D
|
||||||
|
acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
|
||||||
|
// D
|
||||||
|
acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
|
||||||
|
// k ⊗ v
|
||||||
|
size_t buf_size = D * D * nb_base;
|
||||||
|
ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
|
||||||
|
acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor(
|
||||||
|
buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
|
||||||
|
aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
|
||||||
|
//s_new = g ⊗ s_old + k ⊗ v
|
||||||
|
aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
|
||||||
|
aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
|
||||||
|
// compute output
|
||||||
|
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
|
||||||
|
aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
|||||||
*/
|
*/
|
||||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
/*
|
|
||||||
* @brief A generic wrapper for ACL resources with custom deleter support.
|
|
||||||
*/
|
|
||||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Trait structure used to define how to destroy a given ACL resource type.
|
* @brief Forward Gated Linear Attention on the CANN backend.
|
||||||
*
|
*
|
||||||
* @tparam T ACL resource type.
|
* Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions:
|
||||||
*/
|
* k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H
|
||||||
template <typename T> struct acl_resource_traits;
|
* s: initial state [B, H, D, D], where B is batch and D=C/H
|
||||||
|
* dst holds both outputs (o) and updated state; a scale factor is read from op params.
|
||||||
/**
|
|
||||||
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
|
|
||||||
*/
|
|
||||||
template <> struct acl_resource_traits<aclTensor> {
|
|
||||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
|
|
||||||
*/
|
|
||||||
template <> struct acl_resource_traits<aclIntArray> {
|
|
||||||
static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
|
|
||||||
*/
|
|
||||||
template <> struct acl_resource_traits<aclScalar> {
|
|
||||||
static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
|
|
||||||
*/
|
|
||||||
template <> struct acl_resource_traits<aclTensorList> {
|
|
||||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Creates a generic ACL resource wrapper with proper destruction logic.
|
|
||||||
*
|
*
|
||||||
* @tparam T ACL resource type.
|
* The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale.
|
||||||
* @param ptr Raw pointer to ACL resource.
|
|
||||||
* @return any_acl_resource Smart pointer that handles destruction.
|
|
||||||
*/
|
|
||||||
template <typename T> any_acl_resource make_acl_resource(T * ptr) {
|
|
||||||
return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Registers multiple ACL resources into a vector for lifetime management.
|
|
||||||
*
|
*
|
||||||
* @tparam Args Variadic list of ACL resource types.
|
* @param ctx Backend context providing stream/allocator utilities.
|
||||||
* @param vec Target vector to hold ACL resources.
|
* @param dst Output tensor; src deps are k, v, q, g, s as above.
|
||||||
* @param args Raw pointers to ACL resources.
|
|
||||||
*/
|
*/
|
||||||
template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) {
|
void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Launches an asynchronous task using the memory allocator.
|
* @brief Launches an asynchronous task using the memory allocator.
|
||||||
@@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
|||||||
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
|
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
|
||||||
* and epsilon parameter.
|
* and epsilon parameter.
|
||||||
*/
|
*/
|
||||||
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
|
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
|
||||||
|
ggml_tensor * add_node,
|
||||||
|
ggml_tensor * rms_norm_node);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
||||||
|
|||||||
@@ -1889,6 +1889,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
|||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
ggml_cann_out_prod(ctx, dst);
|
ggml_cann_out_prod(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
|
ggml_cann_gated_linear_attn(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
ggml_cann_ssm_conv(ctx, dst);
|
ggml_cann_ssm_conv(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -2454,6 +2457,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_PAD_REFLECT_1D:
|
case GGML_OP_PAD_REFLECT_1D:
|
||||||
case GGML_OP_COUNT_EQUAL:
|
case GGML_OP_COUNT_EQUAL:
|
||||||
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user