vocab: fix Gemma4 tokenizer (#21343)
* seems to work * fix case with new line Co-authored-by: sayap <sokann@gmail.com> * gemma 4: fix pre tok regex --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: sayap <sokann@gmail.com>
This commit is contained in:
committed by
GitHub
parent
0c58ba3365
commit
b069b10ab4
@@ -7464,9 +7464,6 @@ class Gemma4Model(Gemma3Model):
|
|||||||
|
|
||||||
assert len(tokens) == vocab.vocab_size
|
assert len(tokens) == vocab.vocab_size
|
||||||
|
|
||||||
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
|
|
||||||
# but I don't have time to dive into them right now;
|
|
||||||
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
|
|
||||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
|
|||||||
+61
-3
@@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_GEMMA4:
|
||||||
|
// Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the
|
||||||
|
// normalizer, then BPE merges run on the whole text without
|
||||||
|
// word-level pre-splitting. We only need to split on newlines
|
||||||
|
// since BPE merge lookup asserts no newlines in tokens.
|
||||||
|
regex_exprs = {
|
||||||
|
"[^\\n]+|[\\n]+",
|
||||||
|
};
|
||||||
|
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
@@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> regex_exprs;
|
std::vector<std::string> regex_exprs;
|
||||||
|
bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8)
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_tokenizer_bpe_session {
|
struct llm_tokenizer_bpe_session {
|
||||||
@@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session {
|
|||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||||
int final_prev_index = -1;
|
int final_prev_index = -1;
|
||||||
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
|
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode);
|
||||||
|
|
||||||
symbols_final.clear();
|
symbols_final.clear();
|
||||||
|
auto tok_pre = vocab.get_pre_type();
|
||||||
|
|
||||||
for (const auto & word : word_collection) {
|
for (const auto & word : word_collection) {
|
||||||
work_queue = llm_bigram_bpe::queue();
|
work_queue = llm_bigram_bpe::queue();
|
||||||
@@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session {
|
|||||||
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
|
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
|
||||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||||
offset = word.size();
|
offset = word.size();
|
||||||
|
} else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) {
|
||||||
|
// fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343
|
||||||
|
auto tok = vocab.text_to_token(word);
|
||||||
|
if (tok != LLAMA_TOKEN_NULL) {
|
||||||
|
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||||
|
offset = word.size();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (offset < word.size()) {
|
while (offset < word.size()) {
|
||||||
@@ -1864,7 +1883,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
special_pad_id = 3; // <|plamo:pad|>
|
special_pad_id = 3; // <|plamo:pad|>
|
||||||
special_mask_id = LLAMA_TOKEN_NULL;
|
special_mask_id = LLAMA_TOKEN_NULL;
|
||||||
} else if (tokenizer_model == "gemma4") {
|
} else if (tokenizer_model == "gemma4") {
|
||||||
type = LLAMA_VOCAB_TYPE_SPM;
|
type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
|
|
||||||
|
// read bpe merges and populate bpe ranks
|
||||||
|
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||||
|
if (merges_keyidx == -1) {
|
||||||
|
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||||
|
for (int i = 0; i < n_merges; i++) {
|
||||||
|
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||||
|
|
||||||
|
std::string first;
|
||||||
|
std::string second;
|
||||||
|
|
||||||
|
const size_t pos = word.find(' ', 1);
|
||||||
|
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
first = word.substr(0, pos);
|
||||||
|
second = word.substr(pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
bpe_ranks.emplace(std::make_pair(first, second), i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// default special tokens (to be read from GGUF)
|
// default special tokens (to be read from GGUF)
|
||||||
special_bos_id = LLAMA_TOKEN_NULL;
|
special_bos_id = LLAMA_TOKEN_NULL;
|
||||||
@@ -1874,7 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
special_pad_id = LLAMA_TOKEN_NULL;
|
special_pad_id = LLAMA_TOKEN_NULL;
|
||||||
special_mask_id = LLAMA_TOKEN_NULL;
|
special_mask_id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
tokenizer_pre = "gemma4";
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||||
}
|
}
|
||||||
@@ -1882,6 +1925,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
// for now, only BPE models have pre-tokenizers
|
// for now, only BPE models have pre-tokenizers
|
||||||
if (type == LLAMA_VOCAB_TYPE_BPE) {
|
if (type == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
add_space_prefix = false;
|
add_space_prefix = false;
|
||||||
|
escape_whitespaces = false;
|
||||||
clean_spaces = true;
|
clean_spaces = true;
|
||||||
if (tokenizer_pre.empty()) {
|
if (tokenizer_pre.empty()) {
|
||||||
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
@@ -1948,6 +1992,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "jais-2") {
|
tokenizer_pre == "jais-2") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "gemma4") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
|
||||||
|
escape_whitespaces = true;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "jina-v1-en" ||
|
tokenizer_pre == "jina-v1-en" ||
|
||||||
tokenizer_pre == "jina-v2-code" ||
|
tokenizer_pre == "jina-v2-code" ||
|
||||||
@@ -3045,6 +3093,10 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
|
||||||
|
if (escape_whitespaces) {
|
||||||
|
llama_escape_whitespace(text);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||||
#endif
|
#endif
|
||||||
@@ -3224,6 +3276,12 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|||||||
return _try_copy(token_text.data(), token_text.size());
|
return _try_copy(token_text.data(), token_text.size());
|
||||||
}
|
}
|
||||||
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
||||||
|
if (escape_whitespaces) {
|
||||||
|
// SPM-style BPE: tokens contain ▁ for spaces
|
||||||
|
std::string result = token_text;
|
||||||
|
llama_unescape_whitespace(result);
|
||||||
|
return _try_copy(result.data(), result.size());
|
||||||
|
}
|
||||||
std::string result = llama_decode_text(token_text);
|
std::string result = llama_decode_text(token_text);
|
||||||
return _try_copy(result.data(), result.size());
|
return _try_copy(result.data(), result.size());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ enum llama_vocab_pre_type {
|
|||||||
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
||||||
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
||||||
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
|
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV;
|
struct LLM_KV;
|
||||||
|
|||||||
+6
-2
@@ -912,7 +912,7 @@ bool unicode_cpt_is_han(uint32_t cpt) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) {
|
||||||
// unicode categories
|
// unicode categories
|
||||||
static const std::map<std::string, int> k_ucat_enum = {
|
static const std::map<std::string, int> k_ucat_enum = {
|
||||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||||
@@ -1099,5 +1099,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
start += offset;
|
start += offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
return unicode_byte_encoding_process(bpe_words);
|
if (byte_encode) {
|
||||||
|
return unicode_byte_encoding_process(bpe_words);
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_words;
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt);
|
|||||||
|
|
||||||
bool unicode_cpt_is_han(uint32_t cpt);
|
bool unicode_cpt_is_han(uint32_t cpt);
|
||||||
|
|
||||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true);
|
||||||
|
|||||||
Reference in New Issue
Block a user