diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 71cc4b31a995c..dad26cc828959 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1373,13 +1373,6 @@ if(USE_ROCM) set(ROCM_SOURCE_DIR "/opt/rocm") endif() message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}") - target_include_directories(torch_hip PRIVATE - ${ROCM_SOURCE_DIR}/include - ${ROCM_SOURCE_DIR}/hcc/include - ${ROCM_SOURCE_DIR}/rocblas/include - ${ROCM_SOURCE_DIR}/hipsparse/include - ${ROCM_SOURCE_DIR}/include/rccl/ - ) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION @@ -1713,7 +1706,7 @@ if(USE_ROCM) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) # Since PyTorch files contain HIP headers, this is also needed to capture the includes. - target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE}) + target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS}) target_include_directories(torch_hip INTERFACE $) endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 28d15a5ea1b73..c933da049b4cd 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -26,10 +26,10 @@ else() endif() endif() -if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS}) - set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include) +if(NOT DEFINED ENV{ROCM_INCLUDE_DIR}) + set(ROCM_INCLUDE_DIR ${ROCM_PATH}/include) else() - set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) + set(ROCM_INCLUDE_DIR $ENV{ROCM_INCLUDE_DIR}) endif() # MAGMA_HOME @@ -72,6 +72,7 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) macro(find_package_and_print_version PACKAGE_NAME) find_package("${PACKAGE_NAME}" ${ARGN}) message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR}) endmacro() # Find the HIP Package @@ -82,9 +83,16 @@ find_package_and_print_version(HIP 1.0 MODULE) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) + find_package_and_print_version(hip REQUIRED CONFIG) # Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h - find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h hip_version.h - HINTS ${ROCM_INCLUDE_DIRS}/rocm-core ${ROCM_INCLUDE_DIRS}/hip /usr/include) + if(UNIX) + find_package_and_print_version(rocm-core REQUIRED CONFIG) + find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h + HINTS ${rocm_core_INCLUDE_DIR}/rocm-core /usr/include) + else() # Win32 + find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h + HINTS ${hip_INCLUDE_DIR}/hip /usr/include) + endif() get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME) if(EXISTS ${ROCM_VERSION_HEADER_PATH}) @@ -141,7 +149,6 @@ if(HIP_FOUND) # Find ROCM components using Config mode # These components will be searced for recursively in ${ROCM_PATH} message("\n***** Library versions from cmake find_package *****\n") - find_package_and_print_version(hip REQUIRED CONFIG) find_package_and_print_version(amd_comgr REQUIRED) find_package_and_print_version(rocrand REQUIRED) find_package_and_print_version(hiprand REQUIRED) @@ -168,7 +175,11 @@ if(HIP_FOUND) if(UNIX) find_package_and_print_version(rccl) find_package_and_print_version(hsa-runtime64 REQUIRED) + endif() + list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS) + + if(UNIX) # roctx is part of roctracer find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)