diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake index 5bd3ccd606..f92eda98f5 100644 --- a/cmake/Modules/FindNCCL.cmake +++ b/cmake/Modules/FindNCCL.cmake @@ -48,37 +48,39 @@ find_library(NCCL_LIBRARIES include(FindPackageHandleStandardArgs) find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) if(NCCL_FOUND) # obtaining NCCL version and some sanity checks set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) include(CheckCXXSymbolExists) - check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + set(NCCL_VERSION_CODE $ENV{NCCL_VER_CODE}) + set(NCCL_VERSION_DEFINED $ENV{NCCL_VER_CODE}) - if (NCCL_VERSION_DEFINED) + if (DEFINED NCCL_VERSION_DEFINED) set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") file(WRITE ${file} " #include #include int main() { std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; int x; ncclGetVersion(&x); return x == NCCL_VERSION_CODE; } ") try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + CMAKE_FLAGS -DINCLUDE_DIRECTORIES=/opt/cuda/include RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER LINK_LIBRARIES ${NCCL_LIBRARIES}) if (NOT NCCL_VERSION_MATCHED) message(FATAL_ERROR "Found NCCL header version and library version do not match! \ (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") endif() message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") else() message(STATUS "NCCL version < 2.3.5-5") endif ()