common : implement new jinja template engine (#18462)

* jinja vm

* lexer

* add vm types

* demo

* clean up

* parser ok

* binary_expression::execute

* shadow naming

* bin ops works!

* fix map object

* add string builtins

* add more builtins

* wip

* use mk_val

* eval with is_user_input

* render gemma tmpl ok

* track input string even after transformations

* support binded functions

* keyword arguments and slicing array

* use shared_ptr for values

* add mk_stmt

* allow print source on exception

* fix negate test

* testing more templates

* mostly works

* add filter_statement

* allow func to access ctx

* add jinja-value.cpp

* impl global_from_json

* a lot of fixes

* more tests

* more fix, more tests

* more fixes

* rm workarounds

* demo: type inferrence

* add placeholder for tojson

* improve function args handling

* rm type inference

* no more std::regex

* trailing spaces

* make testing more flexible

* make output a bit cleaner

* (wip) redirect minja calls

* test: add --output

* fix crash on macro kwargs

* add minimal caps system

* add some workarounds

* rm caps_apply_workarounds

* get rid of preprocessing

* more fixes

* fix test-chat-template

* move test-chat-jinja into test-chat-template

* rm test-chat-jinja from cmake

* test-chat-template: use common

* fix build

* fix build (2)

* rename vm --> interpreter

* improve error reporting

* correct lstrip behavior

* add tojson

* more fixes

* disable tests for COMMON_CHAT_FORMAT_GENERIC

* make sure tojson output correct order

* add object.length

* fully functional selectattr / rejectattr

* improve error reporting

* more builtins added, more fixes

* create jinja rendering tests

* fix testing.h path

* adjust whitespace rules

* more fixes

* temporary disable test for ibm-granite

* r/lstrip behavior matched with hf.js

* minimax, glm4.5 ok

* add append and pop

* kimi-k2 ok

* test-chat passed

* fix lstrip_block

* add more jinja tests

* cast to unsigned char

* allow dict key to be numeric

* nemotron: rm windows newline

* tests ok

* fix test

* rename interpreter --> runtime

* fix build

* add more checks

* bring back generic format support

* fix Apertus

* [json.exception.out_of_range.403] key 'content' not found

* rm generic test

* refactor input marking

* add docs

* fix windows build

* clarify error message

* improved tests

* split/rsplit with maxsplit

* non-inverse maxsplit

forgot to change after simplifying

* implement separators for tojson and fix indent

* i like to move it move it

* rename null -- > none

* token::eof

* some nits + comments

* add exception classes for lexer and parser

* null -> none

* rename global -> env

* rm minja

* update docs

* docs: add input marking caveats

* imlement missing jinja-tests functions

* oops

* support trim filter with args, remove bogus to_json reference

* numerous argument fixes

* updated tests

* implement optional strip chars parameter

* use new chars parameter

* float filter also has default

* always leave at least one decimal in float string

* jinja : static analysis + header cleanup + minor fixes

* add fuzz test

* add string.cpp

* fix chat_template_kwargs

* nits

* fix build

* revert

* unrevert

sorry :)

* add fuzz func_args, refactor to be safer

* fix array.map()

* loosen ensure_vals max count condition, add not impl for map(int)

* hopefully fix windows

* check if empty first

* normalize newlines

---------

