diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index d0fabf9993..6c32db5f91 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -173,7 +173,13 @@ jobs: cd tests/py python -m pip install -r requirements.txt cd dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dtype_support.xml --ir dynamo models/test_dtype_support.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py popd tests-py-dynamo-serde: @@ -206,6 +212,7 @@ jobs: cd dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py popd tests-py-torch-compile-be: diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 7be7e0f16c..f217383f5c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) -@needs_refit +@needs_refit # type: ignore def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -81,7 +81,7 @@ def construct_refit_mapping( return interpreter.ctx.mapping -@needs_refit +@needs_refit # type: ignore def construct_refit_mapping_from_weight_name_map( weight_name_map: dict[Any, Any], state_dict: dict[Any, Any], @@ -111,7 +111,7 @@ def construct_refit_mapping_from_weight_name_map( return engine_weight_map -@needs_refit +@needs_refit # type: ignore def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, @@ -192,7 +192,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError("Refitting failed.") -@needs_refit +@needs_refit # type: ignore def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index a3444c025d..da5f3b36c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -893,7 +893,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: else: constant_tensor = frozen_attr - return to_torch(constant_tensor) + return to_torch(constant_tensor) def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b7fea33088..3541f57f1a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -358,31 +358,27 @@ def create_constant( shape = trt.Dims() else: shape = list(torch_value.shape) - if torch_value is not None: - if torch_value.dtype == torch.bfloat16: - torch_value_fp32 = torch_value.to(torch.float32) - numpy_value = torch_value_fp32.numpy() - else: - numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) - constant = ctx.net.add_constant( - shape, - numpy_value, - ) - constant.name = name - if torch_value.dtype == torch.bfloat16: - return cast_trt_tensor( - ctx, - constant.get_output(0), - trt.DataType.BF16, - name + "_bf16_cast", - ) - return constant.get_output(0) + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() else: - raise ValueError( - f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + numpy_value = torch_value.numpy() + + ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) + constant = ctx.net.add_constant( + shape, + numpy_value, + ) + constant.name = name + if torch_value.dtype == torch.bfloat16: + return cast_trt_tensor( + ctx, + constant.get_output(0), + trt.DataType.BF16, + name + "_bf16_cast", ) + return constant.get_output(0) def get_trt_tensor( @@ -423,53 +419,6 @@ def get_trt_tensor( raise AssertionError(f"Cannot convert {input_val} to TRT constant") -def to_torch( - value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], - dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, -) -> Optional[torch.Tensor]: - """ - Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU - Args: - value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): - A PyTorch tensor, Numpy array, int, float, or bool - dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): - If a dtype is given, we will convert the type of the given `value` to this dtype. - Returns: - A PyTorch tensor or None, if the input was None. - """ - - cpu_device = torch.device("cpu") - torch_dtype = ( - _enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None - ) - - with unset_fake_temporarily(): - if value is None: - return None - - elif isinstance(value, torch.Tensor): - output = value.to(cpu_device).contiguous() - - elif isinstance(value, np.ndarray): - output = torch.from_numpy(value).to(cpu_device).contiguous() - - elif isinstance(value, int): - output = torch.tensor([value], device=cpu_device, dtype=torch.int32) - - elif isinstance(value, float): - output = torch.tensor([value], device=cpu_device, dtype=torch.float32) - - elif isinstance(value, bool): - output = torch.tensor([value], device=cpu_device, dtype=torch.bool) - - else: - raise AssertionError( - f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" - ) - - return output.to(torch_dtype) if torch_dtype else output - - @overload def get_positive_dim(dim: int, dim_size: int) -> int: ... @@ -633,42 +582,92 @@ def to_numpy( Returns: A Numpy array or None, if the input was None. """ - output = None + with unset_fake_temporarily(): + output = None - if value is None or isinstance(value, np.ndarray): - output = value + if value is None or isinstance(value, np.ndarray): + output = value - elif isinstance(value, torch.Tensor): - if value.is_quantized: - value = value.dequantize() - elif value.dtype == torch.bfloat16: - # TODO: Remove when numpy has a BF16 type - _LOGGER.warning( - "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + elif isinstance(value, torch.Tensor): + if value.is_quantized: + value = value.dequantize() + elif value.dtype == torch.bfloat16: + # TODO: Remove when numpy has a BF16 type + _LOGGER.warning( + "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + ) + value = value.to(torch.float) + + output = value.cpu().detach().contiguous().numpy() + + elif isinstance(value, int): + output = np.array([value], dtype=np.int32) + + elif isinstance(value, float): + output = np.array([value], dtype=np.float32) + + elif isinstance(value, bool): + output = np.array([value], dtype=np.bool_) + + if isinstance(output, np.ndarray) or output is None: + return ( + output + if (dtype is None or output is None) + else output.astype( + _enums.dtype._from(dtype).to(np.dtype, use_default=True) + ) + ) + else: + raise AssertionError( + f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" ) - value = value.to(torch.float) - output = value.cpu().detach().contiguous().numpy() - elif isinstance(value, int): - output = np.array([value], dtype=np.int32) +def to_torch( + value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, +) -> Optional[torch.Tensor]: + """ + Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU + Args: + value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): + A PyTorch tensor, Numpy array, int, float, or bool + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A PyTorch tensor or None, if the input was None. + """ - elif isinstance(value, float): - output = np.array([value], dtype=np.float32) + cpu_device = torch.device("cpu") + torch_dtype = ( + _enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None + ) - elif isinstance(value, bool): - output = np.array([value], dtype=np.bool_) + with unset_fake_temporarily(): + if value is None: + return None - if isinstance(output, np.ndarray) or output is None: - return ( - output - if (dtype is None or output is None) - else output.astype(_enums.dtype._from(dtype).to(np.dtype, use_default=True)) - ) - else: - raise AssertionError( - f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" - ) + elif isinstance(value, torch.Tensor): + output = value.to(cpu_device).contiguous() + + elif isinstance(value, np.ndarray): + output = torch.from_numpy(value).to(cpu_device).contiguous() + + elif isinstance(value, int): + output = torch.tensor([value], device=cpu_device, dtype=torch.int32) + + elif isinstance(value, float): + output = torch.tensor([value], device=cpu_device, dtype=torch.float32) + + elif isinstance(value, bool): + output = torch.tensor([value], device=cpu_device, dtype=torch.bool) + + else: + raise AssertionError( + f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" + ) + + return output.to(torch_dtype) if torch_dtype else output def flatten_dims( diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py index cf803c5ffa..fd5ed390ff 100644 --- a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -1,8 +1,6 @@ -import pytest - -flashinfer = pytest.importorskip("flashinfer") import unittest +import pytest import torch import torch.nn as nn import torch_tensorrt @@ -12,25 +10,29 @@ from ..conversion.harness import DispatchTestCase +# Toggle this flag to enable/disable flashinfer-based overrides +enable_flashinfer: bool = False +if enable_flashinfer: + import flashinfer -@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] -def flashinfer_rmsnorm( - input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 -) -> torch.Tensor: - return flashinfer.norm.rmsnorm(input, weight) + @torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc] + def flashinfer_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 + ) -> torch.Tensor: + return flashinfer.norm.rmsnorm(input, weight) + @torch.library.register_fake("flashinfer::rmsnorm") + def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: + return input -@torch.library.register_fake("flashinfer::rmsnorm") -def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor: - return input + torch_tensorrt.dynamo.conversion.plugins.custom_op( + "flashinfer::rmsnorm", supports_dynamic_shapes=True + ) -torch_tensorrt.dynamo.conversion.plugins.custom_op( - "flashinfer::rmsnorm", supports_dynamic_shapes=True +@unittest.skip( + "Flashinfer RMSNorm test is disabled due to error: SM75 support not available" ) - - -@unittest.skip("Not Available") class TestAutomaticPlugin(DispatchTestCase): @parameterized.expand( [ diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 4c65800f05..6369d3805c 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -2,11 +2,10 @@ from copy import deepcopy import torch +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition -import torch_tensorrt - from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -51,7 +50,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -132,7 +130,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = model(*inputs).detach().cpu() @@ -177,7 +174,6 @@ def forward(self, x, y): optimization_level=4, version_compatible=True, max_aux_streams=5, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -225,7 +221,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=True, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -298,7 +293,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=False, - debug=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, ) optimized_model_results = optimized_model(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 6ff45507a0..aa22a74fc0 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -415,7 +415,6 @@ def run_test( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) @@ -507,7 +506,6 @@ def run_test_compare_tensor_attributes_only( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 146f7fdb7d..37b40574a1 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -297,7 +297,6 @@ def forward(self, x): ir="torch_compile", inputs=inputs, enabled_precisions={torch.bfloat16}, - debug=True, min_block_size=1, device=device, cache_built_engines=False, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index d71091b04e..b170bcc47d 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -815,7 +815,6 @@ def forward(self, x): exp_program, tuple(inputs), enabled_precisions={torch.float}, - debug=True, min_block_size=1, immutable_weights=False, ) diff --git a/tests/py/dynamo/models/test_modelopt_models.py b/tests/py/dynamo/models/test_modelopt_models.py new file mode 100644 index 0000000000..c2cd719bf9 --- /dev/null +++ b/tests/py/dynamo/models/test_modelopt_models.py @@ -0,0 +1,117 @@ +# type: ignore +import importlib +import platform +import unittest +from importlib import metadata + +import pytest +import torch +import torch_tensorrt as torchtrt + +from packaging.version import Version + +assertions = unittest.TestCase() + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) + + +@unittest.skipIf( + platform.system() != "Linux" + or not importlib.util.find_spec("modelopt") + or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"), + "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", +) +@pytest.mark.unit +def test_base_int8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.INT8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has INT8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torchtrt.logging.debug(), torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.int8}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + debug=True, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py index 0a4629644d..0c9b8bc13f 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -61,7 +61,6 @@ def forward(self, x): min_block_size=1, pass_through_build_failures=True, use_python_runtime=True, - debug=True, ) result_samples = [] diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 4f3c4e083b..94db519d28 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -9,5 +9,5 @@ pytest-xdist>=3.6.1 pyyaml timm>=1.0.3 transformers==4.49.0 -nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13" +nvidia-modelopt[all]~=0.27.0; python_version < "3.13" --extra-index-url https://pypi.nvidia.com