mtmd: Add DeepSeekOCR Support (#17400)

* mtmd: llama.cpp DeepSeekOCR support
init commit

* loading sam tensors

* mtmd: fix vision model processing

* deepseek-ocr clip-vit model impl

* mtmd: add DeepSeek-OCR LM support with standard attention

* mtmd: successfully runs DeepSeek-OCR LM in llama-cli

* mtmd: Fix RoPE type for DeepSeek-OCR LM.

* loading LM
testing Vision model loading

* sam warmup working

* sam erroneous return corrected

* clip-vit:  corrected cls_embd concat

* clip-vit: model convert  qkv_proj split

* corrected combining of image encoders' results

* fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model

* concat image_newline and image_seperator tokens

* visual_model warmup (technically) works

* window partitioning using standard ggml ops

* sam implementation without using CPU only ops

* clip: fixed warnings

* Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr

* mtmd: fix get_rel_pos

* mtmd: fixed the wrong scaler for get_rel_pos

* image encoding technically works but the output can't be checked singe image decoding fails

* mtmd: minor changed

* mtmd: add native resolution support

* - image encoding debugged
- issues fixed mainly related wrong config like n_patches etc.
- configs need to be corrected in the converter

* mtmd: correct token order

* - dynamic resizing
- changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4

* mtmd: quick fix token order

* mtmd: fix danling pointer

* mtmd: SAM numerically works

* mtmd: debug CLIP-L (vit_pre_ln)

* mtmd: debug CLIP-L & first working DeepSeek-OCR model

* mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution control & all native resolution modes work

* mtmd: simplify SAM patch embedding

* mtmd: adapt Pillow image resizing function

* mtmd:  simplify DeepSeek-OCR dynamic resolution preprocessing

* mtmd: remove --dsocr-mode argument

* mtmd: refactor code & remove unused helper functions

* mtmd: fix tensor names for image newlines and view separator

* clean up

* reverting automatically removed spaces

* reverting automatically removed spaces

* mtmd: fixed bad ocr check in Deepseek2 (LM)

* mtmd: support combined QKV projection in buid_vit

* using common build_attn in sam

* corrected code-branch when flash-attn disabled
enabling usage of --flash-attn option

* mtmd: minor fix

* minor formatting and style

* fixed flake8 lint issues

* minor editorconfig-check fixes

* minor editorconfig-check fixes

* mtmd: simplify get_rel_pos

* mtmd: make sam hparams configurable

* mtmd: add detailed comments for resize_bicubic_pillow

* mtmd: fixed wrong input setting

* mtmd: convert model in FP16

* mtmd: minor fix

* mtmd: remove tweak to llama-mtmd-cli & deepseek-ocr template

* fix: test-1.jpg ORC issue with small (640) resolution
setting min-resolution base (1024) max large (1280) for dynamic-resolution

* minor: editconfig-check fix

* merge with changes from https://github.com/ggml-org/llama.cpp/pull/17909
added new opt to tests.sh to disable flash-attn

* minor: editconfig-check fix

* testing deepseek-ocr
quick and dirty test script comparing results of Qwen2.5-VL vs DeepSeek-OCR

* quick and (potential) dirty merge with https://github.com/ggml-org/llama.cpp/pull/17909

* refactoring, one single builder function and static helpers

* added deepseek-ocr test to tests.sh

* minor formatting fixes

* check with fixed expected resutls

* minor formatting

* editorconfig-check fix

* merge with changes from https://github.com/ggml-org/llama.cpp/pull/18042

* minor
- added GLM-4.6V to big tests
- added missing deps for python test

* convert: minor fix

* mtmd: format code

* convert: quick fix

* convert: quick fix

* minor python formatting

* fixed merge build issue

* merge resolved
- fixed issues in convert
- tested several deepseek models

* minor fix

* minor

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* - removed clip_is_deepseekocr
- removed redundant RESIZE_ALGO_BICUBIC_PILLOW resize-algo
- simplified image-preprocessing
- removed/simplified debug functions

* - cleaning commented out code

* fixing instabilities issues reintroducing resize_bicubic_pillow

* - use f16 model for deepseek-ocr test
- ignore llama-arch test for deepseek-ocr

* rename fc_w --> mm_fc_w

* add links to OCR discussion

* cleaner loading code

* add missing .weight to some tensors

* add default jinja template (to be used by server)

* move test model to ggml-org

* rolling back upscale change

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: bluebread <hotbread70127@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
This commit is contained in:
Saba Fallah
2026-03-25 19:57:40 +01:00
committed by GitHub
parent 056b50c319
commit a970515bdb
30 changed files with 1569 additions and 27 deletions
+103 -10
View File
@@ -947,6 +947,9 @@ class ModelBase:
if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]
if "language_config" in config:
# rename for DeepSeekOCR
config["text_config"] = config["language_config"]
if "lfm" in config:
# rename for LFM2-Audio
config["text_config"] = config["lfm"]
@@ -2074,7 +2077,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers", "encoder_layers", "vt_num_hidden_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@@ -6938,6 +6941,68 @@ class ConformerAudioModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
# calculate proj_scale_factor (used by tinygemma3 test model)
image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
n_per_side = int(image_seq_length ** 0.5)
image_size = self.hparams["image_size"]
patch_size = self.hparams["patch_size"]
proj_scale_factor = (image_size // patch_size) // n_per_side
if proj_scale_factor > 0 and proj_scale_factor != 4:
# we only need to write this if it's not the default value
# in this case, we are converting a test model
self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)
# @bluebread: there's no window_size in config but just add it here anyway
self.gguf_writer.add_vision_window_size(self.hparams.get("window_size", 14))
# SAM configuration
sam_hparams = hparams['sam']
self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers'])
self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width'])
self.gguf_writer.add_vision_sam_head_count(sam_hparams['heads'])
def get_vision_config(self) -> dict[str, Any]:
vision_config: dict[str, Any] | None = self.global_config.get("vision_config")
if not vision_config:
raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found")
vision_config['sam'] = vision_config['width']['sam_vit_b']
vision_config.update(vision_config['width']['clip-l-14-224'])
vision_config['hidden_size'] = vision_config['width']
vision_config['num_heads'] = vision_config['heads']
vision_config['intermediate_size'] = vision_config['heads'] * 4
return vision_config
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".embeddings." in name or 'pos_embed' in name:
return gguf.GGMLQuantizationType.F32
if ".rel_pos_h" in name or '.rel_pos_w' in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision-related tensors, skip language model tensors
# Vision components: sam_model, vision_model, projector, image_newline, view_seperator
# Language model components to skip: lm_head, embed_tokens, layers, norm
if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")):
return
if name.endswith("pos_embed") or name.endswith("rel_pos_h") or name.endswith("rel_pos_w"):
name += ".weight"
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3nVisionAudioModel(ConformerAudioModel):
has_audio_encoder = True
@@ -8283,6 +8348,19 @@ class DeepseekV2Model(TextModel):
merge_expert = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams: dict = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# special handling for Deepseek OCR
if self.origin_hf_arch == "DeepseekOCRForCausalLM":
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
# default jinja template
self.gguf_writer.add_chat_template("{% for m in messages %}{{m['content']}}{% endfor %}")
def set_vocab(self):
try:
self._set_vocab_gpt2()
@@ -8338,9 +8416,15 @@ class DeepseekV2Model(TextModel):
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")
def set_gguf_parameters(self):
is_ocr = (self.model_arch == gguf.MODEL_ARCH.DEEPSEEK2OCR)
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1
if is_ocr:
self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0)
else:
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1
self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6)
super().set_gguf_parameters()
hparams = self.hparams
@@ -8354,16 +8438,18 @@ class DeepseekV2Model(TextModel):
# Default: if no MoE, all layers are dense; if MoE, none are dense
first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
kv_lora_rank = hparams.get("kv_lora_rank", 512)
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
if not is_ocr:
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(kv_lora_rank)
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
# MoE parameters (required by C++ code for DEEPSEEK2 arch)
# For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length
@@ -8395,8 +8481,15 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5, and DeepSeek-OCR
if ("vision_tower" in name
or "multi_modal_projector" in name
or "mm_projector" in name
or "vision_model" in name
or "image_newline" in name
or "model.projector" in name
or "sam_model" in name
or "view_seperator" in name):
return
if name.startswith("siglip2.") or name.startswith("merger."):
return