You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Adjust torch.compile() best practices
1. Add best practice to prefer `mod.compile` over `torch.compile(mod)`, which avoids `_orig_` naming problems.
Repro steps:
- opt_mod = torch.compile(mod)
- train opt_mod
- save checkpoint
In another script, potentially on a machine that does NOT support `torch.compile`: load checkpoint.
This fails with an error, because the checkpoint on `opt_mod` got its params renamed by `torch.compile`:
```
RuntimeError: Error(s) in loading state_dict for VQVAE:
Missing key(s) in state_dict: "embedding.weight", "encoder.encoder.net.0.weight", "encoder.encoder.net.0.bias", ...
Unexpected key(s) in state_dict: "_orig_mod.embedding.weight", "_orig_mod.encoder.encoder.net.0.weight", "_orig_mod.encoder.encoder.net.0.bias", ...
```
- Add best practice to use, or at least try, `fullgraph=True`. This doesn't always work, but we should encourage it.
---------
Co-authored-by: Svetlana Karslioglu <[email protected]>
0 commit comments