diff --git a/common/debug.cpp b/common/debug.cpp index 0df409a79..102c6924d 100644 --- a/common/debug.cpp +++ b/common/debug.cpp @@ -1,9 +1,38 @@ #include "debug.h" +#include "common.h" #include "log.h" #include +#include #include +#include + +struct common_debug_cb_user_data::impl { + std::vector data; + std::vector tensor_filters; + bool abort_on_nan{false}; +}; + +common_debug_cb_user_data::common_debug_cb_user_data() : pimpl(std::make_unique()) {} +common_debug_cb_user_data::~common_debug_cb_user_data() = default; + +common_debug_cb_user_data::common_debug_cb_user_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan) + : pimpl(std::make_unique()) +{ + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + pimpl->tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + pimpl->abort_on_nan = abort_on_nan; + + params.cb_eval = common_debug_cb_eval; + params.cb_eval_user_data = this; +} static std::string common_ggml_ne_string(const ggml_tensor * t) { std::string str; @@ -47,8 +76,7 @@ static float common_ggml_get_float_value(const uint8_t * data, #define INDENT " " -template -void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { +static void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool abort_on_nan) { GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { @@ -94,7 +122,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n LOG(INDENT "sum = %f\n", sum); } - if constexpr (abort) { + if (abort_on_nan) { if (std::isnan(sum)) { LOG("encountered NaN - aborting\n"); exit(0); @@ -112,8 +140,9 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n * @param user_data user data to pass at each call back * @return true to receive data or continue the graph, false otherwise */ -template bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { - auto * cb_data = (base_callback_data *) user_data; +bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (common_debug_cb_user_data *) user_data; + auto * pimpl = cb_data->pimpl.get(); const struct ggml_tensor * src0 = t->src[0]; const struct ggml_tensor * src1 = t->src[1]; @@ -122,10 +151,10 @@ template bool common_debug_cb_eval(struct ggml_tensor * t, b return true; // Always retrieve data } - bool matches_filter = cb_data->tensor_filters.empty(); + bool matches_filter = pimpl->tensor_filters.empty(); if (!matches_filter) { - for (const auto & filter : cb_data->tensor_filters) { + for (const auto & filter : pimpl->tensor_filters) { if (std::regex_search(t->name, filter)) { matches_filter = true; break; @@ -148,20 +177,14 @@ template bool common_debug_cb_eval(struct ggml_tensor * t, b if (!is_host) { auto n_bytes = ggml_nbytes(t); - cb_data->data.resize(n_bytes); - ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + pimpl->data.resize(n_bytes); + ggml_backend_tensor_get(t, pimpl->data.data(), 0, n_bytes); } if (!ggml_is_quantized(t->type) && matches_filter) { - uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - common_debug_print_tensor(data, t->type, t->ne, t->nb, 3); + uint8_t * data = is_host ? (uint8_t *) t->data : pimpl->data.data(); + common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, pimpl->abort_on_nan); } return true; } - -// Explicit template instantiations -template bool common_debug_cb_eval(ggml_tensor *, bool, void *); -template bool common_debug_cb_eval(ggml_tensor *, bool, void *); -template void common_debug_print_tensor(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t); -template void common_debug_print_tensor(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t); diff --git a/common/debug.h b/common/debug.h index e563b40d6..8b8f8c7aa 100644 --- a/common/debug.h +++ b/common/debug.h @@ -1,43 +1,31 @@ #pragma once -#include "common.h" + +#include #include #include -#include // common debug functions and structs -// Print a tensor's detailed data -// data - the tensor's data in byte format -// type - the tensor's quantization type -// ne - the tensor dimensions array -// nb - the tensor strides array -// n - the number of rows/columns to fully print -template void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n); +struct common_params; // Intended to use as callback for ggml_backend_sched_eval_callback // prints tensors that are processed in the computation graph -// by default prints all tensors, but can be configured by creating a `base_callback_data` instance with -// non-empty filter_patterns. See examples/debug.ccp for possible usage patterns -// The template parameter determines whether an error should be thrown whenever a NaN is encountered +// by default prints all tensors, but can be configured by creating a `common_debug_cb_user_data` instance with +// non-empty filter_patterns. See examples/debug.cpp for possible usage patterns +// `common_debug_cb_user_data` contains `abort_on_nan` flag that determines whether an error should be thrown whenever a NaN is encountered // in a tensor (useful for stopping debug sessions on first erroneous tensor) // The callback data will be passed as the third parameter (user_data) -template bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); -struct base_callback_data { - std::vector data; - std::vector tensor_filters; +bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); - base_callback_data() = default; +struct common_debug_cb_user_data { + struct impl; + std::unique_ptr pimpl; - base_callback_data(common_params & params, const std::vector & filter_patterns) { - for (const auto & pattern : filter_patterns) { - try { - std::string anchored_pattern = "^" + pattern; - tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); - } catch (const std::regex_error & e) { - throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); - } - } - params.cb_eval = common_debug_cb_eval; - params.cb_eval_user_data = this; - } + common_debug_cb_user_data(); + ~common_debug_cb_user_data(); + + common_debug_cb_user_data(const common_debug_cb_user_data &) = delete; + common_debug_cb_user_data & operator=(const common_debug_cb_user_data &) = delete; + + common_debug_cb_user_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan = false); }; diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp index 7ba63b4ff..761e7a2db 100644 --- a/examples/debug/debug.cpp +++ b/examples/debug/debug.cpp @@ -202,10 +202,14 @@ static bool run(llama_context * ctx, const common_params & params) { print_tokenized_prompt(ctx, tokens, params.prompt); if (params.save_logits) { - output_data output {ctx, model, params}; - std::filesystem::path model_path{params.model.path}; - std::string model_name{model_path.stem().string()}; - save_output_data(output, model_name, params.logits_output_dir); + try { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } catch (const std::exception & e) { + LOG_ERR("%s : error saving logits: %s\n", __func__, e.what()); + } } return true; @@ -223,7 +227,7 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - std::optional cb_data; + std::optional cb_data; if (!params.save_logits) { cb_data.emplace(params, params.tensor_filter); } diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 883246845..4ce8d600b 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -3,7 +3,6 @@ #include "debug.h" #include "log.h" #include "llama.h" -#include "llama-cpp.h" #include #include @@ -38,7 +37,7 @@ static bool run(llama_context * ctx, const common_params & params) { int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); - base_callback_data cb_data; + common_debug_cb_user_data cb_data; common_params params; @@ -53,7 +52,7 @@ int main(int argc, char ** argv) { // pass the callback to the backend scheduler // it will be executed for each node during the graph computation - params.cb_eval = common_debug_cb_eval; + params.cb_eval = common_debug_cb_eval; params.cb_eval_user_data = &cb_data; params.warmup = false; diff --git a/tools/mtmd/debug/mtmd-debug.cpp b/tools/mtmd/debug/mtmd-debug.cpp index 6e32b283a..1e41ef793 100644 --- a/tools/mtmd/debug/mtmd-debug.cpp +++ b/tools/mtmd/debug/mtmd-debug.cpp @@ -72,7 +72,7 @@ int main(int argc, char ** argv) { mtmd::context_ptr ctx_mtmd; common_init_result_ptr llama_init; - base_callback_data cb_data; + common_debug_cb_user_data cb_data; llama_init = common_init_from_params(params); { @@ -89,7 +89,7 @@ int main(int argc, char ** argv) { { // always enable debug callback mparams.cb_eval_user_data = &cb_data; - mparams.cb_eval = common_debug_cb_eval; + mparams.cb_eval = common_debug_cb_eval; } ctx_mtmd.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_mtmd.get()) { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index dd72dfb17..be958bd17 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -90,7 +90,7 @@ struct mtmd_cli_context { int n_threads = 1; llama_pos n_past = 0; - base_callback_data cb_data; + common_debug_cb_user_data cb_data; mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { model = llama_init->model(); @@ -145,7 +145,7 @@ struct mtmd_cli_context { mparams.image_max_tokens = params.image_max_tokens; if (std::getenv("MTMD_DEBUG_GRAPH") != nullptr) { mparams.cb_eval_user_data = &cb_data; - mparams.cb_eval = common_debug_cb_eval; + mparams.cb_eval = common_debug_cb_eval; } ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) {