Skip to content

Commit 4279851

Browse files
committed
feat: add a macro to directly visualize the generated mlir
1 parent 93283e5 commit 4279851

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

docs/src/api/api.md

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ within_compile
2727
@code_hlo
2828
@code_mhlo
2929
@code_xla
30+
@mlir_visualize
3031
```
3132

3233
## Profile XLA

src/Compiler.jl

+45
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,51 @@ macro jit(args...)
20682068
#! format: on
20692069
end
20702070

2071+
"""
2072+
@mlir_visualize [optimize = ...] [no_nan = <true/false>] f(args...)
2073+
2074+
Runs `@code_hlo` and visualizes the MLIR module using `model-explorer`. This expects the
2075+
`model-explorer` executable to be in your `PATH`. Installation instructions can be found
2076+
[here](https://github.com/google-ai-edge/model-explorer).
2077+
"""
2078+
macro mlir_visualize(args...)
2079+
default_options = Dict{Symbol,Any}(
2080+
:optimize => true,
2081+
:no_nan => false,
2082+
:client => nothing,
2083+
:raise => false,
2084+
:raise_first => false,
2085+
:shardy_passes => :(:to_mhlo_shardings),
2086+
:assert_nonallocating => false,
2087+
:donated_args => :(:auto),
2088+
:transpose_propagate => :(:up),
2089+
:reshape_propagate => :(:up),
2090+
:optimize_then_pad => true,
2091+
:optimize_communications => true,
2092+
:cudnn_hlo_optimize => false,
2093+
)
2094+
compile_expr, (; compiled) = compile_call_expr(
2095+
__module__, compile_mlir, default_options, args...
2096+
)
2097+
#! format: off
2098+
return esc(
2099+
:(
2100+
if Sys.which("model-explorer") === nothing
2101+
error("model-explorer is not in your PATH. Please install it from \
2102+
https://github.com/google-ai-edge/model-explorer")
2103+
end;
2104+
$(compile_expr);
2105+
mlir_mod = $(first)($(compiled));
2106+
tmpfile = tempname() * ".mlir";
2107+
open(tmpfile, "w") do io
2108+
print(io, mlir_mod)
2109+
end;
2110+
run(`model-explorer $(tmpfile)`)
2111+
)
2112+
)
2113+
#! format: on
2114+
end
2115+
20712116
function compile_call_expr(mod, compiler, options::Dict, args...)
20722117
while length(args) > 1
20732118
option, args = args[1], args[2:end]

src/Reactant.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,15 @@ function Enzyme.make_zero(
201201
return res
202202
end
203203

204-
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
204+
using .Compiler:
205+
@compile,
206+
@code_hlo,
207+
@code_mhlo,
208+
@jit,
209+
@code_xla,
210+
@mlir_visualize,
211+
traced_getfield,
212+
compile
205213
export ConcreteRArray,
206214
ConcreteRNumber,
207215
ConcretePJRTArray,
@@ -214,6 +222,7 @@ export ConcreteRArray,
214222
@code_xla,
215223
@jit,
216224
@trace,
225+
@mlir_visualize,
217226
within_compile
218227

219228
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

0 commit comments

Comments
 (0)