Skip to content

Commit 090cb65

Browse files
committed
Added CPU offloading
1 parent dc36709 commit 090cb65

File tree

4 files changed

+15
-11
lines changed

4 files changed

+15
-11
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def compile(
421421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422422
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423423
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
424+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
424425
**kwargs: Any,
425426
) -> torch.fx.GraphModule:
426427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -550,15 +551,6 @@ def compile(
550551
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
551552
)
552553

553-
if (
554-
not immutable_weights
555-
and not refit_identical_engine_weights
556-
and enable_weight_streaming
557-
):
558-
raise ValueError(
559-
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
560-
)
561-
562554
if (
563555
"enable_cross_compile_for_windows" in kwargs.keys()
564556
and kwargs["enable_cross_compile_for_windows"]
@@ -674,6 +666,7 @@ def compile(
674666
"enable_weight_streaming": enable_weight_streaming,
675667
"tiling_optimization_level": tiling_optimization_level,
676668
"l2_limit_for_tiling": l2_limit_for_tiling,
669+
"offload_module_to_cpu": offload_module_to_cpu,
677670
}
678671

679672
settings = CompilationSettings(**compilation_options)
@@ -684,6 +677,9 @@ def compile(
684677
)
685678

686679
gm = exported_program.module()
680+
# TODO: Memory control prototyping. Under discussion
681+
if offload_module_to_cpu:
682+
exported_program.module().to("cpu")
687683
logger.debug("Input graph: " + str(gm.graph))
688684

689685
# Apply lowering on the graph module
@@ -820,6 +816,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
820816
trt_modules = {}
821817
# Iterate over all components that can be accelerated
822818
# Generate the corresponding TRT Module for those
819+
823820
for name, _ in partitioned_module.named_children():
824821
submodule = getattr(partitioned_module, name)
825822
# filter on the GraphModule

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _save_weight_mapping(self) -> None:
492492
"""
493493
_LOGGER.info("Building weight name mapping...")
494494
# Stage 1: Name mapping
495-
torch_device = to_torch_device(self.compilation_settings.device)
495+
torch_device = to_torch_device(self.compilation_settings.device) # If the model original position is on CPU, move it GPU
496496
sd = {
497497
k: v.reshape(-1).to(torch_device)
498498
for k, v in self.module.state_dict().items()
@@ -736,7 +736,11 @@ def run(
736736
self._create_timing_cache(
737737
builder_config, self.compilation_settings.timing_cache_path
738738
)
739-
739+
# TODO: Memory control prototyping. Under discussion
740+
if self.compilation_settings.offload_module_to_cpu:
741+
del self.module
742+
gc.collect()
743+
torch.cuda.empty_cache()
740744
serialized_engine = self.builder.build_serialized_network(
741745
self.ctx.net, builder_config
742746
)

0 commit comments

Comments
 (0)