model : support Step3.5-Flash (#19283)

* Support Step3.5-Flash

* fix: norm.weight + 1 (HF zero_centered=true)

* step35: simplify GGUF conversion + drop redundant rope KVs

* Address review feedback

* rename limits -> clamp

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Apply suggestion from @CISC

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* rename swiglu limits -> swiglu clamp in LLM_KV

* avoid CI fail

* Apply suggestions from code review

* Apply suggestions from code review

* disabled KV shifting for LLM_ARCH_STEP35

* Apply suggestions from code review

* mistakenly removed cmath

* add model size && apply missed suggestion

* assert partial_rotary_factors

* fix CI errors:

* load freq_base_swa

---------

Co-authored-by: lvyichen <lvyichen@stepfun.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
forforever73
2026-02-07 04:06:14 +08:00
committed by GitHub
parent 3228e77287
commit b83111815e
15 changed files with 576 additions and 38 deletions
+41
View File
@@ -13,6 +13,8 @@
#include <cassert>
#include <cmath>
#include <cstring>
#include <numeric>
#include <sstream>
#include <unordered_set>
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
@@ -1014,6 +1016,26 @@ ggml_tensor * llm_graph_context::build_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate && type_gate == LLM_FFN_PAR) {
// Step35: HF clamps gate (after SiLU) and up before multiplication
if (arch == LLM_ARCH_STEP35 && il >= 0) {
const float limit = hparams.swiglu_clamp_shexp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_silu_clamped", il);
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
cb(tmp, "ffn_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, tmp);
cb(cur, "ffn_swiglu_limited", il);
type_gate = LLM_FFN_SEQ;
break;
}
}
cur = ggml_swiglu_split(ctx0, cur, tmp);
cb(cur, "ffn_swiglu", il);
type_gate = LLM_FFN_SEQ;
@@ -1316,6 +1338,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate_exps) {
// Step35: per-layer clamp for routed experts
if (arch == LLM_ARCH_STEP35 && il >= 0) {
const float limit = hparams.swiglu_clamp_exp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_moe_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_moe_silu_clamped", il);
up = ggml_clamp(ctx0, up, -limit, limit);
cb(up, "ffn_moe_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, up);
cb(cur, "ffn_moe_swiglu_limited", il);
break;
}
}
cur = ggml_swiglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_swiglu", il);
} else {