Skip to content

CRobeck/Triton-distributed

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

đź‘‹ Hi, everyone!
We are ByteDance Seed team.

You can get to know us better through the following channels👇

seed logo

Triton-distributed

Original Triton README | README in Chinese

Triton-distributed is a distributed compiler designed for computation-communication overlapping, which is based on OpenAI Triton.

Using Triton-distributed, programmers are able to develop efficient kernels comparable to highly-optimized libraries (including Distributed-GEMM and FLUX). Triton-distributed currently mainly targets Nvidia GPU and AMD GPU. It can also be ported to other hardware platforms. Feel free to contact us if you want to use Triton-distributed on your own hardware.

Getting started

Install Triton-distributed from source

Build Guide

How to use Triton-distributed

Triton-distributed provides a set of easy-to use primitives to support the development of distributed compute-communication overlapping kernels. The primitives are divided into low-level primitives and high-level primitives. Currently, we have released our low-level primitives, and we plan to release high-level primitives in future.

Triton-distributed Primitives

Using these primitives, users can program communication kernels easily. For example, a low-latency AllToAll (with better latency than DeepEP for inference) is shown below. The performance of this example on 32 H800 GPUs is 137us (128 tokens per rank, topk=8, hidden_size=7168, dtype=fp8), while DeepEP is 182 us (note: DeepEP doesn't use NVLink for inference).

@triton.jit
def all_to_all_kernel(
    data_src,
    data_dst,
    splits_src,
    splits_dst,
    signal,
    splits_cumsum,
    scale_src,
    scale_dst,
    rank: int,
    call_count: int,
    WITH_SCALE: tl.constexpr,
    WORLD_SIZE: tl.constexpr,
    HIDDEN: tl.constexpr,
    MAX_M: tl.constexpr,
    EXPERTS_PER_RANK: tl.constexpr,
    NUM_TOT_EXPERTS: tl.constexpr,
    ELEMENT_SIZE: tl.constexpr = 2,
    SCALE_ELEMENT_SIZE: tl.constexpr = 4,
):
    pid = tl.program_id(0)
    threadidx = tid(axis=0)

    exp_st = pid * EXPERTS_PER_RANK
    exp_ed = exp_st + EXPERTS_PER_RANK

    m_st = tl.load(splits_cumsum + exp_st)
    m_ed = tl.load(splits_cumsum + exp_ed)
    num_rows_cur_block = m_ed - m_st

    src_off = m_st
    dst_off = rank * MAX_M

    split_src_ptr = splits_src + exp_st
    off0 = exp_st + tl.arange(0, EXPERTS_PER_RANK)
    off1 = exp_st + tl.arange(0, EXPERTS_PER_RANK) + 1
    cumsum_sts = tl.load(splits_cumsum + off0)
    cumsum_eds = tl.load(splits_cumsum + off1)
    tl.store(split_src_ptr + tl.arange(0, EXPERTS_PER_RANK), cumsum_eds - cumsum_sts)

    act_pos = call_count % 2
    data_dst_ptr = data_dst + act_pos * WORLD_SIZE * MAX_M * HIDDEN + dst_off * HIDDEN
    split_dst_ptr = splits_dst + act_pos * NUM_TOT_EXPERTS + rank * EXPERTS_PER_RANK
    signal_ptr = signal + act_pos * WORLD_SIZE + rank

    libshmem_device.putmem_nbi_block(
        data_dst_ptr,
        data_src + src_off * HIDDEN,
        num_rows_cur_block * HIDDEN * ELEMENT_SIZE,
        pid,
    )
    libshmem_device.putmem_nbi_block(
        split_dst_ptr,
        split_src_ptr,
        EXPERTS_PER_RANK * 4,  # now we use `int32` for splits
        pid,
    )
    if WITH_SCALE:
        scale_dst_ptr = scale_dst + act_pos * WORLD_SIZE * MAX_M + dst_off
        libshmem_device.putmem_signal_nbi_block(
            scale_dst_ptr,
            scale_src + src_off,
            num_rows_cur_block * SCALE_ELEMENT_SIZE,
            signal_ptr,
            call_count,
            libshmem_device.NVSHMEM_SIGNAL_SET,
            pid,
        )

    libshmem_device.fence()
    if threadidx == 0:
        if not WITH_SCALE:
            libshmem_device.signal_op(
                signal_ptr,
                call_count,
                libshmem_device.NVSHMEM_SIGNAL_SET,
                pid,
            )
        libshmem_device.signal_wait_until(
            signal + act_pos * WORLD_SIZE + pid,
            libshmem_device.NVSHMEM_CMP_EQ,
            call_count,
        )

Also, users can combine the communication part with computation part to design overlapping kernels. We have provided example implementations in third_party/distributed/distributed/kernels.

Performance

Triton-distributed can achieve comparable or better performance than hand-tuned libraries.

AllGather GEMM on single node of H800x8

Ag-GEMM-inter-node

GEMM ReduceScatter on single node of H800x8

Ag-GEMM-inter-node

AllGather GEMM on 2 nodes of H800x8

Ag-GEMM-inter-node

GEMM ReduceScatter on 2 nodes of H800x8

GEMM-Rs-inter-node

Scaling of Distributed Flash-Decode from 1 GPU to 32 GPUs

The batch size is 1 (one query) for decoding. flash-decode-inter-node

Performance on Other Platforms

AMD GPUs

Roadmaps

Functionalities

  • Release low-level primitives
  • Release high-level primitives
  • Tutorials
  • Pre-built binary

Kernels

  • Release single-node GEMM TP overlapping kernels
  • Release single-node MoE TP overlapping kernels
  • Release single-node distributed Flash-Decoding kernels
  • Release single-node MoE EP overlapping kernels
  • Release cross-node GEMM TP overlapping kernels
  • Release cross-node MoE TP overlapping kernels
  • Release cross-node distributed Flash-Decoding kernels
  • Release cross-node EP all-to-all kernels (similar to DeepEP)
  • Provide tutorials for kernel implementation

Backends

Computation

  • Nvidia SM90a support
  • Nvidia SM80 support
  • Nvidia SM89 support
  • AMD CDNA3 support

Communication

  • NVLink
  • IB
  • PCIe

Performance

  • Performance report

License

The Triton-distributed project is under MIT license. Part of our code is under Apache-2.0 License:

  • third_party/distributed/distributed/kernels/flash_decode.py

Triton's original code is partially under Apache-2.0 License, these files include:

  • include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
  • lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
  • python/triton/_C/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h
  • utils/generate-test-checks.py

Citation

If you use Triton-distributed in a scientific publication, we encourage you to add the following reference to the related papers:

@misc{zheng2025tilelink,
      title={TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives},
      author={Size Zheng, Jin Fang, Xuegui Zheng, Qi Hou, Wenlei Bao, Ningxin Zheng, Ziheng Jiang, Dongyang Wang, Jianxi Ye, Haibin Lin, Li-Wen Chang, Xin Liu},
      year={2025},
}

Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society.

Discussion and Contribution

Please use issues or pull requests for discussion and contribution (see CONTRIBUTING.md).

About

Distributed Triton for Parallel Systems

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • MLIR 41.6%
  • C++ 34.3%
  • Python 23.2%
  • CMake 0.6%
  • Shell 0.1%
  • C 0.1%
  • Other 0.1%