common/chat, server: refactor, move all conversion functions to common, add tests (#20690)
* Refactor conversion functions
This commit is contained in:
committed by
GitHub
parent
ca7f7b7b94
commit
134d6e54d4
@@ -155,6 +155,8 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
llama_build_and_test(test-grammar-integration.cpp)
|
||||
llama_build_and_test(test-llama-grammar.cpp)
|
||||
llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
target_include_directories(test-chat PRIVATE ${PROJECT_SOURCE_DIR}/tools/server)
|
||||
target_link_libraries(test-chat PRIVATE server-context)
|
||||
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
|
||||
+128
-2
@@ -7,6 +7,7 @@
|
||||
//
|
||||
#include "../src/llama-grammar.h"
|
||||
#include "../src/unicode.h"
|
||||
#include "../tools/server/server-chat.h"
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
@@ -1514,6 +1515,117 @@ static void test_tools_oaicompat_json_conversion() {
|
||||
common_chat_tools_to_json_oaicompat({ special_function_tool }).dump(2));
|
||||
}
|
||||
|
||||
static void test_convert_responses_to_chatcmpl() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Test basic conversion with input messages (user/assistant alternating)
|
||||
{
|
||||
json input = json::parse(R"({
|
||||
"input": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "hi wassup"
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": "Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?"
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
],
|
||||
"model": "gpt-5-mini",
|
||||
"stream": false,
|
||||
"text": {},
|
||||
"reasoning": {
|
||||
"effort": "medium"
|
||||
}
|
||||
})");
|
||||
|
||||
json result = server_chat_convert_responses_to_chatcmpl(input);
|
||||
|
||||
// Verify messages were converted correctly
|
||||
assert_equals(true, result.contains("messages"));
|
||||
assert_equals(true, result.at("messages").is_array());
|
||||
assert_equals((size_t)3, result.at("messages").size());
|
||||
|
||||
// Check first message (user)
|
||||
const auto & msg0 = result.at("messages")[0];
|
||||
assert_equals(std::string("user"), msg0.at("role").get<std::string>());
|
||||
assert_equals(true, msg0.at("content").is_array());
|
||||
assert_equals(std::string("text"), msg0.at("content")[0].at("type").get<std::string>());
|
||||
assert_equals(std::string("hi wassup"), msg0.at("content")[0].at("text").get<std::string>());
|
||||
|
||||
// Check second message (assistant)
|
||||
const auto & msg1 = result.at("messages")[1];
|
||||
assert_equals(std::string("assistant"), msg1.at("role").get<std::string>());
|
||||
assert_equals(true, msg1.at("content").is_array());
|
||||
assert_equals(std::string("text"), msg1.at("content")[0].at("type").get<std::string>());
|
||||
assert_equals(std::string("Hey! 👋 Not much, just here ready to chat. What's up with you? Anything I can help you with today?"), msg1.at("content")[0].at("text").get<std::string>());
|
||||
|
||||
// Check third message (user)
|
||||
const auto & msg2 = result.at("messages")[2];
|
||||
assert_equals(std::string("user"), msg2.at("role").get<std::string>());
|
||||
assert_equals(true, msg2.at("content").is_array());
|
||||
assert_equals(std::string("text"), msg2.at("content")[0].at("type").get<std::string>());
|
||||
assert_equals(std::string("hi"), msg2.at("content")[0].at("text").get<std::string>());
|
||||
|
||||
// Verify other fields preserved
|
||||
assert_equals(std::string("gpt-5-mini"), result.at("model").get<std::string>());
|
||||
assert_equals(false, result.at("stream").get<bool>());
|
||||
}
|
||||
|
||||
// Test string input
|
||||
{
|
||||
json input = json::parse(R"({
|
||||
"input": "Hello, world!",
|
||||
"model": "test-model"
|
||||
})");
|
||||
|
||||
json result = server_chat_convert_responses_to_chatcmpl(input);
|
||||
|
||||
assert_equals((size_t)1, result.at("messages").size());
|
||||
const auto & msg = result.at("messages")[0];
|
||||
assert_equals(std::string("user"), msg.at("role").get<std::string>());
|
||||
assert_equals(std::string("Hello, world!"), msg.at("content").get<std::string>());
|
||||
}
|
||||
|
||||
// Test with instructions (system message)
|
||||
{
|
||||
json input = json::parse(R"({
|
||||
"input": "Hello",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
"model": "test-model"
|
||||
})");
|
||||
|
||||
json result = server_chat_convert_responses_to_chatcmpl(input);
|
||||
|
||||
assert_equals((size_t)2, result.at("messages").size());
|
||||
const auto & sys_msg = result.at("messages")[0];
|
||||
assert_equals(std::string("system"), sys_msg.at("role").get<std::string>());
|
||||
assert_equals(std::string("You are a helpful assistant."), sys_msg.at("content").get<std::string>());
|
||||
}
|
||||
|
||||
// Test with max_output_tokens conversion
|
||||
{
|
||||
json input = json::parse(R"({
|
||||
"input": "Hello",
|
||||
"model": "test-model",
|
||||
"max_output_tokens": 100
|
||||
})");
|
||||
|
||||
json result = server_chat_convert_responses_to_chatcmpl(input);
|
||||
|
||||
assert_equals(true, result.contains("max_tokens"));
|
||||
assert_equals(false, result.contains("max_output_tokens"));
|
||||
assert_equals(100, result.at("max_tokens").get<int>());
|
||||
}
|
||||
}
|
||||
|
||||
static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
@@ -4291,7 +4403,7 @@ int main(int argc, char ** argv) {
|
||||
bool detailed_debug = false;
|
||||
bool only_run_filtered = false;
|
||||
|
||||
// Check for --template flag
|
||||
// Check for --template and --detailed flags
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
if (arg == "--template" && i + 1 < argc) {
|
||||
@@ -4316,7 +4428,20 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
if (argc > 1) {
|
||||
// Check if any argument is a .jinja file (for template format detection mode)
|
||||
bool has_jinja_files = false;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
if (arg == "--detailed") {
|
||||
continue;
|
||||
}
|
||||
if (arg.size() >= 6 && arg.rfind(".jinja") == arg.size() - 6) {
|
||||
has_jinja_files = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_jinja_files) {
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
@@ -4349,6 +4474,7 @@ int main(int argc, char ** argv) {
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
test_reka_edge_common_path();
|
||||
test_template_output_peg_parsers(detailed_debug);
|
||||
|
||||
Reference in New Issue
Block a user