Skip to content

Adjust torch.compile() best practices #3336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ def forward(self, x):
return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))
mod.compile()
print(mod(t))
## or:
# opt_mod = torch.compile(mod)
# print(opt_mod(t))

######################################################################
# torch.compile and Nested Calls
Expand Down Expand Up @@ -135,8 +138,8 @@ def forward(self, x):
return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))
outer_mod.compile()
print(outer_mod(t))

######################################################################
# We can also disable some functions from being compiled by using
Expand Down Expand Up @@ -197,6 +200,12 @@ def outer_function():
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
# functions and modules, start by compiling the leaf functions or modules first.
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
#
# 5. **Prefer `mod.compile()` over `torch.compile(mod)`:** Avoids `_orig_` prefix issues in `state_dict`.
#
# 6. **Use `fullgraph=True` to catch graph breaks:** Helps ensure end-to-end compilation, maximizing speedup
# and compatibility with `torch.export`.


######################################################################
# Demonstrating Speedups
Expand Down