CUDA: add stream-based concurrency (#16991)

* CUDA: add stream-based concurrency

* HIP: fix hipStreamWaitEvent define and nodiscard warnings

* ggml-cuda: fix fusion inside stream

* ggml-cuda: fix bug w.r.t first stream launch

* ggml-cuda: format

* ggml-cuda: improve assert message

* ggml-cuda: use lambda instead of duplicating code

* ggml-cuda: add some more comments

* ggml-cuda: add more detailed comments about concurrency

* ggml-cuda: rename + remove unused var

* ggml-cuda: fix condition for stream launch

* ggml-cuda: address review comments, add destructor

* common.cuh: add is_valid for concurrent events

* common.cuh: make comment better

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* common.cuh: fix lower_bound condition + remove join_node data from write_ranges

* ggml-cuda: fix overlap condition + shadowing parameter

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Aman Gupta
2025-11-30 08:17:55 +08:00
committed by GitHub
parent 00425e2ed1
commit c7af376c29
3 changed files with 469 additions and 14 deletions
+162 -8
View File
@@ -21,10 +21,12 @@
#include "ggml-common.h"
#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#if defined(GGML_USE_HIP)
@@ -980,6 +982,154 @@ struct ggml_cuda_graph {
#endif
};
struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event = nullptr;
int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;
const ggml_tensor * join_node;
ggml_cuda_concurrent_event() = default;
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);
for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
: join_events(std::move(other.join_events))
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
, join_node(other.join_node) {
other.fork_event = nullptr;
}
// 1. check if any branches write to overlapping memory ranges (except the join node)
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
// we assume all nodes have the same buffer
bool is_valid() const {
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
write_ranges.resize(n_streams);
// get join_node's memory range to exclude from overlap checking.
// multiple nodes can use join_node's buffer; we synchronize on the join node.
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
const int64_t join_start = (int64_t) join_t->data;
const int64_t join_end = join_start + ggml_nbytes(join_t);
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer.
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// concurrent streams begin from 1
write_ranges[stream - 1].emplace_back(t_start, t_end);
}
for (int i = 0; i < n_streams; ++i) {
// sorts first by start then by end of write range
std::sort(write_ranges[i].begin(), write_ranges[i].end());
}
bool writes_overlap = false;
bool dependent_srcs = false;
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// check if this buffer's write data overlaps with another stream's
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
for (int i = 0; i < n_streams; ++i) {
if (i == stream - 1) {
continue;
}
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
if (it != write_ranges[i].end()) {
const std::pair<int64_t, int64_t> & other = *it;
// std::lower_bound returns the first element where other >= data_range (lexicographically).
// This guarantees other.first >= data_range.first.
// Therefore, overlap occurs iff other.first < data_range.second
// (i.e., the other range starts before this range ends).
if (other.first < data_range.second) {
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
writes_overlap = true;
break;
}
}
}
//check if all srcs are either in branch or don't have a branch
for (int i = 0; i < GGML_MAX_SRC; ++i) {
if (!tensor->src[i]) {
continue;
}
auto it = stream_mapping.find(tensor->src[i]);
if (it == stream_mapping.end()) {
continue;
}
if (it->second != stream) {
dependent_srcs = true;
break;
}
}
if (dependent_srcs || writes_overlap) {
break;
}
}
return !writes_overlap && !dependent_srcs;
}
~ggml_cuda_concurrent_event() {
if (fork_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(fork_event));
}
for (cudaEvent_t e : join_events) {
if (e != nullptr) {
CUDA_CHECK(cudaEventDestroy(e));
}
}
}
};
struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
void reset() {
original_nodes.clear();
concurrent_events.clear();
}
};
struct ggml_backend_cuda_context {
int device;
std::string name;
@@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
int curr_stream_no = 0;
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
ggml_cuda_stream_context concurrent_stream_context;
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
@@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
@@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}
ggml_cuda_pool & pool() {