common/parser: fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers (#21230)

* Fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers

* Rename

* Update common/chat-auto-parser-generator.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Piotr Wilkin (ilintar)
2026-04-03 17:51:52 +02:00
committed by GitHub
parent af5c13841f
commit f1f793ad06
7 changed files with 242 additions and 153 deletions
+35 -13
View File
@@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB";
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
static const std::string CALL_ID_001 = "call00001";
static const std::string CALL_ID_002 = "call00002";
static const std::string CALL_ID_999 = "call99999";
static std::vector<std::function<void(const common_chat_template & tmpl, autoparser &)>> workarounds(
{ // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to
@@ -131,6 +134,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
analysis.tools.function.name_prefix = "<tool▁sep>";
analysis.tools.format.per_call_end = "<tool▁call▁end>";
analysis.tools.function.close = "```";
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
}
}
});
@@ -158,7 +162,7 @@ static json user_msg = json{
{ "content", USER_MSG }
};
static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") {
static json build_tool_call(const std::string & name, const json & args, const std::string & id = CALL_ID_001) {
return json{
{ "id", id },
{ "type", "function" },
@@ -166,17 +170,17 @@ static json build_tool_call(const std::string & name, const json & args, const s
};
}
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001");
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001");
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001");
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001");
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), CALL_ID_001);
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, CALL_ID_001);
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, CALL_ID_001);
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, CALL_ID_001);
static json first_tool_call =
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001");
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_001);
static json second_tool_call =
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002");
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_002);
static json first_tool_call_alt_id =
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999");
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_999);
template <typename T>
static std::string mode_to_str(T mode) {
@@ -215,6 +219,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str());
LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str());
LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str());
LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str());
LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str());
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
@@ -583,12 +592,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
if (caps.supports_parallel_tool_calls) {
check_per_call_markers();
}
LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET);
extract_function_markers();
LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET);
if (format.mode == tool_format::TAG_WITH_TAGGED) {
analyze_arguments();
}
extract_argument_separator();
extract_args_markers();
LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET);
extract_call_id_markers();
}
}
@@ -979,8 +991,6 @@ void analyze_tools::extract_function_markers() {
}
void analyze_tools::analyze_arguments() {
LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET);
extract_argument_name_markers();
extract_argument_value_markers();
}
@@ -1189,7 +1199,7 @@ void analyze_tools::extract_args_markers() {
const auto & diff = comparison->diff;
if (format.mode != tool_format::JSON_NATIVE) {
if (format.mode == tool_format::JSON_NATIVE) {
std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end;
// these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones
@@ -1211,6 +1221,10 @@ void analyze_tools::extract_args_markers() {
if (find_fun != std::string::npos) {
args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size());
}
size_t find_call_id = args_start.find(CALL_ID_001);
if (find_call_id != std::string::npos) {
args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size());
}
arguments.start = args_start;
arguments.end = args_end;
}
@@ -1250,8 +1264,8 @@ void analyze_tools::extract_call_id_markers() {
return;
}
std::string id_value_1 = "call00001";
std::string id_value_2 = "call99999";
std::string id_value_1 = CALL_ID_001;
std::string id_value_2 = CALL_ID_999;
size_t common_id_prefix_len = 0;
for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) {
@@ -1350,6 +1364,14 @@ void analyze_tools::extract_call_id_markers() {
call_id.suffix = find_first_marker(before_func);
}
if (call_id.prefix == arguments.end) {
call_id.prefix = "";
}
if (call_id.suffix == arguments.start) {
call_id.suffix = "";
}
// When call_id is detected, per_call_end may have been incorrectly set to include
// the call_id_suffix and sample args. Clear it if it starts with call_id_suffix.
if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() &&