Skip to content

[release/2.5] aten::copy optimization (revised) #2032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: release/2.5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ at::BlasBackend Context::blasPreferredBackend() {
"gfx1100", "gfx1101", "gfx1200", "gfx1201"
#endif
};
for (auto index: c10::irange(getNumGPUs())) {
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use hipBLASLt on an unsupported architecture! "
"Overriding blas backend to hipblas");
Expand Down Expand Up @@ -381,8 +381,8 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942"
};
for (auto index: c10::irange(getNumGPUs())) {
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return true;
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ size_t parseChosenWorkspaceSize() {
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
}
/* 32MiB default, 128MiB for MI300 */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
std::string device_arch = properties->gcnArchName;
const bool gfx94 = device_arch.find("gfx94") != std::string::npos;
const bool gfx94 = at::detail::getCUDAHooks().isGPUArch({"gfx94"});
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
#else
/* :4096:2:16:8 default, 32MiB for Hopper */
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,14 @@ int CUDAHooks::getNumGPUs() const {
}

#ifdef USE_ROCM
bool CUDAHooks::isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const {
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device_index);
bool CUDAHooks::isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index) const {
hipDeviceProp_t* prop;
if (device_index == -1){
prop = at::cuda::getCurrentDeviceProperties();
} else {
prop = at::cuda::getDeviceProperties(device_index);
}

std::string device_arch = prop->gcnArchName;
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
int getNumGPUs() const override;
#ifdef USE_ROCM
bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
bool isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index = -1) const override;
#endif
void deviceSynchronize(DeviceIndex device_index) const override;
};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
}

#ifdef USE_ROCM
virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector<std::string>& /*archs*/) const {
virtual bool isGPUArch(const std::vector<std::string>& /*archs*/, DeviceIndex = -1 /*device_index*/) const {
TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP);
}
#endif
Expand Down
23 changes: 3 additions & 20 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,7 @@ static bool getDisableAddmmCudaLt() {

#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
if (substring != std::string::npos) {
return true;
}
}
return false;
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx940", "gfx941", "gfx942"}, index);
}
#endif

Expand Down Expand Up @@ -846,18 +837,10 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
}

static bool _scaled_mm_allowed_device() {
auto dprops = at::cuda::getCurrentDeviceProperties();
#ifdef USE_ROCM
std::string device_arch = dprops->gcnArchName;
static const std::vector<std::string> archs = {"gfx940", "gfx941", "gfx942"};
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
if (substring != std::string::npos) {
return true;
}
}
return false;
return at::detail::getCUDAHooks().isGPUArch({"gfx940", "gfx941", "gfx942"});
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
#endif
}
Expand Down
172 changes: 172 additions & 0 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,176 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
}
}

// This API is for detecting whether the permute parameter of a three-dimensional tensor
// in the Copy operation from src to dst is from [0, 1, 2] to [0, 2, 1].
bool is_permute_021(TensorIteratorBase &iter) {
const auto& input = iter.tensor(1);
const auto& output = iter.tensor(0);
bool is_permute = false;
if (input.dim() == 3) {
is_permute = true;
is_permute &= input.dim() == output.dim();

is_permute &= input.stride(0) == input.size(2) * input.stride(2);
is_permute &= input.stride(1) == 1;
is_permute &= input.stride(2) >= input.size(1);
is_permute &= output.is_contiguous();
}
return is_permute;
}

