ggml: add GATED_DELTA_NET op (#19504)
* ggml: add GATED_DELTA_NET op * remove the transpose * add KDA * add qwen35 dense * llama : check for fused gated delta net backend support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -3665,6 +3665,51 @@ struct test_rwkv_wkv6 : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GATED_DELTA_NET
|
||||
struct test_gated_delta_net : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
const int v_repeat;
|
||||
const bool permuted;
|
||||
const bool kda;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda);
|
||||
}
|
||||
|
||||
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1,
|
||||
int v_repeat = 1, bool permuted = false, bool kda = false)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs),
|
||||
v_repeat(v_repeat), permuted(permuted), kda(kda) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q;
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
if (permuted) {
|
||||
// create with dims 1 and 2 swapped, then permute back to get non-contiguous layout
|
||||
q = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count * v_repeat, n_seqs), 0, 2, 1, 3);
|
||||
} else {
|
||||
q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
}
|
||||
const int64_t g_ne0 = kda ? head_size : 1;
|
||||
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GATED_LINEAR_ATTN
|
||||
struct test_gla : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -8405,6 +8450,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1, 1, true));
|
||||
// KDA (vector gate)
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
||||
Reference in New Issue
Block a user