common/grammar : replace problematic backtracking regex [\s\S]* (#18342)

* grammar : add support for std::regex_search() with trigger patterns

* common : update hermes2 pro trigger to search instead of match

* common : use regex_search with anchoring for partial matching

* common : adjust regex partial tests to use new pattern

* grammar : check pattern directly instead of adding a type

* common : adjust existing patterns to match new semantics
This commit is contained in:
Aldehir Rojas
2026-01-03 16:02:43 -06:00
committed by GitHub
parent c69c7ebc90
commit cef1d23c5a
6 changed files with 83 additions and 52 deletions
+4 -4
View File
@@ -2065,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Trigger on tool calls that appear in the commentary channel // Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({ data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|channel\\|>(commentary|analysis) to" "<\\|channel\\|>(?:commentary|analysis) to"
}); });
// Trigger tool calls that appear in the role section, either at the // Trigger tool calls that appear in the role section, either at the
@@ -2398,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
data.grammar_triggers.push_back({ data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
// If thinking_forced_open, then we capture the </think> tag in the grammar, // If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + ( std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + (
"\\s*(" "\\s*("
"(?:<tool_call>" "(?:<tool_call>"
"|<function" "|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?" "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
")" ")"
")[\\s\\S]*" ")"
), ),
}); });
data.preserved_tokens = { data.preserved_tokens = {
+13 -13
View File
@@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
return res; return res;
} }
std::match_results<std::string::const_reverse_iterator> srmatch; std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str(); auto group = srmatch[1].str();
if (group.length() != 0) { if (group.length() != 0) {
auto it = srmatch[1].second.base(); auto it = srmatch[1].second.base();
@@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
to see if a string ends with a partial regex match, but but it's not in std::regex yet. to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
- /a|b/ -> (a|b).* - /a|b/ -> ^(a|b)
- /a*?/ -> error, could match "" - /a*?/ -> error, could match ""
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
- /.*?ab/ -> ((?:b)?a).* (merge .*) - /.*?ab/ -> ^((?:b)?a) (omit .*)
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/ */
std::string regex_to_reversed_partial_regex(const std::string & pattern) { std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin(); auto it = pattern.begin();
@@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
} }
} }
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function. // We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts; std::vector<std::string> res_alts;
@@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
throw std::runtime_error("Unmatched '(' in pattern"); throw std::runtime_error("Unmatched '(' in pattern");
} }
return "(" + res + ")[\\s\\S]*"; return "^(" + res + ")";
} }
+10 -8
View File
@@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
} else { } else {
std::vector<std::string> trigger_patterns; std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens; std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) { for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) { switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{ {
const auto & word = trigger.value; const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word)); trigger_patterns.push_back(regex_escape(word));
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{ {
patterns_anywhere.push_back(trigger.value); trigger_patterns.push_back(trigger.value);
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{ {
trigger_patterns.push_back(trigger.value); const auto & pattern = trigger.value;
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
} }
} }
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c; std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size()); trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) { for (const auto & regex : trigger_patterns) {
+40 -13
View File
@@ -369,6 +369,44 @@ static void print_rule(
fprintf(file, "\n"); fprintf(file, "\n");
} }
//
// Regex utilities
//
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
auto find_start_pos = [](const std::smatch & match) {
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
return start;
};
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
// match against the entire input
std::smatch match;
if (std::regex_match(input, match, regex)) {
return find_start_pos(match);
}
}
// search anywhere
std::smatch match;
if (std::regex_search(input, match, regex)) {
return find_start_pos(match);
}
return std::string::npos;
}
// //
// implementation // implementation
// //
@@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece; grammar.trigger_buffer += piece;
std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) { for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { auto start = trigger_pattern.find(grammar.trigger_buffer);
if (start != std::string::npos) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
// replay tokens that overlap with [start, end) // replay tokens that overlap with [start, end)
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
+2
View File
@@ -119,6 +119,8 @@ struct llama_grammar_parser {
struct llama_grammar_trigger_pattern { struct llama_grammar_trigger_pattern {
std::string pattern; std::string pattern;
std::regex regex; std::regex regex;
size_t find(const std::string & input) const;
}; };
struct llama_grammar { struct llama_grammar {
+14 -14
View File
@@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__); printf("[%s]\n", __func__);
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*", "^((?:(?:c)?b)?a)",
regex_to_reversed_partial_regex("abc")); regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>( assert_equals<std::string>(
"(a+)[\\s\\S]*", "^(a+)",
regex_to_reversed_partial_regex("a+")); regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>( assert_equals<std::string>(
"(a*)[\\s\\S]*", "^(a*)",
regex_to_reversed_partial_regex("a*")); regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>( assert_equals<std::string>(
"(a?)[\\s\\S]*", "^(a?)",
regex_to_reversed_partial_regex("a?")); regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>( assert_equals<std::string>(
"([a-z])[\\s\\S]*", "^([a-z])",
regex_to_reversed_partial_regex("[a-z]")); regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*", "^((?:\\w+)?[a-z])",
regex_to_reversed_partial_regex("[a-z]\\w+")); regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:a|b))[\\s\\S]*", "^((?:a|b))",
regex_to_reversed_partial_regex("(?:a|b)")); regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*", "^((?:(?:(?:d)?c)?b)?a)",
regex_to_reversed_partial_regex("abcd")); regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ?? "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b")); regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*", "^((?:(?:b)?a)?.*)",
regex_to_reversed_partial_regex(".*?ab")); regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*", "^((?:(?:b)?.*)?a)",
regex_to_reversed_partial_regex("a.*?b")); regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*", "^((?:(?:d)?(?:(?:c)?b))?a)",
regex_to_reversed_partial_regex("a(bc)d")); regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*", "^((?:(?:(?:c)?b|(?:e)?d))?a)",
regex_to_reversed_partial_regex("a(bc|de)")); regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>( assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*", "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
regex_to_reversed_partial_regex("ab{2,4}c")); regex_to_reversed_partial_regex("ab{2,4}c"));
} }