Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 31ce8b1ae5 | |||
| 06f05e71c1 | |||
| eeb79b026b |
@@ -394,20 +394,30 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
|
|||||||
// fuse only ops that start with these operations
|
// fuse only ops that start with these operations
|
||||||
// can be expanded when needed
|
// can be expanded when needed
|
||||||
if (node.op() == GGML_OP_ADD ||
|
if (node.op() == GGML_OP_ADD ||
|
||||||
|
node.op() == GGML_OP_SUB ||
|
||||||
|
node.op() == GGML_OP_MUL ||
|
||||||
|
node.op() == GGML_OP_DIV ||
|
||||||
node.op() == GGML_OP_NORM ||
|
node.op() == GGML_OP_NORM ||
|
||||||
node.op() == GGML_OP_RMS_NORM) {
|
node.op() == GGML_OP_RMS_NORM) {
|
||||||
ops[0] = node.op();
|
ops[0] = node.op();
|
||||||
|
|
||||||
int f = i + 1;
|
int f = i + 1;
|
||||||
while (f < n && f < i + MAX_FUSE) {
|
while (f < n && f < i + MAX_FUSE) {
|
||||||
// conservatively allow fusing only these ops
|
// bin ops (ADD/SUB/MUL/DIV) must be same type to fuse
|
||||||
// can be expanded when needed
|
// NORM/RMS_NORM can chain with MUL/ADD
|
||||||
|
if (node.op() == GGML_OP_ADD ||
|
||||||
|
node.op() == GGML_OP_SUB ||
|
||||||
|
node.op() == GGML_OP_MUL ||
|
||||||
|
node.op() == GGML_OP_DIV) {
|
||||||
|
if (gf->nodes[f]->op != node.op()) break;
|
||||||
|
} else {
|
||||||
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||||
gf->nodes[f]->op != GGML_OP_MUL &&
|
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||||
gf->nodes[f]->op != GGML_OP_NORM &&
|
gf->nodes[f]->op != GGML_OP_NORM &&
|
||||||
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
ops[f - i] = gf->nodes[f]->op;
|
ops[f - i] = gf->nodes[f]->op;
|
||||||
f++;
|
f++;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -732,6 +732,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
|||||||
int nr1 = 1; // number of src1 rows per threadgroup
|
int nr1 = 1; // number of src1 rows per threadgroup
|
||||||
|
|
||||||
size_t smem = 0; // shared memory
|
size_t smem = 0; // shared memory
|
||||||
|
bool contig = false;
|
||||||
|
|
||||||
const ggml_type tsrc0 = op->src[0]->type;
|
const ggml_type tsrc0 = op->src[0]->type;
|
||||||
const ggml_type tsrc1 = op->src[1]->type;
|
const ggml_type tsrc1 = op->src[1]->type;
|
||||||
@@ -766,6 +767,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
|||||||
{
|
{
|
||||||
nsg = N_SG_Q4_0;
|
nsg = N_SG_Q4_0;
|
||||||
nr0 = N_R0_Q4_0;
|
nr0 = N_R0_Q4_0;
|
||||||
|
contig = ne00 >= 256;
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
{
|
{
|
||||||
@@ -877,7 +879,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
snprintf(base, 256, "kernel_mul_mv_%s_%s%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), contig ? "_c" : "", suffix);
|
||||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||||
|
|
||||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
|||||||
@@ -3118,19 +3118,15 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
int n_fuse = 1;
|
int n_fuse = 1;
|
||||||
|
|
||||||
// c[0] = add(a, b[0])
|
// c[0] = op(a, b[0])
|
||||||
// c[1] = add(c[0], b[1])
|
// c[1] = op(c[0], b[1])
|
||||||
// c[2] = add(c[1], b[2])
|
// c[2] = op(c[1], b[2])
|
||||||
// ...
|
// ...
|
||||||
if (use_fusion) {
|
if (use_fusion) {
|
||||||
fops[0] = GGML_OP_ADD;
|
ggml_op cur_op = op->op;
|
||||||
fops[1] = GGML_OP_ADD;
|
for (int i = 0; i < 8; ++i) {
|
||||||
fops[2] = GGML_OP_ADD;
|
fops[i] = cur_op;
|
||||||
fops[3] = GGML_OP_ADD;
|
}
|
||||||
fops[4] = GGML_OP_ADD;
|
|
||||||
fops[5] = GGML_OP_ADD;
|
|
||||||
fops[6] = GGML_OP_ADD;
|
|
||||||
fops[7] = GGML_OP_ADD;
|
|
||||||
|
|
||||||
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
||||||
// across splits. idx_end indicates the last node in the current split
|
// across splits. idx_end indicates the last node in the current split
|
||||||
@@ -3165,7 +3161,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||||||
++n_fuse;
|
++n_fuse;
|
||||||
|
|
||||||
if (debug_fusion > 1 && n_fuse > 1) {
|
if (debug_fusion > 1 && n_fuse > 1) {
|
||||||
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
GGML_LOG_DEBUG("%s: fuse: %s x %d\n", __func__, ggml_op_name(cur_op), n_fuse);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3533,6 +3533,94 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|||||||
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Q4_0 kernel with contiguous uint32_t weight reads (MLX-style)
|
||||||
|
// Each thread reads 4 contiguous uint32_t packs per block instead of
|
||||||
|
// 8 strided uint16_t reads, improving memory coalescing on Apple GPUs.
|
||||||
|
kernel void kernel_mul_mv_q4_0_f32_c(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short NSG = FC_mul_mv_nsg;
|
||||||
|
|
||||||
|
constexpr short NR0 = N_R0_Q4_0;
|
||||||
|
constexpr short NW = N_SIMDWIDTH;
|
||||||
|
|
||||||
|
const int nb = args.ne00 / QK4_0;
|
||||||
|
|
||||||
|
const int r0 = (tgpig.x * NSG + sgitg) * NR0;
|
||||||
|
const int r1 = tgpig.y;
|
||||||
|
const int im = tgpig.z;
|
||||||
|
|
||||||
|
const uint i12 = im % args.ne12;
|
||||||
|
const uint i13 = im / args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset1 = r1 * args.nb11 + (i12) * args.nb12 + (i13) * args.nb13;
|
||||||
|
|
||||||
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
|
|
||||||
|
device const block_q4_0 * ax[NR0];
|
||||||
|
FOR_UNROLL (int row = 0; row < NR0; ++row) {
|
||||||
|
const uint64_t offset0 = (r0 + row) * args.nb01 + (i12 / args.r2) * args.nb02 + (i13 / args.r3) * args.nb03;
|
||||||
|
ax[row] = (device const block_q4_0 *) ((device char *) src0 + offset0);
|
||||||
|
}
|
||||||
|
|
||||||
|
float sumf[NR0] = {0.f};
|
||||||
|
|
||||||
|
const short ix = (tiisg / (NW / 16));
|
||||||
|
const short il = (tiisg % (NW / 16)) * 8;
|
||||||
|
const int ib0 = ix;
|
||||||
|
|
||||||
|
const uint q_off = il / 8;
|
||||||
|
|
||||||
|
device const float * yb = y + ib0 * QK4_0 + il;
|
||||||
|
|
||||||
|
for (int ib = ib0; ib < nb; ib += 16) {
|
||||||
|
float sumy = 0.f;
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 8; i += 2) {
|
||||||
|
sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17];
|
||||||
|
}
|
||||||
|
|
||||||
|
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||||
|
const float d = ax[row][ib].d;
|
||||||
|
device const uint32_t * qs = (device const uint32_t *) (ax[row][ib].qs);
|
||||||
|
|
||||||
|
const uint32_t q0 = qs[q_off];
|
||||||
|
const uint32_t q1 = qs[q_off + 2];
|
||||||
|
|
||||||
|
float acc = 0.f;
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 8; i += 2) {
|
||||||
|
const uint ni = i / 2;
|
||||||
|
|
||||||
|
acc += ((q0 >> (4 * ni)) & 0xF) * yb[i + 0]
|
||||||
|
+ ((q0 >> (4 * (ni + 1))) & 0xF) * yb[i + 1]
|
||||||
|
+ ((q1 >> (4 * ni)) & 0xF) * yb[i + 16]
|
||||||
|
+ ((q1 >> (4 * (ni + 1))) & 0xF) * yb[i + 17];
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf[row] += d * (acc + sumy * -8.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
yb += QK4_0 * 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_f32 = (device float *) dst + im * args.ne0 * args.ne1 + r1 * args.ne0;
|
||||||
|
|
||||||
|
for (int row = 0; row < NR0; ++row) {
|
||||||
|
const float tot = simd_sum(sumf[row]);
|
||||||
|
|
||||||
|
if (tiisg == 0 && r0 + row < args.ne01) {
|
||||||
|
dst_f32[r0 + row] = tot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q4_1_f32(
|
kernel void kernel_mul_mv_q4_1_f32(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
|
|||||||
Reference in New Issue
Block a user