tests : add unit test coverage for llama_tensor_get_type (#20112)
* Add unit test coverage for llama_tensor_get_type * Fix merge conflicts, add more schemas * clang formatter changes * Trailing whitespace * Update name * Start rebase * Updating files with upstream changes prior to rebase * Changes needed from rebase * Update attn_qkv schema, change throw behaviour * Fix merge conflicts * White space * Update with latest changes to state counters * Revert accidental personal CLAUDE.md changes * Change quotation mark * Reuse metadata.name since we have it * Move test-only stuff out of llama-quant.cpp * Hide the regex functionality back in llama-quant.cpp, use a unique pointer to a new struct 'compiled_tensor_type_patterns' which contains the patterns * cont : inital deslop guidelines * Cleanup based on review comments * Continue cleanup * Small cleanup * Manually set proper ordering of tensors, mostly applies to gemma * Formatting * Update tests/test-quant-type-selection.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Fix merge conflicts --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
+56
-15
@@ -125,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
|
||||
}
|
||||
|
||||
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
|
||||
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
|
||||
// by reading the first element as a representative value.
|
||||
if (vtype == GGUF_TYPE_ARRAY) {
|
||||
int32_t elem_type;
|
||||
uint64_t count;
|
||||
if (!r.read_val(elem_type)) {
|
||||
return false;
|
||||
}
|
||||
if (!r.read_val(count)) {
|
||||
return false;
|
||||
}
|
||||
if (count == 0) {
|
||||
return false;
|
||||
}
|
||||
// Read first element, skip the rest
|
||||
if (!gguf_read_uint32_val(r, elem_type, out)) {
|
||||
return false;
|
||||
}
|
||||
for (uint64_t i = 1; i < count; i++) {
|
||||
size_t sz = gguf_val_type_size(elem_type);
|
||||
if (sz == 0) {
|
||||
return false;
|
||||
}
|
||||
if (!r.skip(sz)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT8) {
|
||||
uint8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
@@ -487,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
|
||||
static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cache_path) {
|
||||
const std::string & cache_path,
|
||||
bool verbose) {
|
||||
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
|
||||
|
||||
// Progressive download inspired by RangeView.fetchChunk()
|
||||
@@ -496,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const size_t max_chunk = 64 * 1024 * 1024;
|
||||
|
||||
while (chunk_size <= max_chunk) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
}
|
||||
|
||||
char range_buf[64];
|
||||
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
|
||||
@@ -542,7 +574,8 @@ static std::optional<gguf_remote_model> fetch_or_cached(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cdir,
|
||||
const std::string & repo_part) {
|
||||
const std::string & repo_part,
|
||||
bool verbose) {
|
||||
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
|
||||
|
||||
{
|
||||
@@ -550,20 +583,23 @@ static std::optional<gguf_remote_model> fetch_or_cached(
|
||||
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
|
||||
auto result = gguf_parse_meta(cached);
|
||||
if (result.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs_create_directory_with_parents(cdir);
|
||||
return fetch_and_parse(repo, filename, cache_path);
|
||||
return fetch_and_parse(repo, filename, cache_path, verbose);
|
||||
}
|
||||
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
@@ -573,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return std::nullopt;
|
||||
@@ -588,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
@@ -597,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return std::nullopt;
|
||||
@@ -620,7 +658,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
@@ -631,7 +670,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return nullptr;
|
||||
@@ -659,8 +698,10 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
@@ -668,7 +709,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user