We are ByteDance Seed team.
You can get to know us better through the following channels👇
Original Triton README | README in Chinese
Triton-distributed is a distributed compiler designed for computation-communication overlapping, which is based on OpenAI Triton.
Using Triton-distributed, programmers are able to develop efficient kernels comparable to highly-optimized libraries (including Distributed-GEMM and FLUX). Triton-distributed currently mainly targets Nvidia GPU and AMD GPU. It can also be ported to other hardware platforms. Feel free to contact us if you want to use Triton-distributed on your own hardware.
Triton-distributed provides a set of easy-to use primitives to support the development of distributed compute-communication overlapping kernels. The primitives are divided into low-level primitives and high-level primitives. Currently, we have released our low-level primitives, and we plan to release high-level primitives in future.
Using these primitives, users can program communication kernels easily. For example, a low-latency AllToAll (with better latency than DeepEP for inference) is shown below. The performance of this example on 32 H800 GPUs is 137us (128 tokens per rank, topk=8, hidden_size=7168, dtype=fp8), while DeepEP is 182 us (note: DeepEP doesn't use NVLink for inference).
@triton.jit
def all_to_all_kernel(
data_src,
data_dst,
splits_src,
splits_dst,
signal,
splits_cumsum,
scale_src,
scale_dst,
rank: int,
call_count: int,
WITH_SCALE: tl.constexpr,
WORLD_SIZE: tl.constexpr,
HIDDEN: tl.constexpr,
MAX_M: tl.constexpr,
EXPERTS_PER_RANK: tl.constexpr,
NUM_TOT_EXPERTS: tl.constexpr,
ELEMENT_SIZE: tl.constexpr = 2,
SCALE_ELEMENT_SIZE: tl.constexpr = 4,
):
pid = tl.program_id(0)
threadidx = tid(axis=0)
exp_st = pid * EXPERTS_PER_RANK
exp_ed = exp_st + EXPERTS_PER_RANK
m_st = tl.load(splits_cumsum + exp_st)
m_ed = tl.load(splits_cumsum + exp_ed)
num_rows_cur_block = m_ed - m_st
src_off = m_st
dst_off = rank * MAX_M
split_src_ptr = splits_src + exp_st
off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK)
off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1
cumsum_sts = tl.load(splits_cumsum + off0)
cumsum_eds = tl.load(splits_cumsum + off1)
tl.store(split_src_ptr + tl.arange(0, EXPERTS_PER_RANK), cumsum_eds - cumsum_sts)
act_pos = call_count % 2
data_dst_ptr = data_dst + act_pos * WORLD_SIZE * MAX_M * HIDDEN + dst_off * HIDDEN
split_dst_ptr = splits_dst + act_pos * NUM_TOT_EXPERTS + rank * EXPERTS_PER_RANK
signal_ptr = signal + act_pos * WORLD_SIZE + rank
libshmem_device.putmem_nbi_block(
data_dst_ptr,
data_src + src_off * HIDDEN,
num_rows_cur_block * HIDDEN * ELEMENT_SIZE,
pid,
)
libshmem_device.putmem_nbi_block(
split_dst_ptr,
split_src_ptr,
EXPERTS_PER_RANK * 4, # now we use `int32` for splits
pid,
)
if WITH_SCALE:
scale_dst_ptr = scale_dst + act_pos * WORLD_SIZE * MAX_M + dst_off
libshmem_device.putmem_signal_nbi_block(
scale_dst_ptr,
scale_src + src_off,
num_rows_cur_block * SCALE_ELEMENT_SIZE,
signal_ptr,
call_count,
libshmem_device.NVSHMEM_SIGNAL_SET,
pid,
)
libshmem_device.fence()
if threadidx == 0:
if not WITH_SCALE:
libshmem_device.signal_op(
signal_ptr,
call_count,
libshmem_device.NVSHMEM_SIGNAL_SET,
pid,
)
libshmem_device.signal_wait_until(
signal + act_pos * WORLD_SIZE + pid,
libshmem_device.NVSHMEM_CMP_EQ,
call_count,
)
Also, users can combine the communication part with computation part to design overlapping kernels. We have provided example implementations in third_party/distributed/distributed/kernels
.
Triton-distributed can achieve comparable or better performance than hand-tuned libraries.
The batch size is 1 (one query) for decoding.
- Release low-level primitives
- Release high-level primitives
- Tutorials
- Pre-built binary
- Release single-node GEMM TP overlapping kernels
- Release single-node MoE TP overlapping kernels
- Release single-node distributed Flash-Decoding kernels
- Release single-node MoE EP overlapping kernels
- Release cross-node GEMM TP overlapping kernels
- Release cross-node MoE TP overlapping kernels
- Release cross-node distributed Flash-Decoding kernels
- Release cross-node EP all-to-all kernels (similar to DeepEP)
- Provide tutorials for kernel implementation
Computation
- Nvidia SM90a support
- Nvidia SM80 support
- Nvidia SM89 support
- AMD CDNA3 support
Communication
- NVLink
- IB
- PCIe
- Performance report
The Triton-distributed project is under MIT license. Part of our code is under Apache-2.0 License:
third_party/distributed/distributed/kernels/flash_decode.py
Triton's original code is partially under Apache-2.0 License, these files include:
include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
python/triton/_C/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
utils/generate-test-checks.py
If you use Triton-distributed in a scientific publication, we encourage you to add the following reference to the related papers:
@misc{zheng2025tilelink,
title={TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives},
author={Size Zheng, Jin Fang, Xuegui Zheng, Qi Hou, Wenlei Bao, Ningxin Zheng, Ziheng Jiang, Dongyang Wang, Jianxi Ye, Haibin Lin, Li-Wen Chang, Xin Liu},
year={2025},
}
About ByteDance Seed Team
Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.
Please use issues or pull requests for discussion and contribution (see CONTRIBUTING.md).