tests : add unit test coverage for llama_tensor_get_type (#20112)

* Add unit test coverage for llama_tensor_get_type

* Fix merge conflicts, add more schemas

* clang formatter changes

* Trailing whitespace

* Update name

* Start rebase

* Updating files with upstream changes prior to rebase

* Changes needed from rebase

* Update attn_qkv schema, change throw behaviour

* Fix merge conflicts

* White space

* Update with latest changes to state counters

* Revert accidental personal CLAUDE.md changes

* Change quotation mark

* Reuse metadata.name since we have it

* Move test-only stuff out of llama-quant.cpp

* Hide the regex functionality back in llama-quant.cpp, use a unique pointer to a new struct 'compiled_tensor_type_patterns' which contains the patterns

* cont : inital deslop guidelines

* Cleanup based on review comments

* Continue cleanup

* Small cleanup

* Manually set proper ordering of tensors, mostly applies to gemma

* Formatting

* Update tests/test-quant-type-selection.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix merge conflicts

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Bartowski
2026-04-02 16:53:58 -04:00
committed by GitHub
parent a1cfb64530
commit 7992aa7c8e
20 changed files with 35301 additions and 51 deletions
+56 -15
View File
@@ -125,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
}
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
// by reading the first element as a representative value.
if (vtype == GGUF_TYPE_ARRAY) {
int32_t elem_type;
uint64_t count;
if (!r.read_val(elem_type)) {
return false;
}
if (!r.read_val(count)) {
return false;
}
if (count == 0) {
return false;
}
// Read first element, skip the rest
if (!gguf_read_uint32_val(r, elem_type, out)) {
return false;
}
for (uint64_t i = 1; i < count; i++) {
size_t sz = gguf_val_type_size(elem_type);
if (sz == 0) {
return false;
}
if (!r.skip(sz)) {
return false;
}
}
return true;
}
if (vtype == GGUF_TYPE_UINT8) {
uint8_t v;
if (!r.read_val(v)) {
@@ -487,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
static std::optional<gguf_remote_model> fetch_and_parse(
const std::string & repo,
const std::string & filename,
const std::string & cache_path) {
const std::string & cache_path,
bool verbose) {
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
// Progressive download inspired by RangeView.fetchChunk()
@@ -496,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
const size_t max_chunk = 64 * 1024 * 1024;
while (chunk_size <= max_chunk) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
}
char range_buf[64];
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
@@ -542,7 +574,8 @@ static std::optional<gguf_remote_model> fetch_or_cached(
const std::string & repo,
const std::string & filename,
const std::string & cdir,
const std::string & repo_part) {
const std::string & repo_part,
bool verbose) {
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
{
@@ -550,20 +583,23 @@ static std::optional<gguf_remote_model> fetch_or_cached(
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
auto result = gguf_parse_meta(cached);
if (result.has_value()) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
}
return result;
}
}
}
fs_create_directory_with_parents(cdir);
return fetch_and_parse(repo, filename, cache_path);
return fetch_and_parse(repo, filename, cache_path, verbose);
}
std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
@@ -573,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return std::nullopt;
@@ -588,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
@@ -597,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return std::nullopt;
@@ -620,7 +658,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
gguf_context_ptr gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
@@ -631,7 +670,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
return nullptr;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return nullptr;
@@ -659,8 +698,10 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
return nullptr;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
@@ -668,7 +709,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return nullptr;