template<typename T, int BLOCK_SIZE, int BIG_TILE_SIZE_N, int BIG_TILE_SIZE_K, int M_SWIZZLE>
__global__ void transpose_tile_big_kernel(const void* __restrict a, void* __restrict c, const int N, const int K, const int STRIDE_N_ELE)
{
constexpr uint32_t elements_in_128b = 16 / sizeof(T);
union BLOCK_16B
{
T e[elements_in_128b];
__uint128_t ow;
};

constexpr int LDS_PAD = (4 / sizeof(T));
// Round up processing to next full tile
const uint32_t n_tiles = (N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N;
const uint32_t k_tiles = (K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K;
const uint32_t nk_tiles = n_tiles * k_tiles;
const uint32_t m_tiles = gridDim.x / nk_tiles;
const uint32_t m_tile_swizzle = blockIdx.x / nk_tiles / M_SWIZZLE * M_SWIZZLE;
/// do m_swizzle when there are enough m_tiles
const bool swizzle_m = m_tile_swizzle + M_SWIZZLE <= m_tiles;
const uint32_t current_m = swizzle_m ? m_tile_swizzle + blockIdx.x % M_SWIZZLE : blockIdx.x / nk_tiles;
const uint64_t stride_n = STRIDE_N_ELE * sizeof(T);
const uint64_t stride_k = N * sizeof(T);
const uint64_t out_stride_nk = N * K * sizeof(T);
const uint64_t in_stride_nk = N * STRIDE_N_ELE * sizeof(T);

const uint32_t current_nk = swizzle_m ? blockIdx.x / M_SWIZZLE % nk_tiles : blockIdx.x % nk_tiles;
const uint32_t ti = current_nk / k_tiles;
const uint32_t tj = current_nk % k_tiles;

__shared__ T smem[BIG_TILE_SIZE_N][BIG_TILE_SIZE_K + LDS_PAD];

// Detect partial tiles
const uint32_t current_n_size = (ti == (n_tiles - 1) && (N % BIG_TILE_SIZE_N) != 0) ? (N % BIG_TILE_SIZE_N) : BIG_TILE_SIZE_N;
const uint32_t current_k_size = (tj == (k_tiles - 1) && (K % BIG_TILE_SIZE_K) != 0) ? (K % BIG_TILE_SIZE_K) : BIG_TILE_SIZE_K;
//use 128bit load&store whenever possible
if (current_n_size % 8 == 0 && current_k_size % 8 == 0)
{
// Copy full tile with large loads
constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T);
constexpr uint32_t ld_per_row = row_bytes / sizeof(__uint128_t);
constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row;
constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg;
// Make sure WG isn't too large
static_assert(vmem_per_thread >= 1);

const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < vmem_per_thread; t++)
{
uint32_t col = threadIdx.x % ld_per_row;
uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg;
uint64_t offset = (col * 8 < current_k_size && row < current_n_size) ?
row * stride_n + col * sizeof(__uint128_t) : 0;
const __uint128_t* pfa = (const __uint128_t*)(pat + offset);
BLOCK_16B d;
d.ow = *pfa;
#pragma unroll
for (uint32_t i = 0; i < elements_in_128b; i++)
{
smem[row][col * elements_in_128b + i] = d.e[i];
}
}
__syncthreads();
// Copy full tile with large loads
constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T);
constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t);
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
// Make sure WG isn't too large
static_assert(wr_per_row >= 1);
const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < wr_per_row; t++)
{
uint32_t col = threadIdx.x % vmem_per_row_wr;
uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr;
if (col * 8 < current_n_size && row < current_k_size)
{
uint64_t offset = row * stride_k + col * sizeof(__uint128_t);
BLOCK_16B d;
// Transpose tile on read from LDS
#pragma unroll
for (uint32_t i = 0; i < elements_in_128b; i++)
{
d.e[i] = smem[col * elements_in_128b + i][row];
}
__uint128_t* pfc = (__uint128_t*)(pc + offset);
*pfc = d.ow;
}
}
}
else
{
// Copy partial tiles with element accesses
constexpr uint32_t row_bytes = BIG_TILE_SIZE_K * sizeof(T);
constexpr uint32_t ld_per_row = BIG_TILE_SIZE_K;
constexpr uint32_t rows_per_wg = BLOCK_SIZE / ld_per_row;
constexpr uint32_t vmem_per_thread = BIG_TILE_SIZE_N / rows_per_wg;
// Make sure WG isn't too large
static_assert(vmem_per_thread >= 1);

const uint8_t* pat = (const uint8_t*)a + tj * row_bytes + ti * BIG_TILE_SIZE_N * stride_n + current_m * in_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < vmem_per_thread; t++)
{
uint32_t col = threadIdx.x % ld_per_row;
uint32_t row = threadIdx.x / ld_per_row + t * rows_per_wg;
uint64_t offset = (col < current_k_size && row < current_n_size) ? row * stride_n + col * 2 : 0;
const uint16_t* pfa = (const uint16_t*)(pat + offset);
smem[row][col] = *pfa;
}
__syncthreads();
// Copy full tile with large loads
constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T);
constexpr uint32_t vmem_per_row_wr = BIG_TILE_SIZE_N;
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk;
#pragma unroll
for (uint32_t t = 0; t < wr_per_row; t++)
{
uint32_t col = threadIdx.x % vmem_per_row_wr;
uint32_t row = threadIdx.x / vmem_per_row_wr + t * rows_per_wg_wr;
if (col < current_n_size && row < current_k_size)
{
uint64_t offset = row * stride_k + col * 2;
uint16_t* pfc = (uint16_t*)(pc + offset);
*pfc = smem[col][row];
}
}
}
}

void transpose_last2dim(TensorIteratorBase &iter) {
void* dst = iter.data_ptr(0);
void* src = iter.data_ptr(1);
const auto& input = iter.tensor(1);

int M = input.size(0);
int N = input.size(1);
int K = input.size(2);

auto stream = c10::cuda::getCurrentCUDAStream();
constexpr uint32_t BIG_TILE_SIZE_N = 64;
constexpr uint32_t BIG_TILE_SIZE_K = 64;
constexpr uint32_t M_SWIZZLE = 8;
const int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K);
const dim3 grid_dim(grid_x, 1, 1);
const dim3 block_dim(256, 1, 1);
transpose_tile_big_kernel<uint16_t, 256, BIG_TILE_SIZE_N, BIG_TILE_SIZE_K, M_SWIZZLE><<<grid_dim, block_dim, 0, stream>>>(src, dst, K, N, input.stride(2));
}

// TODO: We probably can use the opaque type trick to avoid creating duplicate
// kernels for equivalent bit lengths
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
Expand All @@ -171,6 +341,8 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf) && at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx942", "gfx950"})) {
transpose_last2dim(iter);
} else {
AT_DISPATCH_V2(
dtype, "copy_", AT_WRAP([&] {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ inline int can_vectorize_up_to(array_t pointers) {
#else
c10::DeviceIndex curDevice = -1;
AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice));
if (at::detail::getCUDAHooks().isGPUArch(curDevice, {"gfx942"})) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice)) {
if (pointers[0] == pointers[1])
return 4;
}
Expand Down
11 changes: 1 addition & 10 deletions aten/src/ATen/native/cuda/int4mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,7 @@ template<typename T, uint32_t Rank>
using VecT = T __attribute__((ext_vector_type(Rank)));

static bool isCDNA2orLater(int index) {
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
if (substring != std::string::npos) {
return true;
}
}
return false;
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx940", "gfx941", "gfx942"}, index);
}

#else
Expand Down