Skip to content

Commit ef8288a

Browse files
committed
Added back the control flag and fixed the CI
1 parent 63d6552 commit ef8288a

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
425426
**kwargs: Any,
426427
) -> torch.fx.GraphModule:
427428
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,6 +667,7 @@ def compile(
666667
"enable_weight_streaming": enable_weight_streaming,
667668
"tiling_optimization_level": tiling_optimization_level,
668669
"l2_limit_for_tiling": l2_limit_for_tiling,
670+
"offload_module_to_cpu": offload_module_to_cpu,
669671
}
670672

671673
settings = CompilationSettings(**compilation_options)
@@ -677,16 +679,16 @@ def compile(
677679

678680
gm = exported_program.module()
679681
# Move the weights in the state_dict to CPU
680-
logger.info(
681-
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682-
)
683682
logger.debug("Input graph: " + str(gm.graph))
684683

685684
# Apply lowering on the graph module
686685
gm = post_lowering(gm, settings)
687686
logger.debug("Lowered Input graph: " + str(gm.graph))
688-
689-
exported_program.module().to(CPU_DEVICE)
687+
if offload_module_to_cpu:
688+
exported_program.module().to(CPU_DEVICE)
689+
logger.info(
690+
"The model is offloaded to CPU during compilation. If you want to keep the model on GPU, set offload_module_to_cpu=False."
691+
)
690692
trt_gm = compile_module(
691693
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
692694
)

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,8 @@ def run(
736736
self._create_timing_cache(
737737
builder_config, self.compilation_settings.timing_cache_path
738738
)
739-
740-
delete_module(self.module)
739+
if self.compilation_settings.offload_module_to_cpu:
740+
delete_module(self.module)
741741
serialized_engine = self.builder.build_serialized_network(
742742
self.ctx.net, builder_config
743743
)

0 commit comments

Comments
 (0)