server : support multi-modal context checkpoints (#19849)
* Modify llama-memory-hybrid-iswa.cpp * Modify llama-memory-recurrent.cpp * Modify server-common.cpp * Modify server-common.h * Modify server-context.cpp * Modify server-task.h * Added comment to llama-memory-hybrid-iswa.cpp * Remove comment from server-context.cpp * Stylistic fix server-context.cpp * Fix an issue when seqrm isn't called in server-context.cpp * cont : alternative impl * cont : cleanup * cont : n_tokens -> int64_t --------- Co-authored-by: timkhronos <timkhronos@gmail.com>
This commit is contained in:
@@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
|
||||
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||
}
|
||||
|
||||
llama_pos server_tokens::pos_next() const {
|
||||
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
|
||||
if (!has_mtmd) {
|
||||
return tokens.size();
|
||||
if (n_tokens < 0) {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
return n_tokens;
|
||||
}
|
||||
|
||||
llama_pos res = tokens.size();
|
||||
if (n_tokens < 0) {
|
||||
llama_pos res = tokens.size();
|
||||
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
const auto & chunk = it->second;
|
||||
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
const auto & chunk = it->second;
|
||||
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
return res;
|
||||
int64_t idx = 0;
|
||||
llama_pos pos = 0;
|
||||
|
||||
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
|
||||
|
||||
while (idx < n_tokens) {
|
||||
const auto media_it = map_idx_to_media.find(idx);
|
||||
if (media_it != map_idx_to_media.end()) {
|
||||
const auto & chunk = media_it->second;
|
||||
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
|
||||
pos += n_pos;
|
||||
idx += n_tok;
|
||||
} else {
|
||||
pos++;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
|
||||
if (!has_mtmd) {
|
||||
return std::min((size_t)(max_pos + 1), tokens.size());
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
llama_pos pos = 0;
|
||||
|
||||
while (idx < tokens.size()) {
|
||||
const auto media_it = map_idx_to_media.find(idx);
|
||||
if (media_it != map_idx_to_media.end()) {
|
||||
const auto & chunk = media_it->second;
|
||||
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
|
||||
pos += n_pos;
|
||||
idx += n_tok;
|
||||
} else {
|
||||
pos++;
|
||||
idx++;
|
||||
}
|
||||
|
||||
if (pos > max_pos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
std::string server_tokens::str() const {
|
||||
|
||||
Reference in New Issue
Block a user