diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index a5c1b345e9..de31af04dc 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -101,8 +101,11 @@ def forward(self, x): return torch.nn.functional.relu(self.lin(x)) mod = MyModule() -opt_mod = torch.compile(mod) -print(opt_mod(t)) +mod.compile() +print(mod(t)) +## or: +# opt_mod = torch.compile(mod) +# print(opt_mod(t)) ###################################################################### # torch.compile and Nested Calls @@ -135,8 +138,8 @@ def forward(self, x): return torch.nn.functional.relu(self.outer_lin(x)) outer_mod = OuterModule() -opt_outer_mod = torch.compile(outer_mod) -print(opt_outer_mod(t)) +outer_mod.compile() +print(outer_mod(t)) ###################################################################### # We can also disable some functions from being compiled by using @@ -197,6 +200,12 @@ def outer_function(): # 4. **Compile Leaf Functions First:** In complex models with multiple nested # functions and modules, start by compiling the leaf functions or modules first. # For more information see `TorchDynamo APIs for fine-grained tracing `__. +# +# 5. **Prefer ``mod.compile()`` over ``torch.compile(mod)``:** Avoids ``_orig_`` prefix issues in ``state_dict``. +# +# 6. **Use ``fullgraph=True`` to catch graph breaks:** Helps ensure end-to-end compilation, maximizing speedup +# and compatibility with ``torch.export``. + ###################################################################### # Demonstrating Speedups