llama.android : Rewrite Android binding (w/o cpu_features dep) (#17413)
* UI: implement basic UI components
* util: implement performance monitor; wrap it with a viewmodel
* util: implement user preferences utility
* UI: implement core flow's screens
* UI: add a new MainActivity; update manifest
* [WIP] DI: implement simple local vm factory provider
* UI: disable triggering drawer via gesture; enable alert dialog on back navigation inside conversation and benchmark
* UI: allow drawer's gesture control only on Home and Settings screens; enable alert dialog on back navigation inside conversation and benchmark
* UI: split a nested parent settings screen into separate child settings screens
* UI: polish system prompt setup UI
* Deps: bump Kotlin plugin; introduce KSP; apply in :app subproject
* DB: setup Room database
* data: introduce repo for System Prompt; flow data from Room to VM
* bugfix: properly handle user's quitting conversation screen while tokens in generation
* UI: rename `ModeSelection` to `ModelLoading` for better clarity
* UI: update app name to be more Arm
* UI: polish conversation screen
* data: code polish
* UI: code polish
* bugfix: handle user quitting on model loading
* UI: locks user in alert dialog when model is unloading
* vm: replace token metrics stubs with actual implementation
* UI: refactor top app bars
* nit: combine temperatureMetrics and useFahrenheit
* DI: introduce Hilt plugin + processor + lib dependencies
* DI: make app Hilt injectable
* DI: make viewmodels Hilt injectable
* DI: replace manual DI with Hilt DI
* UI: optimize AppContent's composing
* bugfix: wait for model to load before navigating to benchmark screen; use NavigationActions instead of raw navController
* UI: navigation with more natural animated transitions
* DI: Optimize AppModule
* Feature: Introduce ModelRepository and ModelsManagementViewModel; update AppModule
* UI: polish UI for ModelsManagementScreen; inject ModelsManagementVieModel
* DI: abstract the protocol of SystemPromptRepository; update AppModule
* data: [WIP] prepare for ModelRepository refactor & impl
* data: introduce Model entity and DAO; update DI module
* UI: replace Models Management screen's stubbing with instrumentation
* UI: polish sort order menu
* data: import local model with file picker
* bugfix: use List instead of Collection for ModelDao's deletion
* data: add a util file for extracting file name & size and model metadata
* UI: enrich ModelManagementState; extract filename to show correct importing UI
* UI: implement multiple models deletion; update Models Management screen
* UI: handle back navigation when user is in multi-selection mode
* util: extract file size formatting into ModelUtils
* UI: add a confirmation step when user picks a file; refactor model import overlay into AlertDialog
* UI: extract a shared ModelCard component
* UI: replace model selection screen's data stubbing; add empty view
* nit: tidy SystemPromptViewModel
* Util: split FileUtils from ModelUtils; extract copy methods into FileUtils
* data: pass through getModelById from ModelDao into ModelRepository
* core: extract conversation and benchmark logics into InferenceManager; add logs and missing state updates in stub InferenceEngine
* vm: split mono MainViewModel into separate individual ViewModels
* vm: merge SystemPromptViewModel into ModelLoadingViewModel
* core: break down InferenceManager due to Interface Segregation Principle
* UI: show model card in Model Loading screen
* UI: show model card in Conversation screen
* UI: unify Model Card components
* core: swap in LLamaAndroid and mark stub engine for testing only
* data: allow canceling the ongoing model import
* UI: update UI ongoing model import's cancellation
* LLama: update engine state after handling the cancellation of sendUserPrompt
* VM: handle the cancellation of ongoing token generation
* LLama: refactor loadModel by splitting the system prompt setting into a separate method
* feature: check for available space before copying local model
* UI: centralize the AppScaffold and modularize its configs
* UI: refactor BottomBarConfig.ModelsManagement APIs
* UI: combine TopBarConfig and BottomBarConfig into each route's ScaffoldConfig
* UI: replace ugly optional as casts in AppScaffold with extension functions
* UI: fix the typo `totalGb` in `StorageMetrics`
* UI: remove code duplication in sort menu
* LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages
* UI: refactor back handling by removing centralized BackHandlerSetup and UnloadModelConfirmationDialog from AppContent
* UI: implement BenchmarkScreen's individual back handling
* LLama: add a new Initializing state; ; add two extension properties; rename LibraryLoaded state to Initialized
* UI: Introduce an abstract ViewModel to handle additional model unloading logics
* UI: expose a single facade ModelUnloadDialogHandler; move UnloadModelState into ModelUnloadingViewModel.kt
* UI: migrate ModelLoadingScreen onto ModelLoadingViewModel; update & refine ModelLoadingScreen
* UI: migrate ConversationViewModel onto ModelLoadingViewModel; update & refine ConversationScreen
* nit: extract app name into a constant value; remove unused onBackPressed callbacks
* UI: update AppContent to pass in correct navigation callbacks
* nit: polish ModelLoadingScreen UI
* core: throw Exception instead of returning null if model fails to load
* navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens
* gguf: add GGUF metadata data holder and its corresponding extractor implementation
* DB: introduce Kotlin serialization extension's library and plugin; add Room runtime library
* GGUF: make GgufMetadata serializable in order to be compatible with Room
* nit: refactor data.local package structure
* nit: rename lastUsed field to dateLastUsed; add dateAdded field
* UI: refactor ModelCard UI to show GGUF metadata
* UI: update ModelSelectionScreen with a preselect mechanism
* UI: polish model card
* nit: allow deselect model on Model Selection screen
* nit: revert accidental committing of debug code
* UI: polish ModelLoading screen
* util: extract formatting helper functions from FileUtils into a new FormatUtils
* UI: polish model cards on Benchmark and Conversation screens to show model loading metrics
* UI: show a Snack bar to warn user that system prompt is not always supported
* UI: handle back press on Model Selection screen
* UI: finally support theme modes; remove hardcoded color schemes, default to dynamic color scheme implementation
* feature: support searching on Model Selection screen
* nit: move scaffold related UI components into a separate package
* UI: extract InfoView out into a separate file for reusability
* data: move Model related actions (query, filter, sort) into ModelInfo file
* UI: animate FAB on model preselection states
* feature: support filtering in Model Management screen
* ui: show empty models info in Model Management screen
* ui: add filter off icon to "Clear filters" menu item
* [WIP] ui: polish Benchmark screen; implement its bottom app bar
* ui: polish Benchmark screen; implement its bottom app bar's rerun and share
* nit: disable mode selection's radio buttons when loading model
* feature: implement Conversation screen's bottom app bar
* pkg: restructure BottomAppBars into separate files in a child package
* pkg: restructure TopBarApps into separate files in a child package
* pkg: restructure system metrics into a separate file
* UI: polish Conversation screen
* data: update system prompt presets
* UI: allow hide or show model card on Conversation & Benchmark screens; fix message arrangement
* data: update & enhance system prompt presets
* deps: introduce Retrofit2
* data: implement HuggingFace data model, data source with Retrofit API
* data: update Model data repository to support fetching HuggingFace models
* [WIP] UI: replace the HuggingFace stub in Model Management screen with actual API call
* UI: map language codes into country Emojis
* ui: add "clear results" action to Benchmark screen
* nit: print current pp & tg in llama-bench
* UI: disable landscape mode; prevent duplicated benchmark running
* llama: migrate C/CXX flags into CMakeList
* [WIP] llama: ABI split builds five .so artifacts.
However, all .so are performing on SVE level
* [WIP] llama: ABI split where five tiers are built sequentially.
* [WIP] llama: disable OpenMP in ABI split since most SoCs are big.LITTLE
* [WIP] llama: enable KleidiAI and disable tier 4 due to `+sve+sve2` bug caused by `ggml_add_cpu_backend_variant_impl` as explained below
```CMake
if (NOT SME_ENABLED MATCHES -1)
...
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
...
```
* core: add Google's cpu_features as a submodule
* core: implement cpu_detector native lib
* core: swap out hardcoded LlamaAndroid library loading
* core: add back OpenMP due to huge perf loss on TG128
* misc: reorg the pkg structure
* misc: rename LlamaAndroid related class to InferenceEngine prefixes
* [WIP] lib: move GgufMetadata into the lib submodule
* lib: expose GgufMetadataReader as interface only
* lib: replace the naive & plain SharedPreferences with DataStore implementation
* lib: hide the internal implementations, only expose a facade and interfaces
* lib: expose Arm features
* di: add a stub TierDetection; provide both actual impl and stub in AppModule
* UI: add visualizer UI for Arm features
* misc: UI polish
* lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier
* UI: support `NONE` Llama Tier in general settings
* lib: optimize engine loader; always perform a fresh detection when cache is null
* remote: add HuggingFaceModelDetails data class
* remote: refine HuggingFaceModel data class
* nit: remove `trendingScore` field from HuggingFace model entities, weird...
* remote: refactor HuggingFaceApiService; implement download feature in HuggingFaceRemoteDataSource
* remote: fix the incorrect parse of HuggingFace's inconsistent & weird JSON response
* UI: scaffold Models Management screen and view model
* UI: implement a dialog UI to show fetched HuggingFace models.
* UI: use a broadcast receiver to listen for download complete events and show local import dialog.
* data: handle network exceptions elegantly
* pkg: restructure `data`'s packages
* data: extract local file info, copy and cleanup logics into LocalFileDataSource
* nit: minor UI patch; add missing comments
* bugfix: tapping "Home" in navigation drawer should simply close it without any navigation action.
* UI: improve autoscroll during token generation
* lib: tested on JFrog Artifactory for Maven publishing
* UI: show RAM warning if model too large
* UI: polish model management screen's error dialog
* util: add more items into the mapping table of ISO 639-1 language code to ISO 3166-1 country code
* llm: properly propagate error to UI upon failing to load selected model
* UI: avoid duplicated calculation of token metrics
* lib: read & validate the magic number from the picked source file before executing the import
* UI: add "Learn More" hyperlinks to Error dialog upon model import failures
* lib: refactor the GgufMetadataReader to take InputStream instead of absolute path as argument
* lib: fix the `SIMD` typo in Tier description
* core: verify model file path is readable
* lib: add UnsupportedArchitectureException for triaged error message
* util: split FormatUtils into multiple utils for better readability
* UI: change benchmark screen from raw markdown to table view
* bugfix: reset preselection upon running the preselected model
* misc: linter issue
* bugfix: fix the malfunctioning monitoring switch
* UI: update Arm features indicator; fix the broken hyperlinks
* UI: add quick action buttons to benchmark screen's result card
* UI: hide share fab after clearing all benchmark results
* UI: fix the model unload dialog message; elevate the model card and hide it by default on Conversation screen;
* UI: hide the stubbing actions in Conversation screen
* UI: add show/hide stats control to conversation screen's assistant message bubble; fix placeholder
* UI: add a info button to explain token metrics
* misc: remove the redundant `Companion` added due to refactoring
* UI: show corresponding system metrics detailed info upon tapping RAM / storage / temperature indicator
* UI: add info button to System Prompt switch; expand the model card by default
* UI: disable tag & language chips; add section headers to explain what they are
* misc: replace top bar indicator's spacer with padding
* UI: merge the Model Selection and Model Management into a unified Models screen
* UI: split the ModelsManagementViewModel from a unified ModelsViewModel due to huge complexity
* UI: add model loading in progress view; polish the empty model info view
* UI: polish the bottom bars and info view when no models found; show loading in progress while fetching models
* build: [BREAKING] bump the versions of libraries and plugins
* UI: fix the breaking build
* UI: add Tooltip on Import FAB for user onboarding
* UI: adds AppPreferences to track user onboarding status
* UI: tracks user's first success on importing a model
* data: add hand crafted rules to filter the models fetched from HuggingFace API
* UI: update app name & about; polish top bars' indicators & buttons
* UI: polish Hugging Face download dialog UI
* UX: implement onboarding tooltips for model import and onboarding
* misc: use sentence case for CTA button labels
* [WIP] UI: add Arm color palette from Philip.Watson3
* UI: address Rojin's UX feedbacks
* UI: address Rojin's UX feedbacks - part 2
* UI: update Arm color palette from Philip.Watson3
* data: make sure fetch preselected models in the same order of their IDs
* UI: fix UI issues in the generic settings screen and navigation drawer
* nit: address Rojin's feedbacks on model import message again
* nit: append `®` to all `Arm` labels
* UI: extract a reusable InfoAlertDialog
* core: support GGML_CPU_ALL_VARIANTS on Android!
* core: restructure Kleidi-Llama library
* core: organizing cmake arguments
* data: sort preselected models according to device's available RAM
* app: update adaptive + themed + legacy icons and app name
* UI: fix the font size auto scaling for ArmFeaturesVisualizer
* core: further improve the performance on native methods
* UI: minor color palette changes; emphasize the bottom bar FABs; fix Settings Screen menu item label
* UI: make more room for assistant message bubble's width
* UI: better usage of tertiary colors to highlight model cards but not for warnings
* UI: fix the layout issue on large font sizes
* lib: support x86-64 by dynamically set Arm related definitions
* lib: replace the factory pattern for deprecated tiered lib loading with single instance pattern
* llama: update the library name in JNI and CMake project
* llama: update the library's package name and namespace
* llama: update the app's package name and namespace
* app: bump ksp version
* app: remove deprecated SystemUIController from accompanist by migrating to EdgeToEdge
* app: extract AppContent from MainActivity to a separate file in ui package
* lib: add File version for GGUF Magic number verification
* lib: perform engine state check inclusively instead of exclusively
* lib: change `LlamaTier` to `ArmCpuTier`
* lib: remove kleidi-llama related namings
* cleanup: remove Arm AI Chat/Playground app source code; replace with the basic sample app from https://github.com/hanyin-arm/Arm-AI-Chat-Sample
Note: the full Google Play version of AI Chat app will be open will be open sourced in another repo soon, therefore didn't go through the trouble of pruning the history using `git filter-repo` here.
* [WIP] doc: update main and Android README docs; add self to code owners
* lib: revert System.load back to System.loadLibrary
* jni: introduce a logging util to filter different logging levels on different build types
* lib: enable app optimization
* doc: replace stub Google Play app URL with the actual link add screenshots; add my GitHub ID to maintainer list
* Remove cpu_features
* Fix linters issues in editorconfig-checker job
https://github.com/ggml-org/llama.cpp/actions/runs/19548770247/job/55974800633?pr=17413
* Remove unnecessary Android CMake flag
* purge include/cpu_features directory
---------
Co-authored-by: Han Yin <han.yin@arm.com>
This commit is contained in:
@@ -0,0 +1,565 @@
|
||||
#include <android/log.h>
|
||||
#include <jni.h>
|
||||
#include <iomanip>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sampling.h>
|
||||
|
||||
#include "logging.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template<class T>
|
||||
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||
std::ostringstream str;
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
str << values[i];
|
||||
if (i < values.size() - 1) { str << delim; }
|
||||
}
|
||||
return str.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* LLama resources: context, model, batch and sampler
|
||||
*/
|
||||
constexpr int N_THREADS_MIN = 2;
|
||||
constexpr int N_THREADS_MAX = 4;
|
||||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
||||
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
|
||||
constexpr int OVERFLOW_HEADROOM = 4;
|
||||
constexpr int BATCH_SIZE = 512;
|
||||
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
||||
|
||||
static llama_model * g_model;
|
||||
static llama_context * g_context;
|
||||
static llama_batch g_batch;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
static common_sampler * g_sampler;
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
|
||||
// Set llama log handler to Android
|
||||
llama_log_set(aichat_android_log_callback, nullptr);
|
||||
|
||||
// Loading all CPU backend variants
|
||||
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
|
||||
LOGi("Loading backends from %s", path_to_backend);
|
||||
ggml_backend_load_all_from_path(path_to_backend);
|
||||
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
|
||||
|
||||
// Initialize backends
|
||||
llama_backend_init();
|
||||
LOGi("Backend initiated; Log handler set.");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
|
||||
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
|
||||
|
||||
auto *model = llama_model_load_from_file(model_path, model_params);
|
||||
env->ReleaseStringUTFChars(jmodel_path, model_path);
|
||||
if (!model) {
|
||||
return 1;
|
||||
}
|
||||
g_model = model;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
|
||||
if (!model) {
|
||||
LOGe("%s: model cannot be null", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Multi-threading setup
|
||||
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||
N_THREADS_HEADROOM));
|
||||
LOGi("%s: Using %d threads", __func__, n_threads);
|
||||
|
||||
// Context parameters setup
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
const int trained_context_size = llama_model_n_ctx_train(model);
|
||||
if (n_ctx > trained_context_size) {
|
||||
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
|
||||
__func__, trained_context_size, n_ctx);
|
||||
}
|
||||
ctx_params.n_ctx = n_ctx;
|
||||
ctx_params.n_batch = BATCH_SIZE;
|
||||
ctx_params.n_ubatch = BATCH_SIZE;
|
||||
ctx_params.n_threads = n_threads;
|
||||
ctx_params.n_threads_batch = n_threads;
|
||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||
if (context == nullptr) {
|
||||
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
static common_sampler *new_sampler(float temp) {
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = temp;
|
||||
return common_sampler_init(g_model, sparams);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
|
||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string get_backend() {
|
||||
std::vector<std::string> backends;
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto *reg = ggml_backend_reg_get(i);
|
||||
std::string name = ggml_backend_reg_name(reg);
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
}
|
||||
}
|
||||
return backends.empty() ? "CPU" : join(backends, ",");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
return env->NewStringUTF(llama_print_system_info());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto *context = init_context(g_model, pp);
|
||||
if (!context) {
|
||||
const auto *const err_msg = "Fail to init_context! Bench aborted.";
|
||||
LOGe(err_msg);
|
||||
return env->NewStringUTF(err_msg);
|
||||
}
|
||||
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
auto tg_std = 0.0;
|
||||
|
||||
const uint32_t n_ctx = llama_n_ctx(context);
|
||||
LOGi("n_ctx = %d", n_ctx);
|
||||
|
||||
int i, j;
|
||||
int nri;
|
||||
for (nri = 0; nri < nr; nri++) {
|
||||
LOGi("Benchmark prompt processing (pp = %d)", pp);
|
||||
|
||||
common_batch_clear(g_batch);
|
||||
|
||||
const int n_tokens = pp;
|
||||
for (i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(g_batch, 0, i, {0}, false);
|
||||
}
|
||||
|
||||
g_batch.logits[g_batch.n_tokens - 1] = true;
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during prompt processing");
|
||||
}
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
||||
// bench text generation
|
||||
|
||||
LOGi("Benchmark text generation (tg = %d)", tg);
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
common_batch_clear(g_batch);
|
||||
for (j = 0; j < pl; j++) {
|
||||
common_batch_add(g_batch, 0, i, {j}, true);
|
||||
}
|
||||
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during text generation");
|
||||
}
|
||||
}
|
||||
const auto t_tg_end = ggml_time_us();
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||
|
||||
const auto speed_pp = double(pp) / t_pp;
|
||||
const auto speed_tg = double(pl * tg) / t_tg;
|
||||
|
||||
pp_avg += speed_pp;
|
||||
tg_avg += speed_tg;
|
||||
|
||||
pp_std += speed_pp * speed_pp;
|
||||
tg_std += speed_tg * speed_tg;
|
||||
|
||||
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
|
||||
pp_avg /= double(nr);
|
||||
tg_avg /= double(nr);
|
||||
|
||||
if (nr > 1) {
|
||||
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
|
||||
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
|
||||
} else {
|
||||
pp_std = 0;
|
||||
tg_std = 0;
|
||||
}
|
||||
|
||||
char model_desc[128];
|
||||
llama_model_desc(g_model, model_desc, sizeof(model_desc));
|
||||
|
||||
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
|
||||
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
|
||||
|
||||
const auto backend = get_backend();
|
||||
std::stringstream result;
|
||||
result << std::setprecision(3);
|
||||
result << "| model | size | params | backend | test | t/s |\n";
|
||||
result << "| --- | --- | --- | --- | --- | --- |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
||||
return env->NewStringUTF(result.str().c_str());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Completion loop's long-term states:
|
||||
* - chat management
|
||||
* - position tracking
|
||||
*/
|
||||
constexpr const char *ROLE_SYSTEM = "system";
|
||||
constexpr const char *ROLE_USER = "user";
|
||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
static llama_pos system_prompt_position;
|
||||
static llama_pos current_position;
|
||||
|
||||
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||
chat_msgs.clear();
|
||||
system_prompt_position = 0;
|
||||
current_position = 0;
|
||||
|
||||
if (clear_kv_cache)
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO-hyin: implement sliding-window version as a better alternative
|
||||
*
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [system_prompt_position] first tokens from the original prompt
|
||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||
* - recompute the logits in batches
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||
current_position -= n_discard;
|
||||
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
auto formatted = common_chat_format_single(
|
||||
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||
chat_msgs.push_back(new_msg);
|
||||
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Completion loop's short-term states:
|
||||
* - stop generation position
|
||||
* - token chars caching
|
||||
* - current assistant message being generated
|
||||
*/
|
||||
static llama_pos stop_generation_position;
|
||||
static std::string cached_token_chars;
|
||||
static std::ostringstream assistant_ss;
|
||||
|
||||
static void reset_short_term_states() {
|
||||
stop_generation_position = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
}
|
||||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
llama_batch &batch,
|
||||
const llama_tokens &tokens,
|
||||
const llama_pos start_pos,
|
||||
const bool compute_last_logit = false) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(batch);
|
||||
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||
|
||||
// Shift context if current batch cannot fit into the context
|
||||
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Add tokens to the batch with proper positions
|
||||
for (int j = 0; j < cur_batch_size; j++) {
|
||||
const llama_token token_id = tokens[i + j];
|
||||
const llama_pos position = start_pos + i + j;
|
||||
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
common_batch_add(batch, token_id, position, {0}, want_logit);
|
||||
}
|
||||
|
||||
// Decode this batch
|
||||
const int decode_result = llama_decode(context, batch);
|
||||
if (decode_result) {
|
||||
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain system prompt from JEnv
|
||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
|
||||
std::string formatted_system_prompt(system_prompt);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||
|
||||
// Format system prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||
has_chat_template, has_chat_template);
|
||||
for (auto id: system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Handle context overflow
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) system_tokens.size() > max_batch_size) {
|
||||
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
|
||||
__func__, (int) system_tokens.size(), max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
system_prompt_position = current_position = (int) system_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring juser_prompt,
|
||||
jint n_predict
|
||||
) {
|
||||
// Reset short-term states
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
|
||||
std::string formatted_user_prompt(user_prompt);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||
|
||||
// Format user prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id: user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int user_prompt_size = (int) user_tokens.size();
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if (user_prompt_size > max_batch_size) {
|
||||
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
|
||||
}
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position += user_prompt_size;
|
||||
stop_generation_position = current_position + user_prompt_size + n_predict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool is_valid_utf8(const char *string) {
|
||||
if (!string) { return true; }
|
||||
|
||||
const auto *bytes = (const unsigned char *) string;
|
||||
int num;
|
||||
|
||||
while (*bytes != 0x00) {
|
||||
if ((*bytes & 0x80) == 0x00) {
|
||||
// U+0000 to U+007F
|
||||
num = 1;
|
||||
} else if ((*bytes & 0xE0) == 0xC0) {
|
||||
// U+0080 to U+07FF
|
||||
num = 2;
|
||||
} else if ((*bytes & 0xF0) == 0xE0) {
|
||||
// U+0800 to U+FFFF
|
||||
num = 3;
|
||||
} else if ((*bytes & 0xF8) == 0xF0) {
|
||||
// U+10000 to U+10FFFF
|
||||
num = 4;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
bytes += 1;
|
||||
for (int i = 1; i < num; ++i) {
|
||||
if ((*bytes & 0xC0) != 0x80) {
|
||||
return false;
|
||||
}
|
||||
bytes += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
// Infinite text generation via context shifting
|
||||
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Context full! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Stop if reaching the marked position
|
||||
if (current_position >= stop_generation_position) {
|
||||
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Sample next token
|
||||
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||
common_sampler_accept(g_sampler, new_token_id, true);
|
||||
|
||||
// Populate the batch with new token, then decode
|
||||
common_batch_clear(g_batch);
|
||||
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position++;
|
||||
|
||||
// Stop if next token is EOG
|
||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If not EOG, convert to text
|
||||
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||
cached_token_chars += new_token_chars;
|
||||
|
||||
// Create and return a valid UTF-8 Java string
|
||||
jstring result = nullptr;
|
||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||
|
||||
assistant_ss << cached_token_chars;
|
||||
cached_token_chars.clear();
|
||||
} else {
|
||||
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||
result = env->NewStringUTF("");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Free up resources
|
||||
common_sampler_free(g_sampler);
|
||||
g_chat_templates.reset();
|
||||
llama_batch_free(g_batch);
|
||||
llama_free(g_context);
|
||||
llama_model_free(g_model);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
|
||||
llama_backend_free();
|
||||
}
|
||||
Reference in New Issue
Block a user