common/grammar : replace problematic backtracking regex [\s\S]* (#18342)
* grammar : add support for std::regex_search() with trigger patterns * common : update hermes2 pro trigger to search instead of match * common : use regex_search with anchoring for partial matching * common : adjust regex partial tests to use new pattern * grammar : check pattern directly instead of adding a type * common : adjust existing patterns to match new semantics
This commit is contained in:
+4
-4
@@ -2065,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||||||
// Trigger on tool calls that appear in the commentary channel
|
// Trigger on tool calls that appear in the commentary channel
|
||||||
data.grammar_triggers.push_back({
|
data.grammar_triggers.push_back({
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||||
"<\\|channel\\|>(commentary|analysis) to"
|
"<\\|channel\\|>(?:commentary|analysis) to"
|
||||||
});
|
});
|
||||||
|
|
||||||
// Trigger tool calls that appear in the role section, either at the
|
// Trigger tool calls that appear in the role section, either at the
|
||||||
@@ -2398,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|||||||
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
|
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
|
||||||
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
||||||
data.grammar_triggers.push_back({
|
data.grammar_triggers.push_back({
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||||
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
||||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||||
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
|
std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + (
|
||||||
"\\s*("
|
"\\s*("
|
||||||
"(?:<tool_call>"
|
"(?:<tool_call>"
|
||||||
"|<function"
|
"|<function"
|
||||||
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
|
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
|
||||||
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
|
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
|
||||||
")"
|
")"
|
||||||
")[\\s\\S]*"
|
")"
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
|||||||
+13
-13
@@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||||
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
||||||
auto group = srmatch[1].str();
|
auto group = srmatch[1].str();
|
||||||
if (group.length() != 0) {
|
if (group.length() != 0) {
|
||||||
auto it = srmatch[1].second.base();
|
auto it = srmatch[1].second.base();
|
||||||
@@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
|
|||||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||||
|
|
||||||
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
||||||
- /a|b/ -> (a|b).*
|
- /a|b/ -> ^(a|b)
|
||||||
- /a*?/ -> error, could match ""
|
- /a*?/ -> error, could match ""
|
||||||
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
||||||
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
||||||
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
||||||
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
||||||
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
||||||
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
||||||
|
|
||||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
||||||
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
||||||
*/
|
*/
|
||||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||||
auto it = pattern.begin();
|
auto it = pattern.begin();
|
||||||
@@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
||||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||||
std::vector<std::string> res_alts;
|
std::vector<std::string> res_alts;
|
||||||
@@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
|||||||
throw std::runtime_error("Unmatched '(' in pattern");
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
}
|
}
|
||||||
|
|
||||||
return "(" + res + ")[\\s\\S]*";
|
return "^(" + res + ")";
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-8
@@ -179,24 +179,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
} else {
|
} else {
|
||||||
std::vector<std::string> trigger_patterns;
|
std::vector<std::string> trigger_patterns;
|
||||||
std::vector<std::string> patterns_anywhere;
|
|
||||||
std::vector<llama_token> trigger_tokens;
|
std::vector<llama_token> trigger_tokens;
|
||||||
for (const auto & trigger : params.grammar_triggers) {
|
for (const auto & trigger : params.grammar_triggers) {
|
||||||
switch (trigger.type) {
|
switch (trigger.type) {
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||||
{
|
{
|
||||||
const auto & word = trigger.value;
|
const auto & word = trigger.value;
|
||||||
patterns_anywhere.push_back(regex_escape(word));
|
trigger_patterns.push_back(regex_escape(word));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||||
{
|
{
|
||||||
patterns_anywhere.push_back(trigger.value);
|
trigger_patterns.push_back(trigger.value);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||||
{
|
{
|
||||||
trigger_patterns.push_back(trigger.value);
|
const auto & pattern = trigger.value;
|
||||||
|
std::string anchored = "^$";
|
||||||
|
if (!pattern.empty()) {
|
||||||
|
anchored = (pattern.front() != '^' ? "^" : "")
|
||||||
|
+ pattern
|
||||||
|
+ (pattern.back() != '$' ? "$" : "");
|
||||||
|
}
|
||||||
|
trigger_patterns.push_back(anchored);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||||
@@ -210,10 +216,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!patterns_anywhere.empty()) {
|
|
||||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const char *> trigger_patterns_c;
|
std::vector<const char *> trigger_patterns_c;
|
||||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||||
for (const auto & regex : trigger_patterns) {
|
for (const auto & regex : trigger_patterns) {
|
||||||
|
|||||||
+40
-13
@@ -369,6 +369,44 @@ static void print_rule(
|
|||||||
fprintf(file, "\n");
|
fprintf(file, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Regex utilities
|
||||||
|
//
|
||||||
|
|
||||||
|
size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
|
||||||
|
auto find_start_pos = [](const std::smatch & match) {
|
||||||
|
// get from the first matched capturing group to the end of the string
|
||||||
|
size_t start = std::string::npos;
|
||||||
|
for (auto i = 1u; i < match.size(); i++) {
|
||||||
|
if (match.length(i) > 0) {
|
||||||
|
start = match.position(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
start = match.position(0);
|
||||||
|
}
|
||||||
|
return start;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
|
||||||
|
// match against the entire input
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_match(input, match, regex)) {
|
||||||
|
return find_start_pos(match);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// search anywhere
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(input, match, regex)) {
|
||||||
|
return find_start_pos(match);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// implementation
|
// implementation
|
||||||
//
|
//
|
||||||
@@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||||
grammar.trigger_buffer += piece;
|
grammar.trigger_buffer += piece;
|
||||||
|
|
||||||
std::smatch match;
|
|
||||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
auto start = trigger_pattern.find(grammar.trigger_buffer);
|
||||||
|
if (start != std::string::npos) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
// get from the first matched capturing group to the end of the string
|
|
||||||
size_t start = std::string::npos;
|
|
||||||
for (auto i = 1u; i < match.size(); i++) {
|
|
||||||
if (match.length(i) > 0) {
|
|
||||||
start = match.position(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (start == std::string::npos) {
|
|
||||||
start = match.position(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// replay tokens that overlap with [start, end)
|
// replay tokens that overlap with [start, end)
|
||||||
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
|
||||||
|
|||||||
@@ -119,6 +119,8 @@ struct llama_grammar_parser {
|
|||||||
struct llama_grammar_trigger_pattern {
|
struct llama_grammar_trigger_pattern {
|
||||||
std::string pattern;
|
std::string pattern;
|
||||||
std::regex regex;
|
std::regex regex;
|
||||||
|
|
||||||
|
size_t find(const std::string & input) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
|
|||||||
@@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() {
|
|||||||
printf("[%s]\n", __func__);
|
printf("[%s]\n", __func__);
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:c)?b)?a)[\\s\\S]*",
|
"^((?:(?:c)?b)?a)",
|
||||||
regex_to_reversed_partial_regex("abc"));
|
regex_to_reversed_partial_regex("abc"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"(a+)[\\s\\S]*",
|
"^(a+)",
|
||||||
regex_to_reversed_partial_regex("a+"));
|
regex_to_reversed_partial_regex("a+"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"(a*)[\\s\\S]*",
|
"^(a*)",
|
||||||
regex_to_reversed_partial_regex("a*"));
|
regex_to_reversed_partial_regex("a*"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"(a?)[\\s\\S]*",
|
"^(a?)",
|
||||||
regex_to_reversed_partial_regex("a?"));
|
regex_to_reversed_partial_regex("a?"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"([a-z])[\\s\\S]*",
|
"^([a-z])",
|
||||||
regex_to_reversed_partial_regex("[a-z]"));
|
regex_to_reversed_partial_regex("[a-z]"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:\\w+)?[a-z])[\\s\\S]*",
|
"^((?:\\w+)?[a-z])",
|
||||||
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
||||||
|
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:a|b))[\\s\\S]*",
|
"^((?:a|b))",
|
||||||
regex_to_reversed_partial_regex("(?:a|b)"));
|
regex_to_reversed_partial_regex("(?:a|b)"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
|
"^((?:(?:(?:d)?c)?b)?a)",
|
||||||
regex_to_reversed_partial_regex("abcd"));
|
regex_to_reversed_partial_regex("abcd"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
|
"^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
|
||||||
regex_to_reversed_partial_regex("a*b"));
|
regex_to_reversed_partial_regex("a*b"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:b)?a)?.*)[\\s\\S]*",
|
"^((?:(?:b)?a)?.*)",
|
||||||
regex_to_reversed_partial_regex(".*?ab"));
|
regex_to_reversed_partial_regex(".*?ab"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:b)?.*)?a)[\\s\\S]*",
|
"^((?:(?:b)?.*)?a)",
|
||||||
regex_to_reversed_partial_regex("a.*?b"));
|
regex_to_reversed_partial_regex("a.*?b"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
|
"^((?:(?:d)?(?:(?:c)?b))?a)",
|
||||||
regex_to_reversed_partial_regex("a(bc)d"));
|
regex_to_reversed_partial_regex("a(bc)d"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
|
"^((?:(?:(?:c)?b|(?:e)?d))?a)",
|
||||||
regex_to_reversed_partial_regex("a(bc|de)"));
|
regex_to_reversed_partial_regex("a(bc|de)"));
|
||||||
assert_equals<std::string>(
|
assert_equals<std::string>(
|
||||||
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
|
"^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
|
||||||
regex_to_reversed_partial_regex("ab{2,4}c"));
|
regex_to_reversed_partial_regex("ab{2,4}c"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user