diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index d0fabf9993..6c32db5f91 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -173,7 +173,13 @@ jobs: cd tests/py python -m pip install -r requirements.txt cd dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dtype_support.xml --ir dynamo models/test_dtype_support.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py popd tests-py-dynamo-serde: @@ -206,6 +212,7 @@ jobs: cd dynamo python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py popd tests-py-torch-compile-be: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 908ceaec41..b3dfe46a31 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -892,7 +892,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: else: constant_tensor = frozen_attr - return to_torch(constant_tensor) + return to_torch(constant_tensor) def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 0808db2ed8..c0bf3ededb 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -344,10 +344,6 @@ def create_constant( with unset_fake_temporarily(): torch_value = to_torch(value, dtype) - if torch_value is None: - raise ValueError( - f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." - ) if torch_value.dtype == torch.float64: raise ValueError( "TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model." @@ -590,42 +586,45 @@ def to_numpy( Returns: A Numpy array or None, if the input was None. """ - output = None - - if value is None or isinstance(value, np.ndarray): - output = value + with unset_fake_temporarily(): + output = None - elif isinstance(value, torch.Tensor): - if value.is_quantized: - value = value.dequantize() - elif value.dtype == torch.bfloat16: - # TODO: Remove when numpy has a BF16 type - _LOGGER.warning( - "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", - ) - value = value.to(torch.float) + if value is None or isinstance(value, np.ndarray): + output = value - output = value.cpu().detach().contiguous().numpy() + elif isinstance(value, torch.Tensor): + if value.is_quantized: + value = value.dequantize() + elif value.dtype == torch.bfloat16: + # TODO: Remove when numpy has a BF16 type + _LOGGER.warning( + "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + ) + value = value.to(torch.float) - elif isinstance(value, int): - output = np.array([value], dtype=np.int32) + output = value.cpu().detach().contiguous().numpy() - elif isinstance(value, float): - output = np.array([value], dtype=np.float32) + elif isinstance(value, int): + output = np.array([value], dtype=np.int32) - elif isinstance(value, bool): - output = np.array([value], dtype=np.bool_) + elif isinstance(value, float): + output = np.array([value], dtype=np.float32) - if isinstance(output, np.ndarray) or output is None: - return ( - output - if (dtype is None or output is None) - else output.astype(_enums.dtype._from(dtype).to(np.dtype, use_default=True)) - ) - else: - raise AssertionError( - f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" - ) + elif isinstance(value, bool): + output = np.array([value], dtype=np.bool_) + + if isinstance(output, np.ndarray) or output is None: + return ( + output + if (dtype is None or output is None) + else output.astype( + _enums.dtype._from(dtype).to(np.dtype, use_default=True) + ) + ) + else: + raise AssertionError( + f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" + ) def to_torch( diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 4c65800f05..6369d3805c 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -2,11 +2,10 @@ from copy import deepcopy import torch +import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition -import torch_tensorrt - from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -51,7 +50,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -132,7 +130,6 @@ def forward(self, x, y): pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, use_python_runtime=False, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = model(*inputs).detach().cpu() @@ -177,7 +174,6 @@ def forward(self, x, y): optimization_level=4, version_compatible=True, max_aux_streams=5, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -225,7 +221,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=True, - debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -298,7 +293,6 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, truncate_double=False, - debug=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, ) optimized_model_results = optimized_model(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 6ff45507a0..aa22a74fc0 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -415,7 +415,6 @@ def run_test( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) @@ -507,7 +506,6 @@ def run_test_compare_tensor_attributes_only( compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, truncate_double=True, - debug=True, immutable_weights=immutable_weights, ) diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 146f7fdb7d..37b40574a1 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -297,7 +297,6 @@ def forward(self, x): ir="torch_compile", inputs=inputs, enabled_precisions={torch.bfloat16}, - debug=True, min_block_size=1, device=device, cache_built_engines=False, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index d71091b04e..b170bcc47d 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -815,7 +815,6 @@ def forward(self, x): exp_program, tuple(inputs), enabled_precisions={torch.float}, - debug=True, min_block_size=1, immutable_weights=False, ) diff --git a/tests/py/dynamo/models/test_modelopt_models.py b/tests/py/dynamo/models/test_modelopt_models.py new file mode 100644 index 0000000000..c2cd719bf9 --- /dev/null +++ b/tests/py/dynamo/models/test_modelopt_models.py @@ -0,0 +1,117 @@ +# type: ignore +import importlib +import platform +import unittest +from importlib import metadata + +import pytest +import torch +import torch_tensorrt as torchtrt + +from packaging.version import Version + +assertions = unittest.TestCase() + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (8, 9), + "FP8 quantization requires compute capability 8.9 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) + + +@unittest.skipIf( + platform.system() != "Linux" + or not importlib.util.find_spec("modelopt") + or Version(metadata.version("nvidia-modelopt")) < Version("0.27.0"), + "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", +) +@pytest.mark.unit +def test_base_int8(): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear(in_features=10, out_features=5) + self.linear2 = torch.nn.Linear(in_features=5, out_features=1) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.ReLU()(x) + x = self.linear2(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(1, 10).cuda() + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.INT8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has INT8 qdq nodes at this point + output_pyt = model(input_tensor) + + with torchtrt.logging.debug(), torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.int8}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + truncate_double=True, + debug=True, + ) + outputs_trt = trt_model(input_tensor) + assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py deleted file mode 100644 index 19fdeaa9ab..0000000000 --- a/tests/py/dynamo/models/test_models_export.py +++ /dev/null @@ -1,305 +0,0 @@ -# type: ignore -import importlib -import platform -import unittest -from importlib import metadata - -import pytest -import timm -import torch -import torch_tensorrt as torchtrt -import torchvision.models as models -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - -from packaging.version import Version - -assertions = unittest.TestCase() - - -@pytest.mark.unit -def test_resnet18(ir): - model = models.resnet18(pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_mobilenet_v2(ir): - model = models.mobilenet_v2(pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_efficientnet_b0(ir): - model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -@unittest.skipIf( - not importlib.util.find_spec("transformers"), - "transformers is required to run this test", -) -def test_bert_base_uncased(ir): - from transformers import BertModel - - model = ( - BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval() - ) - input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, - dtype=input.dtype, - format=torch.contiguous_format, - ), - torchtrt.Input( - input.shape, - dtype=input.dtype, - format=torch.contiguous_format, - ), - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "truncate_double": True, - "ir": ir, - "min_block_size": 10, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - trt_mod = torchtrt.compile(model, **compile_spec) - model_outputs = model(input, input2) - trt_model_outputs = trt_mod(input, input2) - assertions.assertTrue( - len(model_outputs) == len(trt_model_outputs), - msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.", - ) - - for index in range(len(model_outputs)): - out, trt_out = model_outputs[index], trt_model_outputs[index] - cos_sim = cosine_similarity(out, trt_out) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_resnet18_half(ir): - model = models.resnet18(pretrained=True).eval().to("cuda").half() - input = torch.randn((1, 3, 224, 224)).to("cuda").half() - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.half, format=torch.contiguous_format - ) - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.half}, - "ir": ir, - "pass_through_build_failures": True, - "optimization_level": 1, - "min_block_size": 8, - "cache_built_engines": False, - "reuse_cached_engines": False, - } - - trt_mod = torchtrt.compile(model, **compile_spec) - cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@unittest.skipIf( - torch.cuda.get_device_capability() < (8, 9), - "FP8 quantization requires compute capability 8.9 or later", -) -@unittest.skipIf( - not importlib.util.find_spec("modelopt"), - "ModelOpt is required to run this test", -) -@pytest.mark.unit -def test_base_fp8(ir): - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - - class SimpleNetwork(torch.nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=10, out_features=5) - self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.ReLU()(x) - x = self.linear2(x) - return x - - def calibrate_loop(model): - """Simple calibration function for testing.""" - model(input_tensor) - - input_tensor = torch.randn(1, 10).cuda() - model = SimpleNetwork().eval().cuda() - - quant_cfg = mtq.FP8_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has FP8 qdq nodes at this point - output_pyt = model(input_tensor) - - with torch.no_grad(): - with export_torch_mode(): - exp_program = torch.export.export(model, (input_tensor,), strict=False) - trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=[input_tensor], - enabled_precisions={torch.float8_e4m3fn}, - min_block_size=1, - debug=True, - cache_built_engines=False, - reuse_cached_engines=False, - ) - outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) - - -@unittest.skipIf( - platform.system() != "Linux" - or torch.cuda.get_device_capability() < (8, 9) - or not importlib.util.find_spec("modelopt") - or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), - "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", -) -@pytest.mark.unit -def test_base_int8(ir): - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.utils import export_torch_mode - - class SimpleNetwork(torch.nn.Module): - def __init__(self): - super(SimpleNetwork, self).__init__() - self.linear1 = torch.nn.Linear(in_features=10, out_features=5) - self.linear2 = torch.nn.Linear(in_features=5, out_features=1) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.ReLU()(x) - x = self.linear2(x) - return x - - def calibrate_loop(model): - """Simple calibration function for testing.""" - model(input_tensor) - - input_tensor = torch.randn(1, 10).cuda() - model = SimpleNetwork().eval().cuda() - - quant_cfg = mtq.INT8_DEFAULT_CFG - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - # model has INT8 qdq nodes at this point - output_pyt = model(input_tensor) - - with torch.no_grad(): - with export_torch_mode(): - exp_program = torch.export.export(model, (input_tensor,)) - trt_model = torchtrt.dynamo.compile( - exp_program, - inputs=[input_tensor], - enabled_precisions={torch.int8}, - min_block_size=1, - debug=True, - cache_built_engines=False, - reuse_cached_engines=False, - truncate_double=True, - ) - outputs_trt = trt_model(input_tensor) - assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py index 0a4629644d..0c9b8bc13f 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -61,7 +61,6 @@ def forward(self, x): min_block_size=1, pass_through_build_failures=True, use_python_runtime=True, - debug=True, ) result_samples = [] diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 011ed01e35..6321824458 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -10,5 +10,5 @@ pyyaml timm>=1.0.3 flashinfer-python; python_version < "3.13" transformers==4.49.0 -nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13" +nvidia-modelopt[all]; python_version < "3.13" --extra-index-url https://pypi.nvidia.com