diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 302b1ab0d9deb9..ed7c5d089d80c9 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -328,8 +328,8 @@ at::BlasBackend Context::blasPreferredBackend() { "gfx1100", "gfx1101", "gfx1200", "gfx1201" #endif }; - for (auto index: c10::irange(getNumGPUs())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use hipBLASLt on an unsupported architecture! " "Overriding blas backend to hipblas"); @@ -381,8 +381,8 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { static const std::vector archs = { "gfx90a", "gfx942" }; - for (auto index: c10::irange(getNumGPUs())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); return true; diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 9b183848503ef9..e88c0bd5dab2b4 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -124,9 +124,7 @@ size_t parseChosenWorkspaceSize() { val = getenv("ROCBLAS_WORKSPACE_CONFIG"); } /* 32MiB default, 128MiB for MI300 */ - cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = properties->gcnArchName; - const bool gfx94 = device_arch.find("gfx94") != std::string::npos; + const bool gfx94 = at::detail::getCUDAHooks().isGPUArch({"gfx94"}); const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index c6e83fad1a7f12..f21c508957bc77 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -441,8 +441,14 @@ int CUDAHooks::getNumGPUs() const { } #ifdef USE_ROCM -bool CUDAHooks::isGPUArch(DeviceIndex device_index, const std::vector& archs) const { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device_index); +bool CUDAHooks::isGPUArch(const std::vector& archs, DeviceIndex device_index) const { + hipDeviceProp_t* prop; + if (device_index == -1){ + prop = at::cuda::getCurrentDeviceProperties(); + } else { + prop = at::cuda::getDeviceProperties(device_index); + } + std::string device_arch = prop->gcnArchName; for (std::string arch : archs) { size_t substring = device_arch.find(arch); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index c23998fda56b67..9ef5ed439fdda7 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -50,7 +50,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { void cuFFTClearPlanCache(DeviceIndex device_index) const override; int getNumGPUs() const override; #ifdef USE_ROCM - bool isGPUArch(DeviceIndex device_index, const std::vector& archs) const override; + bool isGPUArch(const std::vector& archs, DeviceIndex device_index = -1) const override; #endif void deviceSynchronize(DeviceIndex device_index) const override; }; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f9a3fa098508f6..42d79841a997f2 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -187,7 +187,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { } #ifdef USE_ROCM - virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector& /*archs*/) const { + virtual bool isGPUArch(const std::vector& /*archs*/, DeviceIndex = -1 /*device_index*/) const { TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); } #endif diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1496806b917ad3..9f78a64296904e 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -188,16 +188,7 @@ static bool getDisableAddmmCudaLt() { #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx940", "gfx941", "gfx942"}, index); } #endif @@ -846,18 +837,10 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { } static bool _scaled_mm_allowed_device() { - auto dprops = at::cuda::getCurrentDeviceProperties(); #ifdef USE_ROCM - std::string device_arch = dprops->gcnArchName; - static const std::vector archs = {"gfx940", "gfx941", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx940", "gfx941", "gfx942"}); #else + auto dprops = at::cuda::getCurrentDeviceProperties(); return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); #endif } diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 572a490242b170..01c9d00804059f 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -149,6 +149,176 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { } } +// This API is for detecting whether the permute parameter of a three-dimensional tensor +// in the Copy operation from src to dst is from [0, 1, 2] to [0, 2, 1]. +bool is_permute_021(TensorIteratorBase &iter) { + const auto& input = iter.tensor(1); + const auto& output = iter.tensor(0); + bool is_permute = false; + if (input.dim() == 3) { + is_permute = true; + is_permute &= input.dim() == output.dim(); + + is_permute &= input.stride(0) == input.size(2) * input.stride(2); + is_permute &= input.stride(1) == 1; + is_permute &= input.stride(2) >= input.size(1); + is_permute &= output.is_contiguous(); + } + return is_permute; +} + +template +__global__ void transpose_tile_big_kernel(const void* __restrict a, void* __restrict c, const int N, const int K, const int STRIDE_N_ELE) +{ + constexpr uint32_t elements_in_128b = 16 / sizeof(T); + union BLOCK_16B + { + T e[elements_in_128b]; + __uint128_t ow; + }; + + constexpr int LDS_PAD = (4 / sizeof(T)); + // Round up processing to next full tile + const uint32_t n_tiles = (N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N; + const uint32_t k_tiles = (K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K; + const uint32_t nk_tiles = n_tiles * k_tiles; + const uint32_t m_tiles = gridDim.x / nk_tiles; + const uint32_t m_tile_swizzle = blockIdx.x / nk_tiles / M_SWIZZLE * M_SWIZZLE; + /// do m_swizzle when there are enough m_tiles + const bool swizzle_m = m_tile_swizzle + M_SWIZZLE <= m_tiles; + const uint32_t current_m = swizzle_m ? m_tile_swizzle + blockIdx.x % M_SWIZZLE : blockIdx.x / nk_tiles; + const uint64_t stride_n = STRIDE_N_ELE * sizeof(T); + const uint64_t stride_k = N * sizeof(T); + const uint64_t out_stride_nk = N * K * sizeof(T); + const uint64_t in_stride_nk = N * STRIDE_N_ELE * sizeof(T); + + const uint32_t current_nk = swizzle_m ? blockIdx.x / M_SWIZZLE % nk_tiles : blockIdx.x % nk_tiles; + const uint32_t ti = current_nk / k_tiles; + const uint32_t tj = current_nk % k_tiles; + + __shared__ T smem[BIG_TILE_SIZE_N][BIG_TILE_SIZE_K + LDS_PAD]; + + // Detect partial tiles + const uint32_t current_n_size = (ti == (n_tiles - 1) && (N % BIG_TILE_SIZE_N) != 0) ? (N % BIG_TILE_SIZE_N) : BIG_TILE_SIZE_N; + const uint32_t current_k_size = (tj == (k_tiles - 1) && (K % BIG_TILE_SIZE_K) != 0) ? (K % BIG_TILE_SIZE_K) : BIG_TILE_SIZE_K; + //use 128bit load&store whenever possible + if (current_n_size % 8 == 0 && current_k_size % 8 == 0) + { + // Copy full tile with large loads + constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T); + constexpr uint32_t ld_per_row = row_bytes / sizeof(__uint128_t); + constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row; + constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg; + // Make sure WG isn't too large + static_assert(vmem_per_thread >= 1); + + const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk; + #pragma unroll + for (uint32_t t = 0; t < vmem_per_thread; t++) + { + uint32_t col = threadIdx.x % ld_per_row; + uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg; + uint64_t offset = (col * 8 < current_k_size && row < current_n_size) ? + row * stride_n + col * sizeof(__uint128_t) : 0; + const __uint128_t* pfa = (const __uint128_t*)(pat + offset); + BLOCK_16B d; + d.ow = *pfa; + #pragma unroll + for (uint32_t i = 0; i < elements_in_128b; i++) + { + smem[row][col * elements_in_128b + i] = d.e[i]; + } + } + __syncthreads(); + // Copy full tile with large loads + constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T); + constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t); + constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; + constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; + // Make sure WG isn't too large + static_assert(wr_per_row >= 1); + const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk; + #pragma unroll + for (uint32_t t = 0; t < wr_per_row; t++) + { + uint32_t col = threadIdx.x % vmem_per_row_wr; + uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr; + if (col * 8 < current_n_size && row < current_k_size) + { + uint64_t offset = row * stride_k + col * sizeof(__uint128_t); + BLOCK_16B d; + // Transpose tile on read from LDS + #pragma unroll + for (uint32_t i = 0; i < elements_in_128b; i++) + { + d.e[i] = smem[col * elements_in_128b + i][row]; + } + __uint128_t* pfc = (__uint128_t*)(pc + offset); + *pfc = d.ow; + } + } + } + else + { + // Copy partial tiles with element accesses + constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T); + constexpr uint32_t ld_per_row = BIG_TILE_SIZE_K; + constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row; + constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg; + // Make sure WG isn't too large + static_assert(vmem_per_thread >= 1); + + const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk; + #pragma unroll + for (uint32_t t = 0; t < vmem_per_thread; t++) + { + uint32_t col = threadIdx.x % ld_per_row; + uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg; + uint64_t offset = (col < current_k_size && row < current_n_size) ? row * stride_n + col * 2 : 0; + const uint16_t* pfa = (const uint16_t*)(pat + offset); + smem[row][col] = *pfa; + } + __syncthreads(); + // Copy full tile with large loads + constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T); + constexpr uint32_t vmem_per_row_wr = BIG_TILE_SIZE_N; + constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; + constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; + const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk; + #pragma unroll + for (uint32_t t = 0; t < wr_per_row; t++) + { + uint32_t col = threadIdx.x % vmem_per_row_wr; + uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr; + if (col < current_n_size && row < current_k_size) + { + uint64_t offset = row * stride_k + col * 2; + uint16_t* pfc = (uint16_t*)(pc + offset); + *pfc = smem[col][row]; + } + } + } +} + +void transpose_last2dim(TensorIteratorBase &iter) { + void* dst = iter.data_ptr(0); + void* src = iter.data_ptr(1); + const auto& input = iter.tensor(1); + + int M = input.size(0); + int N = input.size(1); + int K = input.size(2); + + auto stream = c10::cuda::getCurrentCUDAStream(); + constexpr uint32_t BIG_TILE_SIZE_N = 64; + constexpr uint32_t BIG_TILE_SIZE_K = 64; + constexpr uint32_t M_SWIZZLE = 8; + const int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K); + const dim3 grid_dim(grid_x, 1, 1); + const dim3 block_dim(256, 1, 1); + transpose_tile_big_kernel<<>>(src, dst, K, N, input.stride(2)); +} + // TODO: We probably can use the opaque type trick to avoid creating duplicate // kernels for equivalent bit lengths void direct_copy_kernel_cuda(TensorIteratorBase &iter) { @@ -171,6 +341,8 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); + } else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf) && at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx942", "gfx950"})) { + transpose_last2dim(iter); } else { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] { diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 46eb556f917832..e12aa5d309583f 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -546,7 +546,7 @@ inline int can_vectorize_up_to(array_t pointers) { #else c10::DeviceIndex curDevice = -1; AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); - if (at::detail::getCUDAHooks().isGPUArch(curDevice, {"gfx942"})) { + if (at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice)) { if (pointers[0] == pointers[1]) return 4; } diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index ff11ea8a96539f..d9f2c26063cf65 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -135,16 +135,7 @@ template using VecT = T __attribute__((ext_vector_type(Rank))); static bool isCDNA2orLater(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx940", "gfx941", "gfx942"}, index); } #else