Skip to content

Commit 7fc877b

Browse files
punkeelsvekars
andauthored
Adjust torch.compile() best practices (#3336)
* Adjust torch.compile() best practices 1. Add best practice to prefer `mod.compile` over `torch.compile(mod)`, which avoids `_orig_` naming problems. Repro steps: - opt_mod = torch.compile(mod) - train opt_mod - save checkpoint In another script, potentially on a machine that does NOT support `torch.compile`: load checkpoint. This fails with an error, because the checkpoint on `opt_mod` got its params renamed by `torch.compile`: ``` RuntimeError: Error(s) in loading state_dict for VQVAE: Missing key(s) in state_dict: "embedding.weight", "encoder.encoder.net.0.weight", "encoder.encoder.net.0.bias", ... Unexpected key(s) in state_dict: "_orig_mod.embedding.weight", "_orig_mod.encoder.encoder.net.0.weight", "_orig_mod.encoder.encoder.net.0.bias", ... ``` - Add best practice to use, or at least try, `fullgraph=True`. This doesn't always work, but we should encourage it. --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent a5632da commit 7fc877b

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

intermediate_source/torch_compile_tutorial.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,11 @@ def forward(self, x):
101101
return torch.nn.functional.relu(self.lin(x))
102102

103103
mod = MyModule()
104-
opt_mod = torch.compile(mod)
105-
print(opt_mod(t))
104+
mod.compile()
105+
print(mod(t))
106+
## or:
107+
# opt_mod = torch.compile(mod)
108+
# print(opt_mod(t))
106109

107110
######################################################################
108111
# torch.compile and Nested Calls
@@ -135,8 +138,8 @@ def forward(self, x):
135138
return torch.nn.functional.relu(self.outer_lin(x))
136139

137140
outer_mod = OuterModule()
138-
opt_outer_mod = torch.compile(outer_mod)
139-
print(opt_outer_mod(t))
141+
outer_mod.compile()
142+
print(outer_mod(t))
140143

141144
######################################################################
142145
# We can also disable some functions from being compiled by using
@@ -197,6 +200,12 @@ def outer_function():
197200
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
198201
# functions and modules, start by compiling the leaf functions or modules first.
199202
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
203+
#
204+
# 5. **Prefer ``mod.compile()`` over ``torch.compile(mod)``:** Avoids ``_orig_`` prefix issues in ``state_dict``.
205+
#
206+
# 6. **Use ``fullgraph=True`` to catch graph breaks:** Helps ensure end-to-end compilation, maximizing speedup
207+
# and compatibility with ``torch.export``.
208+
200209

201210
######################################################################
202211
# Demonstrating Speedups

0 commit comments

Comments
 (0)