common : implement new jinja template engine (#18462)
* jinja vm * lexer * add vm types * demo * clean up * parser ok * binary_expression::execute * shadow naming * bin ops works! * fix map object * add string builtins * add more builtins * wip * use mk_val * eval with is_user_input * render gemma tmpl ok * track input string even after transformations * support binded functions * keyword arguments and slicing array * use shared_ptr for values * add mk_stmt * allow print source on exception * fix negate test * testing more templates * mostly works * add filter_statement * allow func to access ctx * add jinja-value.cpp * impl global_from_json * a lot of fixes * more tests * more fix, more tests * more fixes * rm workarounds * demo: type inferrence * add placeholder for tojson * improve function args handling * rm type inference * no more std::regex * trailing spaces * make testing more flexible * make output a bit cleaner * (wip) redirect minja calls * test: add --output * fix crash on macro kwargs * add minimal caps system * add some workarounds * rm caps_apply_workarounds * get rid of preprocessing * more fixes * fix test-chat-template * move test-chat-jinja into test-chat-template * rm test-chat-jinja from cmake * test-chat-template: use common * fix build * fix build (2) * rename vm --> interpreter * improve error reporting * correct lstrip behavior * add tojson * more fixes * disable tests for COMMON_CHAT_FORMAT_GENERIC * make sure tojson output correct order * add object.length * fully functional selectattr / rejectattr * improve error reporting * more builtins added, more fixes * create jinja rendering tests * fix testing.h path * adjust whitespace rules * more fixes * temporary disable test for ibm-granite * r/lstrip behavior matched with hf.js * minimax, glm4.5 ok * add append and pop * kimi-k2 ok * test-chat passed * fix lstrip_block * add more jinja tests * cast to unsigned char * allow dict key to be numeric * nemotron: rm windows newline * tests ok * fix test * rename interpreter --> runtime * fix build * add more checks * bring back generic format support * fix Apertus * [json.exception.out_of_range.403] key 'content' not found * rm generic test * refactor input marking * add docs * fix windows build * clarify error message * improved tests * split/rsplit with maxsplit * non-inverse maxsplit forgot to change after simplifying * implement separators for tojson and fix indent * i like to move it move it * rename null -- > none * token::eof * some nits + comments * add exception classes for lexer and parser * null -> none * rename global -> env * rm minja * update docs * docs: add input marking caveats * imlement missing jinja-tests functions * oops * support trim filter with args, remove bogus to_json reference * numerous argument fixes * updated tests * implement optional strip chars parameter * use new chars parameter * float filter also has default * always leave at least one decimal in float string * jinja : static analysis + header cleanup + minor fixes * add fuzz test * add string.cpp * fix chat_template_kwargs * nits * fix build * revert * unrevert sorry :) * add fuzz func_args, refactor to be safer * fix array.map() * loosen ensure_vals max count condition, add not impl for map(int) * hopefully fix windows * check if empty first * normalize newlines --------- Co-authored-by: Alde Rojas <hello@alde.dev> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
+288
-32
@@ -2,6 +2,11 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
@@ -9,6 +14,152 @@
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "chat.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/lexer.h"
|
||||
#include "jinja/caps.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
int main_automated_tests(void);
|
||||
|
||||
void run_multiple(std::string dir_path, bool stop_on_first_failure, json input, bool use_common = false);
|
||||
void run_single(std::string contents, json input, bool use_common = false, const std::string & output_path = "");
|
||||
|
||||
|
||||
|
||||
std::string HELP = R"(
|
||||
Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE
|
||||
Options:
|
||||
-h, --help Show this help message and exit.
|
||||
--json <path> Path to the JSON input file.
|
||||
--stop-on-first-fail Stop testing on the first failure (default: false).
|
||||
--no-common Use direct Jinja engine instead of common chat templates (default: use common).
|
||||
--output <path> Path to output results (only for single template runs).
|
||||
If PATH_TO_TEMPLATE is a file, runs that single template.
|
||||
If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory.
|
||||
If PATH_TO_TEMPLATE is omitted, runs automated tests (default CI mode).
|
||||
)";
|
||||
|
||||
std::string DEFAULT_JSON = R"({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I am fine, thank you!"
|
||||
}
|
||||
],
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"tools": [],
|
||||
"add_generation_prompt": true
|
||||
})";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::vector<std::string> args(argv, argv + argc);
|
||||
|
||||
std::string tmpl_path;
|
||||
std::string json_path;
|
||||
std::string output_path;
|
||||
bool stop_on_first_fail = false;
|
||||
bool use_common = true;
|
||||
|
||||
for (size_t i = 1; i < args.size(); i++) {
|
||||
if (args[i] == "--help" || args[i] == "-h") {
|
||||
std::cout << HELP << "\n";
|
||||
return 0;
|
||||
} else if (args[i] == "--json" && i + 1 < args.size()) {
|
||||
json_path = args[i + 1];
|
||||
i++;
|
||||
} else if (args[i] == "--stop-on-first-fail") {
|
||||
stop_on_first_fail = true;
|
||||
} else if (args[i] == "--output" && i + 1 < args.size()) {
|
||||
output_path = args[i + 1];
|
||||
i++;
|
||||
} else if (args[i] == "--no-common") {
|
||||
use_common = true;
|
||||
} else if (tmpl_path.empty()) {
|
||||
tmpl_path = args[i];
|
||||
} else {
|
||||
std::cerr << "Unknown argument: " << args[i] << "\n";
|
||||
std::cout << HELP << "\n";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (tmpl_path.empty()) {
|
||||
return main_automated_tests();
|
||||
}
|
||||
|
||||
json input_json;
|
||||
if (!json_path.empty()) {
|
||||
std::ifstream json_file(json_path);
|
||||
if (!json_file) {
|
||||
std::cerr << "Error: Could not open JSON file: " << json_path << "\n";
|
||||
return 1;
|
||||
}
|
||||
std::string content = std::string(
|
||||
std::istreambuf_iterator<char>(json_file),
|
||||
std::istreambuf_iterator<char>());
|
||||
input_json = json::parse(content);
|
||||
} else {
|
||||
input_json = json::parse(DEFAULT_JSON);
|
||||
}
|
||||
|
||||
std::filesystem::path p(tmpl_path);
|
||||
if (std::filesystem::is_directory(p)) {
|
||||
run_multiple(tmpl_path, stop_on_first_fail, input_json, use_common);
|
||||
} else if (std::filesystem::is_regular_file(p)) {
|
||||
std::ifstream infile(tmpl_path);
|
||||
std::string contents = std::string(
|
||||
std::istreambuf_iterator<char>(infile),
|
||||
std::istreambuf_iterator<char>());
|
||||
run_single(contents, input_json, use_common, output_path);
|
||||
} else {
|
||||
std::cerr << "Error: PATH_TO_TEMPLATE is not a valid file or directory: " << tmpl_path << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void run_multiple(std::string dir_path, bool stop_on_first_fail, json input, bool use_common) {
|
||||
std::vector<std::string> failed_tests;
|
||||
|
||||
// list all files in models/templates/ and run each
|
||||
size_t test_count = 0;
|
||||
|
||||
for (const auto & entry : std::filesystem::directory_iterator(dir_path)) {
|
||||
// only process .jinja files
|
||||
if (entry.path().extension() == ".jinja" && entry.is_regular_file()) {
|
||||
test_count++;
|
||||
std::cout << "\n\n=== RUNNING TEMPLATE FILE: " << entry.path().string() << " ===\n";
|
||||
std::ifstream infile(entry.path());
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
||||
try {
|
||||
run_single(contents, input, use_common);
|
||||
} catch (const std::exception & e) {
|
||||
std::cout << "Exception: " << e.what() << "\n";
|
||||
std::cout << "=== ERROR WITH TEMPLATE FILE: " << entry.path().string() << " ===\n";
|
||||
failed_tests.push_back(entry.path().string());
|
||||
if (stop_on_first_fail) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n\n=== TEST SUMMARY ===\n";
|
||||
std::cout << "Total tests run: " << test_count << "\n";
|
||||
std::cout << "Total failed tests: " << failed_tests.size() << "\n";
|
||||
for (const auto & test : failed_tests) {
|
||||
std::cout << "FAILED TEST: " << test << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static std::string normalize_newlines(const std::string & s) {
|
||||
#ifdef _WIN32
|
||||
@@ -19,6 +170,105 @@ static std::string normalize_newlines(const std::string & s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
static std::string format_using_common(
|
||||
const std::string & template_str,
|
||||
const std::string & bos_token,
|
||||
const std::string & eos_token,
|
||||
std::vector<common_chat_msg> & messages,
|
||||
std::vector<common_chat_tool> tools = {}) {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, template_str, bos_token, eos_token);
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.use_jinja = true;
|
||||
inputs.messages = messages;
|
||||
inputs.tools = tools;
|
||||
inputs.add_generation_prompt = true;
|
||||
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
|
||||
output = normalize_newlines(output);
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
// skip libcommon, use direct jinja engine
|
||||
static jinja::value_string format_using_direct_engine(
|
||||
const std::string & template_str,
|
||||
json & input) {
|
||||
// lexing
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(template_str);
|
||||
|
||||
// compile to AST
|
||||
jinja::program ast = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
// check caps for workarounds
|
||||
jinja::caps_get(ast);
|
||||
|
||||
std::cout << "\n=== RUN ===\n";
|
||||
jinja::context ctx(template_str);
|
||||
|
||||
jinja::global_from_json(ctx, input, true);
|
||||
|
||||
jinja::runtime runtime(ctx);
|
||||
const jinja::value results = runtime.execute(ast);
|
||||
auto parts = runtime.gather_string_parts(results);
|
||||
|
||||
std::cout << "\n=== RESULTS ===\n";
|
||||
for (const auto & part : parts->as_string().parts) {
|
||||
std::cout << (part.is_input ? "DATA" : "TMPL") << ": " << part.val << "\n";
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
|
||||
void run_single(std::string contents, json input, bool use_common, const std::string & output_path) {
|
||||
jinja::enable_debug(true);
|
||||
|
||||
jinja::value_string output_parts;
|
||||
|
||||
if (use_common) {
|
||||
std::string bos_token = "<s>";
|
||||
std::string eos_token = "</s>";
|
||||
if (input.contains("bos_token")) {
|
||||
bos_token = input["bos_token"].get<std::string>();
|
||||
}
|
||||
if (input.contains("eos_token")) {
|
||||
eos_token = input["eos_token"].get<std::string>();
|
||||
}
|
||||
nlohmann::ordered_json msgs_json = input["messages"];
|
||||
nlohmann::ordered_json tools_json = input["tools"];
|
||||
auto messages = common_chat_msgs_parse_oaicompat(msgs_json);
|
||||
auto tools = common_chat_tools_parse_oaicompat(tools_json);
|
||||
auto output = format_using_common(contents, bos_token, eos_token, messages, tools);
|
||||
std::cout << "\n=== OUTPUT ===\n";
|
||||
std::cout << output << "\n";
|
||||
output_parts = jinja::mk_val<jinja::value_string>(output);
|
||||
|
||||
} else {
|
||||
output_parts = format_using_direct_engine(contents, input);
|
||||
std::cout << "\n=== OUTPUT ===\n";
|
||||
std::cout << output_parts->as_string().str() << "\n";
|
||||
}
|
||||
|
||||
if (!output_path.empty()) {
|
||||
std::ofstream outfile(output_path);
|
||||
if (!outfile) {
|
||||
throw std::runtime_error("Could not open output file: " + output_path);
|
||||
}
|
||||
outfile << output_parts->as_string().str();
|
||||
outfile.close();
|
||||
std::cout << "\n=== OUTPUT WRITTEN TO " << output_path << " ===\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Automated tests for chat templates
|
||||
//
|
||||
|
||||
#define U8C(x) (const char*)(u8##x)
|
||||
|
||||
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
|
||||
@@ -28,7 +278,9 @@ static common_chat_msg simple_msg(const std::string & role, const std::string &
|
||||
return msg;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
int main_automated_tests(void) {
|
||||
// jinja::enable_debug(true);
|
||||
|
||||
std::vector<llama_chat_message> conversation {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
@@ -61,8 +313,8 @@ int main(void) {
|
||||
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
|
||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||
/* .expected_output_jinja= */ "<s>[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||
/* .bos_token= */ "<s>",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "</s>",
|
||||
},
|
||||
{
|
||||
@@ -177,7 +429,7 @@ int main(void) {
|
||||
/* .name= */ "ChatGLM3",
|
||||
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
||||
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
||||
},
|
||||
{
|
||||
/* .name= */ "ChatGLM4",
|
||||
@@ -221,7 +473,7 @@ int main(void) {
|
||||
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
|
||||
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
|
||||
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
|
||||
/* .expected_output_jinja= */ "",
|
||||
/* .expected_output_jinja= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
|
||||
/* .bos_token= */ "",
|
||||
/* .eos_token= */ "</s>",
|
||||
},
|
||||
@@ -308,9 +560,9 @@ int main(void) {
|
||||
assert(res > 0);
|
||||
supported_tmpl.resize(res);
|
||||
res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
|
||||
printf("Built-in chat templates:\n");
|
||||
std::cout << "Built-in chat templates:\n";
|
||||
for (auto tmpl : supported_tmpl) {
|
||||
printf(" %s\n", tmpl);
|
||||
std::cout << " " << tmpl << "\n";
|
||||
}
|
||||
|
||||
// test invalid chat template
|
||||
@@ -319,7 +571,7 @@ int main(void) {
|
||||
const auto add_generation_prompt = true;
|
||||
|
||||
for (const auto & test_case : test_cases) {
|
||||
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
|
||||
std::cout << "\n\n=== " << test_case.name << " ===\n\n";
|
||||
formatted_chat.resize(1024);
|
||||
res = llama_chat_apply_template(
|
||||
test_case.template_str.c_str(),
|
||||
@@ -332,10 +584,10 @@ int main(void) {
|
||||
formatted_chat.resize(res);
|
||||
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||
if (output != test_case.expected_output) {
|
||||
printf("Expected:\n%s\n", test_case.expected_output.c_str());
|
||||
printf("-------------------------\n");
|
||||
printf("Actual:\n%s\n", output.c_str());
|
||||
fflush(stdout);
|
||||
std::cout << "Expected:\n" << test_case.expected_output << "\n";
|
||||
std::cout << "-------------------------\n";
|
||||
std::cout << "Actual:\n" << output << "\n";
|
||||
std::cout.flush();
|
||||
assert(output == test_case.expected_output);
|
||||
}
|
||||
}
|
||||
@@ -348,39 +600,41 @@ int main(void) {
|
||||
if (!test_case.supported_with_jinja) {
|
||||
continue;
|
||||
}
|
||||
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
|
||||
std::cout << "\n\n=== " << test_case.name << " (jinja) ===\n\n";
|
||||
try {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.use_jinja = true;
|
||||
inputs.messages = messages;
|
||||
inputs.add_generation_prompt = add_generation_prompt;
|
||||
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
|
||||
output = normalize_newlines(output);
|
||||
auto output = format_using_common(
|
||||
test_case.template_str,
|
||||
test_case.bos_token,
|
||||
test_case.eos_token,
|
||||
messages);
|
||||
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
|
||||
if (output != expected_output) {
|
||||
printf("Expected:\n%s\n", expected_output.c_str());
|
||||
printf("-------------------------\n");
|
||||
printf("Actual:\n%s\n", output.c_str());
|
||||
fflush(stdout);
|
||||
std::cout << "Template:```\n" << test_case.template_str << "\n```";
|
||||
std::cout << "-------------------------\n";
|
||||
std::cout << "Expected:```\n" << expected_output << "\n```";
|
||||
std::cout << "-------------------------\n";
|
||||
std::cout << "Actual:```\n" << output << "\n```";
|
||||
std::cout.flush();
|
||||
assert(output == expected_output);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
printf("ERROR: %s\n", e.what());
|
||||
std::cerr << "ERROR: " << e.what() << "\n";
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: llama_chat_format_single will be deprecated, remove these tests later
|
||||
|
||||
// test llama_chat_format_single for system message
|
||||
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
|
||||
std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n";
|
||||
std::vector<common_chat_msg> chat2;
|
||||
auto sys_msg = simple_msg("system", "You are a helpful assistant");
|
||||
|
||||
auto fmt_sys = [&](std::string tmpl_str) {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
|
||||
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
|
||||
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||
printf("-------------------------\n");
|
||||
std::cout << "fmt_sys(" << tmpl_str << ") : " << output << "\n";
|
||||
std::cout << "-------------------------\n";
|
||||
return output;
|
||||
};
|
||||
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
|
||||
@@ -397,7 +651,7 @@ int main(void) {
|
||||
|
||||
|
||||
// test llama_chat_format_single for user message
|
||||
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
|
||||
std::cout << "\n\n=== llama_chat_format_single (user message) ===\n\n";
|
||||
chat2.push_back(simple_msg("system", "You are a helpful assistant"));
|
||||
chat2.push_back(simple_msg("user", "Hello"));
|
||||
chat2.push_back(simple_msg("assistant", "I am assistant"));
|
||||
@@ -406,8 +660,8 @@ int main(void) {
|
||||
auto fmt_single = [&](const std::string & tmpl_str) {
|
||||
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
|
||||
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
|
||||
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||
printf("-------------------------\n");
|
||||
std::cout << "fmt_single(" << tmpl_str << ") : " << output << "\n";
|
||||
std::cout << "-------------------------\n";
|
||||
return output;
|
||||
};
|
||||
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
|
||||
@@ -419,7 +673,9 @@ int main(void) {
|
||||
assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
|
||||
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
|
||||
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
||||
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
||||
// assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
||||
|
||||
std::cout << "\nOK: All tests passed successfully.\n";
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user