common : implement new jinja template engine (#18462)

* jinja vm

* lexer

* add vm types

* demo

* clean up

* parser ok

* binary_expression::execute

* shadow naming

* bin ops works!

* fix map object

* add string builtins

* add more builtins

* wip

* use mk_val

* eval with is_user_input

* render gemma tmpl ok

* track input string even after transformations

* support binded functions

* keyword arguments and slicing array

* use shared_ptr for values

* add mk_stmt

* allow print source on exception

* fix negate test

* testing more templates

* mostly works

* add filter_statement

* allow func to access ctx

* add jinja-value.cpp

* impl global_from_json

* a lot of fixes

* more tests

* more fix, more tests

* more fixes

* rm workarounds

* demo: type inferrence

* add placeholder for tojson

* improve function args handling

* rm type inference

* no more std::regex

* trailing spaces

* make testing more flexible

* make output a bit cleaner

* (wip) redirect minja calls

* test: add --output

* fix crash on macro kwargs

* add minimal caps system

* add some workarounds

* rm caps_apply_workarounds

* get rid of preprocessing

* more fixes

* fix test-chat-template

* move test-chat-jinja into test-chat-template

* rm test-chat-jinja from cmake

* test-chat-template: use common

* fix build

* fix build (2)

* rename vm --> interpreter

* improve error reporting

* correct lstrip behavior

* add tojson

* more fixes

* disable tests for COMMON_CHAT_FORMAT_GENERIC

* make sure tojson output correct order

* add object.length

* fully functional selectattr / rejectattr

* improve error reporting

* more builtins added, more fixes

* create jinja rendering tests

* fix testing.h path

* adjust whitespace rules

* more fixes

* temporary disable test for ibm-granite

* r/lstrip behavior matched with hf.js

* minimax, glm4.5 ok

* add append and pop

* kimi-k2 ok

* test-chat passed

* fix lstrip_block

* add more jinja tests

* cast to unsigned char

* allow dict key to be numeric

* nemotron: rm windows newline

* tests ok

* fix test

* rename interpreter --> runtime

* fix build

* add more checks

* bring back generic format support

* fix Apertus

* [json.exception.out_of_range.403] key 'content' not found

* rm generic test

* refactor input marking

* add docs

* fix windows build

* clarify error message

* improved tests

* split/rsplit with maxsplit

* non-inverse maxsplit

forgot to change after simplifying

* implement separators for tojson and fix indent

* i like to move it move it

* rename null -- > none

* token::eof

* some nits + comments

* add exception classes for lexer and parser

* null -> none

* rename global -> env

* rm minja

* update docs

* docs: add input marking caveats

* imlement missing jinja-tests functions

* oops

* support trim filter with args, remove bogus to_json reference

* numerous argument fixes

* updated tests

* implement optional strip chars parameter

* use new chars parameter

* float filter also has default

* always leave at least one decimal in float string

* jinja : static analysis + header cleanup + minor fixes

* add fuzz test

* add string.cpp

* fix chat_template_kwargs

* nits

* fix build

* revert

* unrevert

sorry :)

* add fuzz func_args, refactor to be safer

* fix array.map()

* loosen ensure_vals max count condition, add not impl for map(int)

* hopefully fix windows

* check if empty first

* normalize newlines

---------

