ggml: backend-agnostic tensor parallelism (experimental) (#19378)

* ggml: backend-agnostic tensor parallelism

* support for GPT-OSS, Qwen 3 MoE

* partial Vulkan fix

* add support for 4/8 GPUs

* unconditional peer access

* re-use buffers + ggml contexts

* fix output pattern

* NCCL support

* GGML: HIP: add RCCL support

* Remove shfl and AllReduce from backend interface

* move allocation workaround out of ggml-alloc.c

* 2d tensor set/get support

* Fix the seg fault without NCCL

* Apply suggestion from JohannesGaessler

* support for tensor dims % n_devs != 0

* fix view_offs scaling

* arbitrary num. of GPUs/tensor split

* fix compilation

* better granularity estimate

* Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.

* partial Qwen 3 Next support

* Fix qwen3 30b (#8)

* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type

* Fix crashes due to KV cache serialization (#9)

KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.

* metal : fix build (#7)

* static memory allocations, fix usage count

* fix tensor granularity

* more even memory distribution

* use BF16 for allreduce

* rebase fixup

* better error message for unsupported architectures

* Fix device mismatch during scatter of allReduce. (#11)

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies

* Enable the previous allreduce implementation. It is better in both perf and stability (#12)

* delay AllReduce for Moe for less I/O

* build : clean-up compile warnings

* backend : move most of the meta backend API to ggml-backend-impl.h

* cont : hide unused public API in the implementation

* llama : use llama_device + remove ggml_backend_dev_is_meta()

* ggml-backend : remove unused alloc include

* minor : remove regex include

* ggml : introduce ggml-ext.h for staging new APIs

* rebase fixup

* fix tests

* llama : more robust logic for determining Meta devices (#16)

* llama : more robust logic for determining Meta devices

* cont : fix devs size check

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* cont : fix log type

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* disable roundtrip for meta backend

* fix arch selection

* Qwen 3.5 support

* fix Gemma 4 MoE

* fix OpenVino, SYCL

* fix test-llama-archs for CPU-only builds

* Fix Qwen 3.5 MoE

* disable meta backend tests for WebGPU

* tests : filter CPU-based devices from the Meta backend tests (#17)

* meta : formatting, naming, indentation (#18)

* formatting : llama-model.cpp

* formatting : ggml-ext.h

* formatting : ggml-backend-meta.cpp

* meta : add TODO

* add documentation

* better error messages

* fix GPT-OSS

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Johannes Gäßler
2026-04-09 16:42:19 +02:00
committed by GitHub
parent 009a113326
commit d6f3030047
48 changed files with 3198 additions and 342 deletions
+4
View File
@@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 11)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
if(GIT_EXE)
# Get current git commit hash
@@ -204,12 +206,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON)
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
"ggml: cuda link binary compression mode; requires cuda 12.8+")
set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")
option(GGML_HIP "ggml: use HIP" OFF)
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
+36
View File
@@ -0,0 +1,36 @@
# cmake/FindNCCL.cmake
# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead.
find_path(NCCL_INCLUDE_DIR
NAMES nccl.h
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
PATH_SUFFIXES include
)
find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
PATH_SUFFIXES lib lib64
)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL
DEFAULT_MSG
NCCL_LIBRARY NCCL_INCLUDE_DIR
)
if(NCCL_FOUND)
set(NCCL_LIBRARIES ${NCCL_LIBRARY})
set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR})
if(NOT TARGET NCCL::NCCL)
add_library(NCCL::NCCL UNKNOWN IMPORTED)
set_target_properties(NCCL::NCCL PROPERTIES
IMPORTED_LOCATION "${NCCL_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}"
)
endif()
endif()
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
+17 -9
View File
@@ -68,7 +68,7 @@ extern "C" {
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst);
//
// Backend (stream)
@@ -83,13 +83,17 @@ extern "C" {
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
// "offset" refers to the offset in tensor->data for setting/getting data
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
@@ -109,7 +113,7 @@ extern "C" {
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
@@ -135,7 +139,9 @@ extern "C" {
// integrated GPU device using host memory
GGML_BACKEND_DEVICE_TYPE_IGPU,
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
GGML_BACKEND_DEVICE_TYPE_ACCEL
GGML_BACKEND_DEVICE_TYPE_ACCEL,
// "meta" device wrapping multiple other devices for tensor parallelism
GGML_BACKEND_DEVICE_TYPE_META,
};
// functionality supported by the device
@@ -196,7 +202,9 @@ extern "C" {
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
// Split buffer type for tensor parallelism
// AllReduce operation for tensor parallelism (meta backend)
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
// Split buffer type for tensor parallelism (old)
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
// Set the number of threads for the backend
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
+3
View File
@@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
// device buffer
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// conduct allreduce operation between devices
GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
+1
View File
@@ -200,6 +200,7 @@ add_library(ggml-base
ggml.cpp
ggml-alloc.c
ggml-backend.cpp
ggml-backend-meta.cpp
ggml-opt.cpp
ggml-threading.cpp
ggml-threading.h
+3
View File
@@ -1236,6 +1236,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx,
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
size_t nbytes_total = 0;
if (ggml_backend_buft_is_meta(buft)) {
return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft);
}
return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false);
}
+22 -2
View File
@@ -49,6 +49,10 @@ extern "C" {
void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) 2d data copies
void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
// (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
// clear the entire buffer
@@ -80,6 +84,20 @@ extern "C" {
GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
//
// Backend (meta)
//
GGML_API bool ggml_backend_is_meta (ggml_backend_t backend);
GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf);
GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend);
GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index);
// temporary workaround to statically allocate tensors from a context in a deduplicated way:
GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
//
// Backend (stream)
//
@@ -90,8 +108,10 @@ extern "C" {
void (*free)(ggml_backend_t backend);
// (optional) asynchronous tensor data access
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
// (optional) complete all pending operations (required if the backend supports async operations)
File diff suppressed because it is too large Load Diff
+102 -8
View File
@@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
// get_base is optional if the buffer is zero-sized
if (buffer->size == 0) {
if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) {
return NULL;
}
@@ -279,15 +279,57 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
}
}
void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size,
size_t n_copies, size_t stride_tensor, size_t stride_data) {
GGML_ASSERT(backend);
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) {
for (size_t i = 0; i < n_copies; i++) {
ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size);
}
return;
}
if (size == 0) {
return;
}
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data);
}
void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size,
size_t n_copies, size_t stride_tensor, size_t stride_data) {
GGML_ASSERT(backend);
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) {
for (size_t i = 0; i < n_copies; i++) {
ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size);
}
return;
}
if (size == 0) {
return;
}
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data);
}
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf != NULL && "tensor buffer not set");
if (size == 0) {
return;
}
GGML_ASSERT(buf != NULL && "tensor buffer not set");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
@@ -297,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf != NULL && "tensor buffer not set");
if (size == 0) {
return;
}
GGML_ASSERT(buf != NULL && "tensor buffer not set");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
buf->iface.get_tensor(buf, tensor, data, offset, size);
}
void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size,
size_t n_copies, size_t stride_tensor, size_t stride_data) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf != NULL && "tensor buffer not set");
if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) {
for (size_t i = 0; i < n_copies; i++) {
ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size);
}
return;
}
if (size == 0) {
return;
}
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data);
}
void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size,
size_t n_copies, size_t stride_tensor, size_t stride_data) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf != NULL && "tensor buffer not set");
if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) {
for (size_t i = 0; i < n_copies; i++) {
ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size);
}
return;
}
if (size == 0) {
return;
}
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data);
}
void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
@@ -388,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
// backend copy
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
if (src == dst) {
@@ -402,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
} else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
#endif
#endif // NDEBUG
size_t nbytes = ggml_nbytes(src);
void * data = malloc(nbytes);
ggml_backend_tensor_get(src, data, 0, nbytes);
@@ -411,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
}
}
void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
if (src == dst) {
@@ -500,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
}
void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
GGML_ASSERT(device);
memset(props, 0, sizeof(*props));
device->iface.get_props(device, props);
}
@@ -610,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ NULL,
/* .get_tensor = */ NULL,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_multi_buffer_clear,
/* .reset = */ NULL,
@@ -1899,8 +1988,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct
GGML_ASSERT(tensor->data == NULL);
GGML_ASSERT(tensor->view_src == NULL);
GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
(char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) ||
(char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
(char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
tensor->buffer = buffer;
tensor->data = addr;
@@ -2174,6 +2264,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
@@ -2186,6 +2278,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
+2
View File
@@ -262,6 +262,8 @@ static struct ggml_backend_i blas_backend_i = {
/* .get_name = */ ggml_backend_blas_get_name,
/* .free = */ ggml_backend_blas_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
+4
View File
@@ -1457,6 +1457,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
/* .clear = */ ggml_backend_cann_buffer_clear,
/* .reset = */ NULL,
@@ -2698,6 +2700,8 @@ static const ggml_backend_i ggml_backend_cann_interface = {
/* .free = */ ggml_backend_cann_free,
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
/* .synchronize = */ ggml_backend_cann_synchronize,
/* .graph_plan_create = */ NULL,
+2
View File
@@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ nullptr,
/* .set_tensor_2d = */ nullptr,
/* .get_tensor_2d = */ nullptr,
/* .cpy_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ nullptr,
+2
View File
@@ -195,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .free = */ ggml_backend_cpu_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
+10
View File
@@ -181,6 +181,16 @@ if (CUDAToolkit_FOUND)
target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
endif()
if (GGML_CUDA_NCCL)
find_package(NCCL)
if (NCCL_FOUND)
add_compile_definitions(GGML_USE_NCCL)
target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL)
else()
message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal")
endif()
endif()
set(CUDA_CXX_FLAGS "")
set(CUDA_FLAGS -use_fast_math -extended-lambda)
+8
View File
@@ -186,6 +186,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#ifdef GGML_USE_NCCL
#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString)
#endif // GGML_USE_NCCL
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
@@ -1086,6 +1090,10 @@ struct ggml_cuda_device_info {
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
#ifdef GGML_USE_NCCL
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
#endif // GGML_USE_NCCL
};
const ggml_cuda_device_info & ggml_cuda_info();
+166 -79
View File
@@ -324,6 +324,28 @@ static ggml_cuda_device_info ggml_cuda_init() {
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
for (int id = 0; id < info.device_count; ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < info.device_count; ++id_other) {
if (id == id_other) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
}
}
}
#ifdef GGML_USE_NCCL
int dev_ids[GGML_CUDA_MAX_DEVICES];
for (int id = 0; id < info.device_count; ++id) {
dev_ids[id] = id;
}
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
#endif // GGML_USE_NCCL
return info;
}
@@ -632,26 +654,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
}
static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data,
size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpy2DAsync(
(char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data,
size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
CUDA_CHECK(cudaMemcpy2DAsync(
data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
@@ -691,6 +733,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
/* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
/* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d,
/* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d,
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
/* .clear = */ ggml_backend_cuda_buffer_clear,
/* .reset = */ NULL,
@@ -1003,6 +1047,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_cuda_split_buffer_clear,
/* .reset = */ NULL,
@@ -1079,6 +1125,83 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
};
bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) {
#ifdef GGML_USE_NCCL
const int64_t ne = ggml_nelements(tensors[0]);
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
// This then causes a crash in this function
if (ne == 0) {
return true;
}
for (size_t i = 0; i < n_backends; ++i) {
GGML_ASSERT(tensors[i] != nullptr);
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}
const ggml_cuda_device_info info = ggml_cuda_info();
// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
return true;
}
// For large tensors it's faster to compress them to BF16 for the reduction:
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
tmp[i].pool = &cuda_ctx->pool();
tmp[i].alloc(ne);
ggml_cuda_set_device(i);
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
ggml_cuda_set_device(i);
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
return true;
#else
// If NCCL is installed it is used by default for optimal performance.
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
// RCCL is disabled by default, users are explicitly opting in.
// Therefore print no warning for RCCL.
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
static bool warning_printed = false;
if (!warning_printed) {
GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__);
warning_printed = true;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
GGML_UNUSED_VARS(backends, tensors, n_backends);
return false;
#endif // GGML_USE_NCCL
}
ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
@@ -1425,64 +1548,6 @@ static void ggml_cuda_op_mul_mat_cublas(
GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
}
static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
static bool peer_access_enabled = false;
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
if (peer_access_enabled == enable_peer_access) {
return;
}
#ifdef NDEBUG
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
ggml_cuda_set_device(id);
CUDA_CHECK(cudaDeviceSynchronize());
}
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
if (id == id_other) {
continue;
}
if (id != main_device && id_other != main_device) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
if (enable_peer_access) {
cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
if (err != cudaErrorPeerAccessAlreadyEnabled) {
CUDA_CHECK(err);
} else {
// reset the error
(void)cudaGetLastError();
}
} else {
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
if (err != cudaErrorPeerAccessNotEnabled) {
CUDA_CHECK(err);
} else {
// reset the error
(void)cudaGetLastError();
}
}
}
}
}
ggml_cuda_set_device(main_device);
#endif // NDEBUG
peer_access_enabled = enable_peer_access;
GGML_UNUSED(main_device);
}
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
@@ -2483,11 +2548,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
// why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
}
switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
@@ -2845,21 +2905,43 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) {
}
static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
}
static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
}
static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data,
size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
CUDA_CHECK(cudaMemcpy2DAsync(
(char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream()));
}
static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data,
size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
CUDA_CHECK(cudaMemcpy2DAsync(
data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
}
static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
@@ -2870,21 +2952,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
return false;
}
if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
return false;
}
// device -> device copy
ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context;
ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context;
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
#endif
#endif // NDEBUG
return false;
}
@@ -2897,7 +2979,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
return false;
#else
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
#endif
#endif // GGML_CUDA_NO_PEER_COPY
}
// record event on src stream after the copy
@@ -4343,6 +4425,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .free = */ ggml_backend_cuda_free,
/* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
/* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
/* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
/* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
/* .synchronize = */ ggml_backend_cuda_synchronize,
/* .graph_plan_create = */ NULL,
@@ -5130,6 +5214,9 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
GGML_UNUSED(reg);
if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) {
return (void *)ggml_backend_cuda_allreduce_tensor;
}
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
return (void *)ggml_backend_cuda_split_buffer_type;
}
+4
View File
@@ -6,6 +6,10 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef GGML_USE_NCCL
#include <nccl.h>
#endif // GGML_USE_NCCL
#if CUDART_VERSION >= 11080
#include <cuda_fp8.h>
#define FP8_AVAILABLE
+6
View File
@@ -10,6 +10,11 @@
#include <rocwmma/rocwmma-version.hpp>
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
#ifdef GGML_USE_NCCL
#include <rccl/rccl.h>
#endif // GGML_USE_NCCL
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
@@ -28,6 +33,7 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+56
View File
@@ -0,0 +1,56 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
// This is a "staging" header for new ggml API
// It is not publicly available and it should not be used by 3rd party projects
//
// When the API matures enough, it will be moved to the official public API
//
// Meta backend
//
#define GGML_BACKEND_META_MAX_DEVICES 16
enum ggml_backend_meta_split_axis {
// tensor split by tensor dimensions:
GGML_BACKEND_SPLIT_AXIS_0 = 0,
GGML_BACKEND_SPLIT_AXIS_1 = 1,
GGML_BACKEND_SPLIT_AXIS_2 = 2,
GGML_BACKEND_SPLIT_AXIS_3 = 3,
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
// for internal bookkeeping only:
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
};
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
struct ggml_backend_meta_split_state {
enum ggml_backend_meta_split_axis axis;
// for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
// - each device has a slice of the tensor along the split axis
// - most tensors have n_segments == 1 and a contiguous slice of the tensor data
// - some tensors have an inhomogenenous data layout along the split axis,
// those tensors are divided into segments which are each individually split across devices
// - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
// the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
// - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
// that each need to be split individually across devices so that each device gets a slice of Q, K, and V
int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES];
uint32_t n_segments;
};
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
// TODO: this looks a bit strange - a backend API creates a device. I think we should try
// express this as a backend registry functionality instead
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);
+4
View File
@@ -1491,6 +1491,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor,
/* .clear = */ ggml_backend_hexagon_buffer_clear,
/* .reset = */ NULL,
@@ -3002,6 +3004,8 @@ static struct ggml_backend_i hexagon_backend_i = {
/* .free = */ ggml_backend_hexagon_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ ggml_backend_hexagon_synchronize,
/* .graph_plan_create = */ NULL,
+12
View File
@@ -47,6 +47,10 @@ find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
if (GGML_HIP_RCCL)
find_package(rccl REQUIRED)
endif()
if (${hip_VERSION} VERSION_LESS 6.1)
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
endif()
@@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA)
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
endif()
if (GGML_HIP_RCCL)
add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL.
endif()
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()
@@ -142,4 +150,8 @@ if (GGML_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
if (GGML_HIP_RCCL)
target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl)
endif()
target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)
+15 -9
View File
@@ -90,6 +90,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
/* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_shared_clear,
/* .reset = */ NULL,
@@ -158,15 +160,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer
}
static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_private_clear,
/* .reset = */ NULL,
/* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_private_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
/* .clear = */ ggml_backend_metal_buffer_private_clear,
/* .reset = */ NULL,
};
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
@@ -563,6 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = {
/* .free = */ ggml_backend_metal_free,
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
/* .synchronize = */ ggml_backend_metal_synchronize,
/* .graph_plan_create = */ NULL,
+4
View File
@@ -4063,6 +4063,8 @@ static ggml_backend_i ggml_backend_opencl_i = {
/* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */
/* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */
/* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .synchronize = */ ggml_backend_opencl_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
@@ -5778,6 +5780,8 @@ static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_opencl_buffer_clear,
/* .reset = */ ggml_backend_opencl_buffer_reset,
+4
View File
@@ -412,6 +412,8 @@ static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = {
/* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor,
/* .clear = */ ggml_backend_openvino_buffer_clear,
/* .reset = */ NULL,
@@ -617,6 +619,8 @@ static const ggml_backend_i ggml_backend_openvino_interface = {
/* .free = */ ggml_backend_openvino_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
+4
View File
@@ -706,6 +706,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
/* .clear = */ ggml_backend_rpc_buffer_clear,
/* .reset = */ NULL,
@@ -894,6 +896,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .synchronize = */ ggml_backend_rpc_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
+6
View File
@@ -638,6 +638,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
/* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
/* .clear = */ ggml_backend_sycl_buffer_clear,
/* .reset = */ ggml_backend_sycl_buffer_reset,
@@ -1084,6 +1086,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_sycl_split_buffer_clear,
/* .reset = */ NULL,
@@ -4553,6 +4557,8 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .free = */ ggml_backend_sycl_free,
/* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
// // TODO: update for the new
// interface
@@ -101,6 +101,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor,
/* .clear = */ ggml_backend_remoting_buffer_clear,
/* .reset = */ NULL,
@@ -113,6 +115,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = {
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr,
/* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor,
/* .clear = */ ggml_backend_remoting_buffer_clear,
/* .reset = */ NULL,
+2
View File
@@ -34,6 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = {
/* .free = */ ggml_backend_remoting_free,
/* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async,
/* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async,
/* .synchronize = */ NULL, // ggml_backend_remoting_synchronize,
/* .graph_plan_create = */ NULL,
+4
View File
@@ -13521,6 +13521,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
/* .clear = */ ggml_backend_vk_buffer_clear,
/* .reset = */ NULL,
@@ -14979,6 +14981,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .free = */ ggml_backend_vk_free,
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
/* .synchronize = */ ggml_backend_vk_synchronize,
/* .graph_plan_create = */ NULL,
+4
View File
@@ -3013,6 +3013,8 @@ static ggml_backend_i ggml_backend_webgpu_i = {
/* .free = */ ggml_backend_webgpu_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
@@ -3170,6 +3172,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
/* .clear = */ ggml_backend_webgpu_buffer_clear,
/* .reset = */ NULL, // TODO: optional, think it coordinates with
+18 -14
View File
@@ -313,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = {
/* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
/* .get_tensor_2d = */ NULL,
/* .cpy_tensor = */ NULL,
/* .clear = */ ggml_backend_zdnn_buffer_clear,
/* .reset = */ NULL,
@@ -417,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend,
}
static ggml_backend_i ggml_backend_zdnn_i = {
/* .get_name = */ ggml_backend_zdnn_name,
/* .free = */ ggml_backend_zdnn_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .graph_optimize = */ NULL,
/* .get_name = */ ggml_backend_zdnn_name,
/* .free = */ ggml_backend_zdnn_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .graph_optimize = */ NULL,
};
static ggml_guid_t ggml_backend_zdnn_guid(void) {
+2
View File
@@ -407,6 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = {
/* .free = */ ggml_backend_zendnn_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .get_tensor_2d_async = */ NULL,
/* .set_tensor_2d_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,