From e1a9a6dcbefccb4b864d9385ce8494f2a7f2ffcd Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 15 Apr 2026 10:51:50 +0200 Subject: [PATCH] autoparser: support case of JSON_NATIVE with per-call markers (test case: Reka-Edge) (#21892) --- common/chat-auto-parser-generator.cpp | 17 ++- common/chat-auto-parser.h | 8 +- common/chat-diff-analyzer.cpp | 48 +++++++- common/chat-peg-parser.cpp | 4 +- models/templates/Reka-Edge.jinja | 161 ++++++++++++++++++++++++++ tests/test-chat.cpp | 90 +++++++++++++- 6 files changed, 314 insertions(+), 14 deletions(-) create mode 100644 models/templates/Reka-Edge.jinja diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 3eb1fa9a9..c6431b898 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -198,10 +198,19 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont args_field = format.function_field + "." + args_field; } - auto tools_parser = p.standard_json_tools( - format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls, - inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, - format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + auto tools_parser = p.eps(); + if (format.section_start.empty() && !format.per_call_start.empty()) { + auto single_tool_parser = p.standard_json_tools( + format.per_call_start, format.per_call_end, inputs.tools, inputs.parallel_tool_calls, + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + tools_parser = p.trigger_rule("tool-calls", p.one_or_more(single_tool_parser + p.space())); + } else { + tools_parser = p.standard_json_tools( + format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls, + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + } // Handle content wrappers if present if (ctx.content && ctx.content->is_always_wrapped()) { diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 99dd9f063..6c5474097 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -308,19 +308,23 @@ struct analyze_tools : analyze_base { private: // Extract tool calling 'haystack' for further analysis and delegate further analysis based on format - void analyze_tool_calls(const analyze_reasoning & reasoning); + void analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls); // Analyze format based on position of function and argument name in needle void analyze_tool_call_format(const std::string & haystack, const std::string & fun_name_needle, const std::string & arg_name_needle, - const analyze_reasoning & reasoning); + const analyze_reasoning & reasoning, + bool supports_parallel_tool_calls); // Analyze specifics of JSON native format (entire tool call is a JSON object) void analyze_tool_call_format_json_native(const std::string & clean_haystack, const std::string & fun_name_needle, const std::string & arg_name_needle); + // Check if parallel calls in JSON native format array wrapped or tag wrapped + void analyze_json_native_parallel_calls(); + // Analyze specifics of non-JSON native format (tags for function name or for function name and arguments) void analyze_tool_call_format_non_json(const std::string & clean_haystack, const std::string & fun_name_needle); diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index fa3e36809..2f0bd14af 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -558,7 +558,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl, : analyze_base(tmpl) { LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET); - analyze_tool_calls(reasoning); + analyze_tool_calls(reasoning, caps.supports_parallel_tool_calls); if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) { if (caps.supports_parallel_tool_calls) { @@ -577,7 +577,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl, } } -void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) { +void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls) { json assistant_no_tools = json{ { "role", "assistant" }, { "content", ASSISTANT_MSG } @@ -611,13 +611,14 @@ void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) { return; } - analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning); + analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning, supports_parallel_tool_calls); } void analyze_tools::analyze_tool_call_format(const std::string & haystack, const std::string & fun_name_needle, const std::string & arg_name_needle, - const analyze_reasoning & reasoning) { + const analyze_reasoning & reasoning, + bool supports_parallel_tool_calls) { if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) { return; } @@ -660,6 +661,9 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack, if (format.mode == tool_format::JSON_NATIVE) { analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle); + if (supports_parallel_tool_calls) { + analyze_json_native_parallel_calls(); + } } else { analyze_tool_call_format_non_json(clean_haystack, fun_name_needle); } @@ -668,6 +672,42 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack, format.per_call_end = trim_whitespace(format.per_call_end); } +void analyze_tools::analyze_json_native_parallel_calls() { + json assistant_one_tool = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_two_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call, second_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_one_tool }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + std::string & second_call = comparison->diff.right; + if (!format.section_start.empty() && second_call.find(format.section_start) != std::string::npos) { + format.per_call_start = format.section_start; + format.per_call_end = format.section_end; + format.section_start.clear(); + format.section_end.clear(); + } +} + void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack, const std::string & fun_name_needle, const std::string & arg_name_needle) { diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 624dee22f..56eb567df 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -676,7 +676,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() + - literal("\"") + tool_name(literal(name)) + literal("\""); + atomic(literal("\"") + tool_name(literal(name)) + literal("\"")); auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() + tool_args(schema(json(), "tool-" + name + "-schema", params)); @@ -744,7 +744,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object(); auto tool_name_ = name_key_parser + space() + literal(":") + space() + - literal("\"") + tool_name(literal(name)) + literal("\""); + atomic(literal("\"") + tool_name(literal(name)) + literal("\"")); auto tool_args_ = args_key_parser + space() + literal(":") + space() + tool_args(schema(json(), "tool-" + name + "-schema", params)); diff --git a/models/templates/Reka-Edge.jinja b/models/templates/Reka-Edge.jinja new file mode 100644 index 000000000..76bb21f8a --- /dev/null +++ b/models/templates/Reka-Edge.jinja @@ -0,0 +1,161 @@ +{%- macro render_content(content, num_img_tokens, num_video_frames) -%} + {%- if content is string -%} + {{- content -}} + {%- elif content is sequence -%} + {%- set ns = namespace(out="", prev_was_text=false) -%} + {%- for item in content -%} + {%- set item_type = item.get("type") -%} + {%- if item_type == "text" or item.get("text") is not none -%} + {%- set text = item.get("text", "") -%} + {%- if text -%} + {%- if ns.prev_was_text -%} + {%- set ns.out = ns.out ~ " " -%} + {%- endif -%} + {%- set ns.out = ns.out ~ text -%} + {%- endif -%} + {%- set ns.prev_was_text = text != "" -%} + {%- elif item_type in ["image", "image_url"] or item.get("image") is not none or item.get("image_url") is not none -%} + {%- set ns.out = ns.out ~ "" ~ ("" * num_img_tokens) ~ "" -%} + {%- set ns.prev_was_text = false -%} + {%- elif item_type in ["video", "video_url"] or item.get("video") is not none or item.get("video_url") is not none -%} + {%- set repeat_tokens = num_img_tokens * num_video_frames -%} + {%- set ns.out = ns.out ~ "" -%} + {%- set ns.prev_was_text = false -%} + {%- endif -%} + {%- endfor -%} + {{- ns.out -}} + {%- endif -%} +{%- endmacro -%} +{%- set ns = namespace(out="", last_query_index=messages|length - 1) -%} +{%- for msg in messages[::-1] -%} + {%- set idx = messages|length - 1 - loop.index0 -%} + {%- if msg.get("role") == "user" -%} + {%- set content = msg.get("content", "") -%} + {%- if not (content is string and content.startswith("") and content.endswith("")) -%} + {%- set ns.last_query_index = idx -%} + {%- break -%} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- set last_query_index = ns.last_query_index -%} +{%- set num_img_tokens = num_img_tokens | default(64, true) | int -%} +{%- set num_video_frames = num_video_frames | default(6, true) | int -%} +{%- set start_idx = 0 -%} +{%- set system_text = "" -%} +{%- if messages|length > 0 and messages[0].get("role") in ["system", "developer"] -%} + {%- set system_text = render_content(messages[0].get("content", ""), num_img_tokens, num_video_frames) -%} + {%- set start_idx = 1 -%} +{%- endif -%} +{%- if tools or system_text -%} + {%- set preamble_ns = namespace(text="") -%} + {%- if system_text -%} + {%- set preamble_ns.text = "system: " ~ system_text -%} + {%- endif -%} + {%- if tools -%} + {%- if preamble_ns.text -%} + {%- set preamble_ns.text = preamble_ns.text ~ "\n\n" -%} + {%- else -%} + {%- set preamble_ns.text = "system: " -%} + {%- endif -%} + {%- set preamble_ns.text = preamble_ns.text + ~ "# Tools\n\n" + ~ "You may call one or more functions to assist with the user query.\n\n" + ~ "You are provided with function signatures within XML tags:\n" + ~ "" -%} + {%- for tool in tools -%} + {%- set preamble_ns.text = preamble_ns.text ~ "\n" ~ (tool | tojson(ensure_ascii=True)) -%} + {%- endfor -%} + {%- set preamble_ns.text = preamble_ns.text + ~ "\n\n\n" + ~ "For each function call, return a json object with function name and arguments " + ~ "within XML tags:\n" + ~ "\n{\"name\": , \"arguments\": }\n" -%} + {%- endif -%} + {%- set ns.out = ns.out ~ preamble_ns.text ~ "\n\n" -%} +{%- endif -%} +{%- for idx in range(start_idx, messages|length) -%} + {%- set message = messages[idx] -%} + {%- set role = message.get("role") -%} + {%- set content = message.get("content") -%} + {%- if role == "user" -%} + {%- set prefix_ns = namespace(value="human: ") -%} + {%- if content is sequence and content is not string -%} + {%- for item in content -%} + {%- if item.get("type") == "text" or item.get("text") is not none -%} + {%- set text = item.get("text", "") -%} + {%- if text -%} + {%- break -%} + {%- endif -%} + {%- elif item.get("type") in ["image", "image_url", "video", "video_url"] -%} + {%- set prefix_ns.value = "human:" -%} + {%- break -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set ns.out = ns.out ~ prefix_ns.value ~ render_content(content, num_img_tokens, num_video_frames) ~ "" -%} + {%- elif role == "assistant" -%} + {%- set tool_calls = message.get("tool_calls") -%} + {%- set content_text = render_content(content, num_img_tokens, num_video_frames) -%} + {%- set reasoning_text = "" -%} + {%- if message.get("reasoning_content") is string -%} + {%- set reasoning_text = message.get("reasoning_content") -%} + {%- elif "" in content_text -%} + {%- set reasoning_text = content_text.split("", 1)[0].rstrip("\n").split("")[-1].lstrip("\n") -%} + {%- set content_text = content_text.split("", 1)[1].lstrip("\n") -%} + {%- endif -%} + {%- set ns.out = ns.out ~ "assistant: " -%} + {%- set include_thinking = enable_thinking is true + and idx > last_query_index + and (idx == messages|length - 1 or reasoning_text) + -%} + {%- if include_thinking -%} + {%- set ns.out = ns.out ~ "\n" ~ (reasoning_text.strip() ) ~ "\n\n\n" -%} + {%- endif -%} + {%- set ns.out = ns.out ~ content_text -%} + {%- if tool_calls -%} + {%- if content_text and not ns.out.endswith("\n") -%} + {%- set ns.out = ns.out ~ "\n" -%} + {%- endif -%} + {%- for tool_call in tool_calls -%} + {%- if tool_call.get("function") is not none -%} + {%- set tool_call = tool_call.get("function") -%} + {%- endif -%} + {%- set arguments = tool_call.get("arguments", {}) -%} + {%- if arguments is string -%} + {%- set arguments_json = arguments -%} + {%- elif arguments is mapping -%} + {%- set arguments_json = arguments | tojson(ensure_ascii=True) -%} + {%- else -%} + {%- set arguments_json = arguments | tojson(ensure_ascii=True) -%} + {%- endif -%} + {%- set ns.out = ns.out + ~ "\n" + ~ "{\"name\": \"" ~ tool_call.get("name", "") ~ "\", \"arguments\": " + ~ arguments_json + ~ "}\n" -%} + {%- endfor -%} + {%- endif -%} + {%- if not (continue_final_message and idx == messages|length - 1) -%} + {%- set ns.out = ns.out ~ "\n\n" -%} + {%- endif -%} + {%- elif role == "tool" -%} + {%- if idx == start_idx or messages[idx - 1].get("role") != "tool" -%} + {%- set ns.out = ns.out ~ "human: " -%} + {%- endif -%} + {%- set response_text = render_content(content, num_img_tokens, num_video_frames) -%} + {%- set ns.out = ns.out ~ "\n" ~ response_text ~ "\n" -%} + {%- if idx == messages|length - 1 or messages[idx + 1].get("role") != "tool" -%} + {%- set ns.out = ns.out ~ "" -%} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt + and (messages|length == 0 or messages[-1].get("role") != "assistant") +-%} + {%- if enable_thinking is true -%} + {%- set ns.out = ns.out ~ "assistant: \n" -%} + {%- else -%} + {%- set ns.out = ns.out ~ "assistant:" -%} + {%- endif -%} +{%- endif -%} +{{- ns.out -}} \ No newline at end of file diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 8438a5eaf..3b8de5ce0 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -2164,7 +2164,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test( "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}" "") .tools({ special_function_tool }) .expect(message_assist_call) @@ -2172,7 +2172,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test( "Hello, world!\nWhat's up?\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}" "") .tools({ special_function_tool }) .expect(message_assist_call_content) @@ -3329,6 +3329,92 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // Reka-Edge tests - uses native JSON format with per-call wrapper + { + auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug); + + // Basic content only + tst.test("Hello, world!\nWhat's up?").enable_thinking(false).expect(message_assist).run(); + + // Single tool call without reasoning + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with string argument + tst.test("\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}") + .enable_thinking(false) + .tools({ get_time_tool }) + .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) + .run(); + + // Tool call with reasoning (enable_thinking=true) + tst.test("I'm\nthinking\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + // Multiple tool calls (parallel) + tst.test( + "\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}" + "\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}" + ) + .enable_thinking(false) + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + // Tool call with reasoning and content + tst.test("I need to call a function" + "Let me check the time.\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ get_time_tool }) + .expect(message_with_reasoning_content_and_multiple_tool_calls( + "I need to call a function", "Let me check the time.", { { "get_time", "{\"city\":\"XYZCITY\"}" } } + )) + .run(); + + // Partial tool call (streaming) + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\":") + .tools({ special_function_tool }) + .enable_thinking(false) + .is_partial(true) + .expect(simple_assist_msg("", "", "special_function", "{\"arg1\": ")) + .run(); + + // Tool call with empty arguments + tst.test("\n{\"name\": \"empty_args\", \"arguments\": {}}") + .enable_thinking(false) + .tools({ empty_args_tool }) + .expect(simple_assist_msg("", "", "empty_args", "{}")) + .run(); + + // fake tool call marker in reasoning + tst.test( + "Let me think about \n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}} hmm" + "\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect_reasoning("Let me think about \n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}} hmm") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + }) + .run(); + } + + // Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format // Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|> {