Co-authored-by: Alde Rojas <hello@alde.dev>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen
2026-01-16 11:22:06 +01:00
committed by GitHub
parent aa1dc3770a
commit c15395f73c
30 changed files with 7159 additions and 3926 deletions
+288 -32
View File
@@ -2,6 +2,11 @@
#include <vector>
#include <sstream>
#include <regex>
#include <iostream>
#include <fstream>
#include <filesystem>
#include <nlohmann/json.hpp>
#undef NDEBUG
#include <cassert>
@@ -9,6 +14,152 @@
#include "llama.h"
#include "common.h"
#include "chat.h"
#include "jinja/runtime.h"
#include "jinja/parser.h"
#include "jinja/lexer.h"
#include "jinja/caps.h"
using json = nlohmann::ordered_json;
int main_automated_tests(void);
void run_multiple(std::string dir_path, bool stop_on_first_failure, json input, bool use_common = false);
void run_single(std::string contents, json input, bool use_common = false, const std::string & output_path = "");
std::string HELP = R"(
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
Options:
-h, --help Show this help message and exit.
--json <path> Path to the JSON input file.
--stop-on-first-fail Stop testing on the first failure (default: false).
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
--output <path> Path to output results (only for single template runs).
If PATH_TO_TEMPLATE is a file, runs that single template.
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
If PATH_TO_TEMPLATE is omitted, runs automated tests (default CI mode).
)";
std::string DEFAULT_JSON = R"({
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
},
{
"role": "assistant",
"content": "I am fine, thank you!"
}
],
"bos_token": "<s>",
"eos_token": "</s>",
"tools": [],
"add_generation_prompt": true
})";
int main(int argc, char ** argv) {
std::vector<std::string> args(argv, argv + argc);
std::string tmpl_path;
std::string json_path;
std::string output_path;
bool stop_on_first_fail = false;
bool use_common = true;
for (size_t i = 1; i < args.size(); i++) {
if (args[i] == "--help" || args[i] == "-h") {
std::cout << HELP << "\n";
return 0;
} else if (args[i] == "--json" && i + 1 < args.size()) {
json_path = args[i + 1];
i++;
} else if (args[i] == "--stop-on-first-fail") {
stop_on_first_fail = true;
} else if (args[i] == "--output" && i + 1 < args.size()) {
output_path = args[i + 1];
i++;
} else if (args[i] == "--no-common") {
use_common = true;
} else if (tmpl_path.empty()) {
tmpl_path = args[i];
} else {
std::cerr << "Unknown argument: " << args[i] << "\n";
std::cout << HELP << "\n";
return 1;
}
}
if (tmpl_path.empty()) {
return main_automated_tests();
}
json input_json;
if (!json_path.empty()) {
std::ifstream json_file(json_path);
if (!json_file) {
std::cerr << "Error: Could not open JSON file: " << json_path << "\n";
return 1;
}
std::string content = std::string(
std::istreambuf_iterator<char>(json_file),
std::istreambuf_iterator<char>());
input_json = json::parse(content);
} else {
input_json = json::parse(DEFAULT_JSON);
}
std::filesystem::path p(tmpl_path);
if (std::filesystem::is_directory(p)) {
run_multiple(tmpl_path, stop_on_first_fail, input_json, use_common);
} else if (std::filesystem::is_regular_file(p)) {
std::ifstream infile(tmpl_path);
std::string contents = std::string(
std::istreambuf_iterator<char>(infile),
std::istreambuf_iterator<char>());
run_single(contents, input_json, use_common, output_path);
} else {
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
return 1;
}
return 0;
}
void run_multiple(std::string dir_path, bool stop_on_first_fail, json input, bool use_common) {
std::vector<std::string> failed_tests;
// list all files in models/templates/ and run each
size_t test_count = 0;
for (const auto & entry : std::filesystem::directory_iterator(dir_path)) {
// only process .jinja files
if (entry.path().extension() == ".jinja" && entry.is_regular_file()) {
test_count++;
std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n";
std::ifstream infile(entry.path());
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
try {
run_single(contents, input, use_common);
} catch (const std::exception & e) {
std::cout << "Exception: " << e.what() << "\n";
std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n";
failed_tests.push_back(entry.path().string());
if (stop_on_first_fail) {
break;
}
}
}
}
std::cout << "\n\n=== TEST SUMMARY ===\n";
std::cout << "Total tests run: " << test_count << "\n";
std::cout << "Total failed tests: " << failed_tests.size() << "\n";
for (const auto & test : failed_tests) {
std::cout << "FAILED TEST: " << test << "\n";
}
}
static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
@@ -19,6 +170,105 @@ static std::string normalize_newlines(const std::string & s) {
#endif
}
static std::string format_using_common(
const std::string & template_str,
const std::string & bos_token,
const std::string & eos_token,
std::vector<common_chat_msg> & messages,
std::vector<common_chat_tool> tools = {}) {
auto tmpls = common_chat_templates_init(/* model= */ nullptr, template_str, bos_token, eos_token);
common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
inputs.tools = tools;
inputs.add_generation_prompt = true;
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
output = normalize_newlines(output);
return output;
}
// skip libcommon, use direct jinja engine
static jinja::value_string format_using_direct_engine(
const std::string & template_str,
json & input) {
// lexing
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(template_str);
// compile to AST
jinja::program ast = jinja::parse_from_tokens(lexer_res);
// check caps for workarounds
jinja::caps_get(ast);
std::cout << "\n=== RUN ===\n";
jinja::context ctx(template_str);
jinja::global_from_json(ctx, input, true);
jinja::runtime runtime(ctx);
const jinja::value results = runtime.execute(ast);
auto parts = runtime.gather_string_parts(results);
std::cout << "\n=== RESULTS ===\n";
for (const auto & part : parts->as_string().parts) {
std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n";
}
return parts;
}
void run_single(std::string contents, json input, bool use_common, const std::string & output_path) {
jinja::enable_debug(true);
jinja::value_string output_parts;
if (use_common) {
std::string bos_token = "<s>";
std::string eos_token = "</s>";
if (input.contains("bos_token")) {
bos_token = input["bos_token"].get<std::string>();
}
if (input.contains("eos_token")) {
eos_token = input["eos_token"].get<std::string>();
}
nlohmann::ordered_json msgs_json = input["messages"];
nlohmann::ordered_json tools_json = input["tools"];
auto messages = common_chat_msgs_parse_oaicompat(msgs_json);
auto tools = common_chat_tools_parse_oaicompat(tools_json);
auto output = format_using_common(contents, bos_token, eos_token, messages, tools);
std::cout << "\n=== OUTPUT ===\n";
std::cout << output << "\n";
output_parts = jinja::mk_val<jinja::value_string>(output);
} else {
output_parts = format_using_direct_engine(contents, input);
std::cout << "\n=== OUTPUT ===\n";
std::cout << output_parts->as_string().str() << "\n";
}
if (!output_path.empty()) {
std::ofstream outfile(output_path);
if (!outfile) {
throw std::runtime_error("Could not open output file: " + output_path);
}
outfile << output_parts->as_string().str();
outfile.close();
std::cout << "\n=== OUTPUT WRITTEN TO " << output_path << " ===\n";
}
}
//
// Automated tests for chat templates
//
#define U8C(x) (const char*)(u8##x)
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
@@ -28,7 +278,9 @@ static common_chat_msg simple_msg(const std::string & role, const std::string &
return msg;
}
int main(void) {
int main_automated_tests(void) {
// jinja::enable_debug(true);
std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
@@ -61,8 +313,8 @@ int main(void) {
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "<s>[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .bos_token= */ "<s>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
@@ -177,7 +429,7 @@ int main(void) {
/* .name= */ "ChatGLM3",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
},
{
/* .name= */ "ChatGLM4",
@@ -221,7 +473,7 @@ int main(void) {
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
/* .expected_output_jinja= */ "",
/* .expected_output_jinja= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
@@ -308,9 +560,9 @@ int main(void) {
assert(res > 0);
supported_tmpl.resize(res);
res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
printf("Built-in chat templates:\n");
std::cout << "Built-in chat templates:\n";
for (auto tmpl : supported_tmpl) {
printf(" %s\n", tmpl);
std::cout << " " << tmpl << "\n";
}
// test invalid chat template
@@ -319,7 +571,7 @@ int main(void) {
const auto add_generation_prompt = true;
for (const auto & test_case : test_cases) {
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
std::cout << "\n\n=== " << test_case.name << " ===\n\n";
formatted_chat.resize(1024);
res = llama_chat_apply_template(
test_case.template_str.c_str(),
@@ -332,10 +584,10 @@ int main(void) {
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
if (output != test_case.expected_output) {
printf("Expected:\n%s\n", test_case.expected_output.c_str());
printf("-------------------------\n");
printf("Actual:\n%s\n", output.c_str());
fflush(stdout);
std::cout << "Expected:\n" << test_case.expected_output << "\n";
std::cout << "-------------------------\n";
std::cout << "Actual:\n" << output << "\n";
std::cout.flush();
assert(output == test_case.expected_output);
}
}
@@ -348,39 +600,41 @@ int main(void) {
if (!test_case.supported_with_jinja) {
continue;
}
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
std::cout << "\n\n=== " << test_case.name << " (jinja) ===\n\n";
try {
auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
inputs.add_generation_prompt = add_generation_prompt;
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
output = normalize_newlines(output);
auto output = format_using_common(
test_case.template_str,
test_case.bos_token,
test_case.eos_token,
messages);
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str());
printf("-------------------------\n");
printf("Actual:\n%s\n", output.c_str());
fflush(stdout);
std::cout << "Template:```\n" << test_case.template_str << "\n```";
std::cout << "-------------------------\n";
std::cout << "Expected:```\n" << expected_output << "\n```";
std::cout << "-------------------------\n";
std::cout << "Actual:```\n" << output << "\n```";
std::cout.flush();
assert(output == expected_output);
}
} catch (const std::exception & e) {
printf("ERROR: %s\n", e.what());
std::cerr << "ERROR: " << e.what() << "\n";
assert(false);
}
}
// TODO: llama_chat_format_single will be deprecated, remove these tests later
// test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
std::vector<common_chat_msg> chat2;
auto sys_msg = simple_msg("system", "You are a helpful assistant");
auto fmt_sys = [&](std::string tmpl_str) {
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
std::cout << "fmt_sys(" << tmpl_str << ") : " << output << "\n";
std::cout << "-------------------------\n";
return output;
};
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
@@ -397,7 +651,7 @@ int main(void) {
// test llama_chat_format_single for user message
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
std::cout << "\n\n=== llama_chat_format_single (user message) ===\n\n";
chat2.push_back(simple_msg("system", "You are a helpful assistant"));
chat2.push_back(simple_msg("user", "Hello"));
chat2.push_back(simple_msg("assistant", "I am assistant"));
@@ -406,8 +660,8 @@ int main(void) {
auto fmt_single = [&](const std::string & tmpl_str) {
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
std::cout << "fmt_single(" << tmpl_str << ") : " << output << "\n";
std::cout << "-------------------------\n";
return output;
};
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
@@ -419,7 +673,9 @@ int main(void) {
assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
// assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
std::cout << "\nOK: All tests passed successfully.\n";
return 0;
}