mtmd: add clip_graph::build_mm() (#20751)
* clip: add build_mm() * apply to all models * add TODO for bias overload
This commit is contained in:
+21
-21
@@ -70,17 +70,17 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
|
||||
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
|
||||
if (layer.q_b) {
|
||||
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
|
||||
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
|
||||
if (layer.k_b) {
|
||||
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
|
||||
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
|
||||
if (layer.v_b) {
|
||||
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
|
||||
}
|
||||
@@ -164,17 +164,17 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
|
||||
// llava projector
|
||||
if (proj_type == PROJECTOR_TYPE_MLP) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = build_mm(model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
if (model.mm_2_w) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||
embeddings = build_mm(model.mm_2_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||
}
|
||||
}
|
||||
else if (proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = build_mm(model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
||||
// First LayerNorm
|
||||
@@ -186,7 +186,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
|
||||
embeddings = build_mm(model.mm_3_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
|
||||
|
||||
// Second LayerNorm
|
||||
@@ -197,10 +197,10 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
else if (proj_type == PROJECTOR_TYPE_LDP) {
|
||||
// MobileVLM projector
|
||||
int n_patch = 24;
|
||||
ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
|
||||
ggml_tensor * mlp_1 = build_mm(model.mm_model_mlp_1_w, embeddings);
|
||||
mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
|
||||
mlp_1 = ggml_gelu(ctx0, mlp_1);
|
||||
ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
|
||||
ggml_tensor * mlp_3 = build_mm(model.mm_model_mlp_3_w, mlp_1);
|
||||
mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
|
||||
// mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
|
||||
|
||||
@@ -229,10 +229,10 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
|
||||
// pointwise conv
|
||||
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_1_block_1_fc1_w, block_1);
|
||||
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
|
||||
block_1 = ggml_relu(ctx0, block_1);
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_1_block_1_fc2_w, block_1);
|
||||
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
|
||||
block_1 = ggml_hardsigmoid(ctx0, block_1);
|
||||
// block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
|
||||
@@ -244,7 +244,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
|
||||
|
||||
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_1_block_2_0_w, block_1);
|
||||
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
|
||||
|
||||
// block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
|
||||
@@ -277,10 +277,10 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
// block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
|
||||
// pointwise conv
|
||||
block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_2_block_1_fc1_w, block_1);
|
||||
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
|
||||
block_1 = ggml_relu(ctx0, block_1);
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_2_block_1_fc2_w, block_1);
|
||||
block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
|
||||
block_1 = ggml_hardsigmoid(ctx0, block_1);
|
||||
|
||||
@@ -292,7 +292,7 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
|
||||
block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
|
||||
// block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
|
||||
block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
|
||||
block_1 = build_mm(model.mm_model_block_2_block_2_0_w, block_1);
|
||||
block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
|
||||
|
||||
|
||||
@@ -307,10 +307,10 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
else if (proj_type == PROJECTOR_TYPE_LDPV2)
|
||||
{
|
||||
int n_patch = 24;
|
||||
ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
||||
ggml_tensor * mlp_0 = build_mm(model.mm_model_mlp_0_w, embeddings);
|
||||
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
||||
mlp_0 = ggml_gelu(ctx0, mlp_0);
|
||||
ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
|
||||
ggml_tensor * mlp_2 = build_mm(model.mm_model_mlp_2_w, mlp_0);
|
||||
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
|
||||
// mlp_2 ne = [2048, 576, 1, 1]
|
||||
// // AVG Pool Layer 2*2, strides = 2
|
||||
@@ -344,15 +344,15 @@ ggml_cgraph * clip_graph_llava::build() {
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
|
||||
// GLU
|
||||
{
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
||||
embeddings = build_mm(model.mm_model_mlp_0_w, embeddings);
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
|
||||
embeddings = ggml_gelu_inplace(ctx0, embeddings);
|
||||
ggml_tensor * x = embeddings;
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
||||
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
||||
embeddings = build_mm(model.mm_model_mlp_2_w, embeddings);
|
||||
x = build_mm(model.mm_model_mlp_1_w,x);
|
||||
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
||||
embeddings = build_mm(model.mm_model_mlp_3_w, embeddings);
|
||||
}
|
||||
// arrangement of BOI/EOI token embeddings
|
||||
// note: these embeddings are not present in text model, hence we cannot process them as text tokens
|
||||
|
||||
Reference in New Issue
Block a user