Co-authored-by: Alde Rojas <hello@alde.dev>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen
2026-01-16 11:22:06 +01:00
committed by GitHub
parent aa1dc3770a
commit c15395f73c
30 changed files with 7159 additions and 3926 deletions
+239 -29
View File
@@ -7,8 +7,13 @@
#include "log.h"
#include "regex-partial.h"
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>
// #include <minja/chat-template.hpp>
// #include <minja/minja.hpp>
#include "jinja/parser.h"
#include "jinja/value.h"
#include "jinja/runtime.h"
#include "jinja/caps.h"
#include <algorithm>
#include <cstdio>
@@ -135,7 +140,68 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
return diffs;
}
typedef minja::chat_template common_chat_template;
using chat_template_caps = jinja::caps;
struct common_chat_template {
jinja::program prog;
std::string bos_tok;
std::string eos_tok;
std::string src;
chat_template_caps caps;
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(src);
this->prog = jinja::parse_from_tokens(lexer_res);
this->src = lexer_res.source;
this->bos_tok = bos_token;
this->eos_tok = eos_token;
this->caps = jinja::caps_get(prog);
// LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str());
}
const std::string & source() const { return src; }
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }
// TODO: this is ugly, refactor it somehow
json add_system(const json & messages, const std::string & system_prompt) const {
GGML_ASSERT(messages.is_array());
auto msgs_copy = messages;
if (!caps.supports_system_role) {
if (msgs_copy.empty()) {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "user"},
{"content", system_prompt}
});
} else {
auto & first_msg = msgs_copy[0];
if (!first_msg.contains("content")) {
first_msg["content"] = "";
}
first_msg["content"] = system_prompt + "\n\n"
+ first_msg["content"].get<std::string>();
}
} else {
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
msgs_copy.insert(msgs_copy.begin(), json{
{"role", "system"},
{"content", system_prompt}
});
} else if (msgs_copy[0].at("role") == "system") {
msgs_copy[0]["content"] = system_prompt;
}
}
return msgs_copy;
}
chat_template_caps original_caps() const {
return caps;
}
};
struct common_chat_templates {
bool add_bos;
@@ -161,6 +227,7 @@ struct templates_params {
bool add_bos;
bool add_eos;
bool is_inference = true;
bool mark_input = true; // whether to mark input strings in the jinja context
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -627,14 +694,16 @@ common_chat_templates_ptr common_chat_templates_init(
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
LOG_ERR("%s: error: %s\n", __func__, e.what());
LOG_ERR("%s: failed to initialize chat template\n", __func__);
LOG_ERR("%s: please consider disabling jinja via --no-jinja, or using another chat template\n", __func__);
throw e;
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
@@ -739,27 +808,43 @@ static std::string apply(
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt)
{
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
if (tools_override) {
tmpl_inputs.tools = *tools_override;
} else {
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
}
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
tmpl_inputs.extra_context = inputs.extra_context;
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
if (additional_context) {
tmpl_inputs.extra_context.merge_patch(*additional_context);
}
// TODO: add flag to control date/time, if only for testing purposes.
// tmpl_inputs.now = std::chrono::system_clock::now();
jinja::context ctx(tmpl.source());
minja::chat_template_options tmpl_opts;
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
nlohmann::ordered_json inp = nlohmann::ordered_json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
{"bos_token", tmpl.bos_token()},
{"eos_token", tmpl.eos_token()},
};
if (inputs.extra_context.is_object()) {
// TODO: do we need to merge, or replacing is fine?
for (const auto & [k, v] : inputs.extra_context.items()) {
inp[k] = v;
}
}
if (additional_context.has_value()) {
// TODO: merge properly instead of overwriting (matching old behavior)
for (const auto & [k, v] : additional_context->items()) {
inp[k] = v;
}
}
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp["tools"].is_null()) {
inp["tools"] = json::array();
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
// render
jinja::runtime runtime(ctx);
const jinja::value results = runtime.execute(tmpl.prog);
auto parts = runtime.gather_string_parts(results);
std::string result = parts->as_string().str();
// TODO: improve this later
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
@@ -846,10 +931,17 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
builder.add_schema("root", schema);
});
auto tweaked_messages = common_chat_template::add_system(
auto tweaked_messages = tmpl.add_system(
inputs.messages,
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
// ensure all messages has "content" field
for (auto & message : tweaked_messages) {
if (!message.contains("content") || message["content"].is_null()) {
message["content"] = "";
}
}
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
data.format = COMMON_CHAT_FORMAT_GENERIC;
return data;
@@ -1364,7 +1456,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json {
{"date_string", format_time(inputs.now, "%d %b %Y")},
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
{"builtin_tools", builtin_tools},
});
return data;
}
@@ -2669,6 +2761,107 @@ static common_chat_params common_chat_params_init_seed_oss(
return data;
}
// various workarounds for known issues with certain templates or model behaviors
// TODO @ngxson : improve this (how?)
namespace workaround {
// if first message is system and template does not support it, merge it with next message
static void system_message_not_supported(json & messages) {
if (!messages.empty() && messages.front().at("role") == "system") {
if (messages.size() > 1) {
LOG_DBG("Merging system prompt into next message\n");
auto & first_msg = messages.front();
auto & second_msg = messages[1];
second_msg["content"] = first_msg.at("content").get<std::string>()
+ "\n" + second_msg.at("content").get<std::string>();
messages.erase(messages.begin());
} else {
LOG_WRN("Removing system prompt due to template not supporting system role\n");
messages.erase(messages.begin());
}
}
}
static void func_args_not_string(json & messages) {
GGML_ASSERT(messages.is_array());
for (auto & message : messages) {
if (message.contains("tool_calls")) {
for (auto & tool_call : message["tool_calls"]) {
if (tool_call.contains("function") && tool_call["function"].contains("arguments")) {
auto & args = tool_call["function"]["arguments"];
if (args.is_string()) {
try {
args = json::parse(args.get<std::string>());
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse tool call arguments as JSON: " + std::string(e.what()));
}
}
}
}
}
}
}
static void move_tool_calls_to_content(json & messages, int indent_spaces = 2) {
GGML_ASSERT(messages.is_array());
for (auto & message : messages) {
if (message.contains("tool_calls")) {
auto tool_calls_new = json{
{"tool_calls", message.at("tool_calls")}
};
message.erase("tool_calls");
auto content = message.at("content");
std::string content_new = content.is_null() ? "" : content.get<std::string>();
message["content"] = content_new + tool_calls_new.dump(indent_spaces, ' ', false, json::error_handler_t::replace);
}
}
}
// TODO @ngxson : we may remove support for generic schema in the future
static void use_generic_schema(json & messages) {
GGML_ASSERT(messages.is_array());
for (auto & message : messages) {
if (message.contains("tool_calls") && message.at("tool_calls").is_array()) {
auto & tool_calls = message.at("tool_calls");
for (auto & tool_call : tool_calls) {
if (tool_call.contains("type") && tool_call.at("type") == "function" &&
tool_call.contains("function") && tool_call.at("function").is_object()) {
// Copy values before erasing to avoid use-after-free
json name_value;
json arguments_value;
json id_value;
const auto & function = tool_call.at("function");
if (function.contains("name")) {
name_value = function.at("name");
}
if (function.contains("arguments")) {
arguments_value = function.at("arguments");
}
if (tool_call.contains("id")) {
id_value = tool_call.at("id");
}
// Now safely erase and assign in the correct order
tool_call.erase("type");
tool_call.erase("function");
tool_call.erase("id");
// Reassign in desired order: name, arguments, id
if (!name_value.is_null()) {
tool_call["name"] = name_value;
}
if (!arguments_value.is_null()) {
tool_call["arguments"] = arguments_value;
}
if (!id_value.is_null()) {
tool_call["id"] = id_value;
}
}
}
}
}
}
} // namespace workaround
static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
@@ -2690,6 +2883,10 @@ static common_chat_params common_chat_templates_apply_jinja(
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (!tmpl.original_caps().supports_system_role) {
workaround::system_message_not_supported(params.messages);
}
params.extra_context = json::object();
for (auto el : inputs.chat_template_kwargs) {
params.extra_context[el.first] = json::parse(el.second);
@@ -2728,11 +2925,15 @@ static common_chat_params common_chat_templates_apply_jinja(
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
workaround::func_args_not_string(params.messages);
return common_chat_params_init_command_r7b(tmpl, params);
}
// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
workaround::func_args_not_string(params.messages);
workaround::use_generic_schema(params.messages);
workaround::move_tool_calls_to_content(params.messages);
return common_chat_params_init_granite(tmpl, params);
}
@@ -2741,6 +2942,7 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("<arg_key>") != std::string::npos &&
src.find("<arg_value>") != std::string::npos &&
params.json_schema.is_null()) {
workaround::func_args_not_string(params.messages);
return common_chat_params_init_glm_4_5(tmpl, params);
}
@@ -2752,6 +2954,7 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("<function=") != std::string::npos &&
src.find("<parameters>") != std::string::npos &&
src.find("<parameter=") != std::string::npos) {
workaround::func_args_not_string(params.messages);
// Nemotron 3 Nano 30B A3B
if (src.find("<think>") != std::string::npos) {
return common_chat_params_init_nemotron_v3(tmpl, params);
@@ -2788,6 +2991,7 @@ static common_chat_params common_chat_templates_apply_jinja(
// Seed-OSS
if (src.find("<seed:think>") != std::string::npos) {
workaround::func_args_not_string(params.messages);
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}
@@ -2809,6 +3013,7 @@ static common_chat_params common_chat_templates_apply_jinja(
// MiniMax-M2 format detection
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
workaround::func_args_not_string(params.messages);
return common_chat_params_init_minimax_m2(tmpl, params);
}
@@ -2855,6 +3060,7 @@ static common_chat_params common_chat_templates_apply_jinja(
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
workaround::func_args_not_string(params.messages);
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
@@ -2883,10 +3089,14 @@ static common_chat_params common_chat_templates_apply_jinja(
// Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) {
workaround::func_args_not_string(params.messages);
return common_chat_params_init_mistral_nemo(tmpl, params);
}
// Generic fallback
workaround::func_args_not_string(params.messages);
workaround::use_generic_schema(params.messages);
workaround::move_tool_calls_to_content(params.messages);
return common_chat_params_init_generic(tmpl, params);
}