diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 40b4ce8ae1..a8a46d24c2 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -22,12 +22,6 @@ using namespace at; namespace detail { -void throw_nccl_error(ncclResult_t status) { - std::ostringstream err; - err << "NCCL Error " << status << ": " << ncclGetErrorString(status); - throw std::runtime_error(err.str()); -} - struct NcclCommList { std::unique_ptr comms; int ndevices; diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 9f276f76fa..e636f2e224 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -10,6 +10,8 @@ #include #include +#include +#include namespace torch { namespace cuda { @@ -19,7 +21,11 @@ namespace nccl { // Don't use them outside of these files. namespace detail { -TORCH_CUDA_API void throw_nccl_error(ncclResult_t status); +TORCH_CUDA_API inline void throw_nccl_error(ncclResult_t status) { + std::ostringstream err; + err << "NCCL Error " << status << ": " << ncclGetErrorString(status); + throw std::runtime_error(err.str()); +} static inline void NCCL_CHECK(ncclResult_t status) { if (status != ncclSuccess) {