From 4279851faa3ad679119d5a0a1bcb7c19777faa7a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 4 May 2025 18:34:40 -0400 Subject: [PATCH] feat: add a macro to directly visualize the generated mlir --- docs/src/api/api.md | 1 + src/Compiler.jl | 45 +++++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 11 ++++++++++- 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/docs/src/api/api.md b/docs/src/api/api.md index b2a32148a0..9bc3a3a820 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -27,6 +27,7 @@ within_compile @code_hlo @code_mhlo @code_xla +@mlir_visualize ``` ## Profile XLA diff --git a/src/Compiler.jl b/src/Compiler.jl index c724d10b47..b258f55b0b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -2068,6 +2068,51 @@ macro jit(args...) #! format: on end +""" + @mlir_visualize [optimize = ...] [no_nan = ] f(args...) + +Runs `@code_hlo` and visualizes the MLIR module using `model-explorer`. This expects the +`model-explorer` executable to be in your `PATH`. Installation instructions can be found +[here](https://github.com/google-ai-edge/model-explorer). +""" +macro mlir_visualize(args...) + default_options = Dict{Symbol,Any}( + :optimize => true, + :no_nan => false, + :client => nothing, + :raise => false, + :raise_first => false, + :shardy_passes => :(:to_mhlo_shardings), + :assert_nonallocating => false, + :donated_args => :(:auto), + :transpose_propagate => :(:up), + :reshape_propagate => :(:up), + :optimize_then_pad => true, + :optimize_communications => true, + :cudnn_hlo_optimize => false, + ) + compile_expr, (; compiled) = compile_call_expr( + __module__, compile_mlir, default_options, args... + ) + #! format: off + return esc( + :( + if Sys.which("model-explorer") === nothing + error("model-explorer is not in your PATH. Please install it from \ + https://github.com/google-ai-edge/model-explorer") + end; + $(compile_expr); + mlir_mod = $(first)($(compiled)); + tmpfile = tempname() * ".mlir"; + open(tmpfile, "w") do io + print(io, mlir_mod) + end; + run(`model-explorer $(tmpfile)`) + ) + ) + #! format: on +end + function compile_call_expr(mod, compiler, options::Dict, args...) while length(args) > 1 option, args = args[1], args[2:end] diff --git a/src/Reactant.jl b/src/Reactant.jl index b389d07237..1b698c3bcf 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -201,7 +201,15 @@ function Enzyme.make_zero( return res end -using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile +using .Compiler: + @compile, + @code_hlo, + @code_mhlo, + @jit, + @code_xla, + @mlir_visualize, + traced_getfield, + compile export ConcreteRArray, ConcreteRNumber, ConcretePJRTArray, @@ -214,6 +222,7 @@ export ConcreteRArray, @code_xla, @jit, @trace, + @mlir_visualize, within_compile const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()