parser: fix structured output bug (#22302)
* fix very stupid structured output bug * Things just cannot be too easy.
This commit is contained in:
committed by
GitHub
parent
361fe72acb
commit
0adede866d
@@ -3,8 +3,12 @@
|
|||||||
Test structured output capability via chat completions endpoint.
|
Test structured output capability via chat completions endpoint.
|
||||||
|
|
||||||
Each test case contains:
|
Each test case contains:
|
||||||
- response_format: OpenAI-compatible response_format specification
|
- response_format: OpenAI-compatible response_format specification.
|
||||||
(json_schema only — llama.cpp does not support json_object)
|
Both "json_schema" and "json_object" are accepted; with
|
||||||
|
"json_object" a schema can be supplied via extra_body.
|
||||||
|
- extra_body (optional): dict of extra top-level request fields merged into
|
||||||
|
the request payload (mirrors the OpenAI SDK's extra_body
|
||||||
|
feature; llama.cpp reads a top-level "json_schema" here).
|
||||||
- messages: initial conversation messages
|
- messages: initial conversation messages
|
||||||
- tools (optional): tool definitions (for mixed tool + structured tests)
|
- tools (optional): tool definitions (for mixed tool + structured tests)
|
||||||
- mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON)
|
- mock_tool_responses (optional): dict mapping tool_name -> callable(arguments) -> str (JSON)
|
||||||
@@ -81,11 +85,14 @@ def print_info(msg):
|
|||||||
_print(f"{DIM}{msg}{RESET}")
|
_print(f"{DIM}{msg}{RESET}")
|
||||||
|
|
||||||
|
|
||||||
def print_schema_note(label, rf):
|
def print_schema_note(label, rf, extra_body=None):
|
||||||
kind = rf.get("type", "?")
|
kind = rf.get("type", "?")
|
||||||
name = ""
|
name = ""
|
||||||
if kind == "json_schema":
|
if kind == "json_schema":
|
||||||
name = rf.get("json_schema", {}).get("name", "")
|
name = rf.get("json_schema", {}).get("name", "")
|
||||||
|
elif kind == "json_object" and extra_body and "json_schema" in extra_body:
|
||||||
|
extra_schema = extra_body["json_schema"] or {}
|
||||||
|
name = extra_schema.get("title") or "extra_body.json_schema"
|
||||||
_print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}"
|
_print(f"{DIM}{MAGENTA} ⟐ response_format [{label}]: {kind}"
|
||||||
f"{(' / ' + name) if name else ''}{RESET}")
|
f"{(' / ' + name) if name else ''}{RESET}")
|
||||||
|
|
||||||
@@ -95,17 +102,20 @@ def print_schema_note(label, rf):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def chat_completion(url, messages, tools=None, response_format=None, stream=False):
|
def chat_completion(url, messages, tools=None, response_format=None, stream=False,
|
||||||
|
extra_body=None):
|
||||||
payload = {
|
payload = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"max_tokens": 4096,
|
"max_tokens": 8192,
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
payload["tools"] = tools
|
payload["tools"] = tools
|
||||||
payload["tool_choice"] = "auto"
|
payload["tool_choice"] = "auto"
|
||||||
if response_format is not None:
|
if response_format is not None:
|
||||||
payload["response_format"] = response_format
|
payload["response_format"] = response_format
|
||||||
|
if extra_body:
|
||||||
|
payload.update(extra_body)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(url, json=payload, stream=stream)
|
response = requests.post(url, json=payload, stream=stream)
|
||||||
@@ -180,7 +190,7 @@ def chat_completion(url, messages, tools=None, response_format=None, stream=Fals
|
|||||||
|
|
||||||
def run_tool_loop(
|
def run_tool_loop(
|
||||||
url, messages, tools, mock_tool_responses, stream, response_format=None,
|
url, messages, tools, mock_tool_responses, stream, response_format=None,
|
||||||
max_turns=6,
|
extra_body=None, max_turns=6,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Drive the tool-call loop. If response_format is provided it is applied to
|
Drive the tool-call loop. If response_format is provided it is applied to
|
||||||
@@ -191,7 +201,8 @@ def run_tool_loop(
|
|||||||
|
|
||||||
for _ in range(max_turns):
|
for _ in range(max_turns):
|
||||||
result = chat_completion(
|
result = chat_completion(
|
||||||
url, msgs, tools=tools, response_format=response_format, stream=stream
|
url, msgs, tools=tools, response_format=response_format, stream=stream,
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
return all_tool_calls, msgs, None
|
return all_tool_calls, msgs, None
|
||||||
@@ -274,7 +285,8 @@ def run_test(url, test_case, stream):
|
|||||||
print_header(f"{name} [{mode}] ({apply_stage})")
|
print_header(f"{name} [{mode}] ({apply_stage})")
|
||||||
|
|
||||||
response_format = test_case["response_format"]
|
response_format = test_case["response_format"]
|
||||||
print_schema_note(apply_stage, response_format)
|
extra_body = test_case.get("extra_body")
|
||||||
|
print_schema_note(apply_stage, response_format, extra_body)
|
||||||
|
|
||||||
tools = test_case.get("tools")
|
tools = test_case.get("tools")
|
||||||
mocks = test_case.get("mock_tool_responses") or {}
|
mocks = test_case.get("mock_tool_responses") or {}
|
||||||
@@ -290,6 +302,7 @@ def run_test(url, test_case, stream):
|
|||||||
mock_tool_responses=mocks,
|
mock_tool_responses=mocks,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
elif apply_stage == "after_tools":
|
elif apply_stage == "after_tools":
|
||||||
# Phase 1: plain tool loop, no response_format applied yet.
|
# Phase 1: plain tool loop, no response_format applied yet.
|
||||||
@@ -314,7 +327,8 @@ def run_test(url, test_case, stream):
|
|||||||
# model focuses on producing the schema-constrained answer.
|
# model focuses on producing the schema-constrained answer.
|
||||||
_print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}")
|
_print(f"\n{DIM}{MAGENTA} ⟐ follow-up turn with response_format applied{RESET}")
|
||||||
result = chat_completion(
|
result = chat_completion(
|
||||||
url, msgs, tools=None, response_format=response_format, stream=stream
|
url, msgs, tools=None, response_format=response_format, stream=stream,
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
final_content = result["content"] if result else None
|
final_content = result["content"] if result else None
|
||||||
else:
|
else:
|
||||||
@@ -481,6 +495,51 @@ def _validate_sentiment(parsed):
|
|||||||
return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}"
|
return True, f"sentiment={parsed['sentiment']} conf={conf} kws={kws}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Test: json_object + extra_body.json_schema (always) ----
|
||||||
|
#
|
||||||
|
# Exercises the llama.cpp-specific path where the OpenAI SDK would send
|
||||||
|
# response_format={"type": "json_object"} and tunnel the schema through
|
||||||
|
# extra_body.json_schema (which becomes a top-level "json_schema" field on
|
||||||
|
# the request body).
|
||||||
|
|
||||||
|
_PRODUCT_JSON_OBJECT_SCHEMA = {
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"$id": "https://example.com/product.schema.json",
|
||||||
|
"title": "Product",
|
||||||
|
"description": "A product in the catalog",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
PRODUCT_JSON_OBJECT_TEST_CASE = {
|
||||||
|
"name": "json_object response_format with extra_body json_schema",
|
||||||
|
"response_format": {"type": "json_object"},
|
||||||
|
"extra_body": {"json_schema": _PRODUCT_JSON_OBJECT_SCHEMA},
|
||||||
|
"apply_stage": "always",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Extract structured data from the provided text according to the "
|
||||||
|
"JSON schema. Return only valid JSON matching the schema exactly."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Product: Wireless Headphones, ID: 101, In Stock: Yes",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"validate": lambda parsed, tcs, raw: _validate_product_json_object(parsed),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_product_json_object(parsed):
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
return False, f"expected JSON object, got {type(parsed).__name__}: {parsed!r}"
|
||||||
|
if not parsed:
|
||||||
|
return False, f"expected non-empty object, got {parsed!r}"
|
||||||
|
return True, f"product object with {len(parsed)} field(s): {sorted(parsed.keys())}"
|
||||||
|
|
||||||
|
|
||||||
# ---- Test 3: Nested recipe schema (always) ----
|
# ---- Test 3: Nested recipe schema (always) ----
|
||||||
|
|
||||||
_RECIPE_SCHEMA = {
|
_RECIPE_SCHEMA = {
|
||||||
@@ -915,6 +974,7 @@ def _validate_country_report(parsed, tcs):
|
|||||||
ALL_TEST_CASES = [
|
ALL_TEST_CASES = [
|
||||||
BOOK_TEST_CASE,
|
BOOK_TEST_CASE,
|
||||||
SENTIMENT_TEST_CASE,
|
SENTIMENT_TEST_CASE,
|
||||||
|
PRODUCT_JSON_OBJECT_TEST_CASE,
|
||||||
RECIPE_TEST_CASE,
|
RECIPE_TEST_CASE,
|
||||||
SHOP_COMPARISON_TEST_CASE,
|
SHOP_COMPARISON_TEST_CASE,
|
||||||
COUNTRY_REPORT_TEST_CASE,
|
COUNTRY_REPORT_TEST_CASE,
|
||||||
|
|||||||
@@ -947,7 +947,9 @@ json oaicompat_chat_params_parse(
|
|||||||
json response_format = json_value(body, "response_format", json::object());
|
json response_format = json_value(body, "response_format", json::object());
|
||||||
std::string response_type = json_value(response_format, "type", std::string());
|
std::string response_type = json_value(response_format, "type", std::string());
|
||||||
if (response_type == "json_object") {
|
if (response_type == "json_object") {
|
||||||
json_schema = json_value(response_format, "schema", json::object());
|
if (response_format.contains("schema") || json_schema.empty()) {
|
||||||
|
json_schema = json_value(response_format, "schema", json::object());
|
||||||
|
}
|
||||||
} else if (response_type == "json_schema") {
|
} else if (response_type == "json_schema") {
|
||||||
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
|
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
|
||||||
json_schema = json_value(schema_wrapper, "schema", json::object());
|
json_schema = json_value(schema_wrapper, "schema", json::object());
|
||||||
|
|||||||
Reference in New Issue
Block a user