Skip to content

fix: cherry pick PR of 3445 #3457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 29, 2025
9 changes: 8 additions & 1 deletion .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ jobs:
cd tests/py
python -m pip install -r requirements.txt
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models.xml --ir dynamo models/test_models.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_models_dynamic.xml --ir dynamo models/test_dyn_models.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/engine_cache.xml --ir dynamo models/test_engine_cache.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dtype_support.xml --ir dynamo models/test_dtype_support.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/model_refit.xml --ir dynamo models/test_model_refit.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/modelopt_models.xml --ir dynamo models/test_modelopt_models.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/weight_stripped_engine.xml --ir dynamo models/test_weight_stripped_engine.py
popd
tests-py-dynamo-serde:
Expand Down Expand Up @@ -206,6 +212,7 @@ jobs:
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_kwargs_serde_test_results.xml --ir dynamo models/test_export_kwargs_serde.py
popd
tests-py-torch-compile-be:
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
logger = logging.getLogger(__name__)


@needs_refit
@needs_refit # type: ignore
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down Expand Up @@ -81,7 +81,7 @@ def construct_refit_mapping(
return interpreter.ctx.mapping


@needs_refit
@needs_refit # type: ignore
def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
Expand Down Expand Up @@ -111,7 +111,7 @@ def construct_refit_mapping_from_weight_name_map(
return engine_weight_map


@needs_refit
@needs_refit # type: ignore
def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
Expand Down Expand Up @@ -192,7 +192,7 @@ def _refit_single_trt_engine_with_gm(
raise AssertionError("Refitting failed.")


@needs_refit
@needs_refit # type: ignore
def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
else:
constant_tensor = frozen_attr

return to_torch(constant_tensor)
return to_torch(constant_tensor)

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
Expand Down
193 changes: 96 additions & 97 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,31 +358,27 @@ def create_constant(
shape = trt.Dims()
else:
shape = list(torch_value.shape)
if torch_value is not None:
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
numpy_value,
)
constant.name = name
if torch_value.dtype == torch.bfloat16:
return cast_trt_tensor(
ctx,
constant.get_output(0),
trt.DataType.BF16,
name + "_bf16_cast",
)
return constant.get_output(0)
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
raise ValueError(
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
numpy_value,
)
constant.name = name
if torch_value.dtype == torch.bfloat16:
return cast_trt_tensor(
ctx,
constant.get_output(0),
trt.DataType.BF16,
name + "_bf16_cast",
)
return constant.get_output(0)


def get_trt_tensor(
Expand Down Expand Up @@ -423,53 +419,6 @@ def get_trt_tensor(
raise AssertionError(f"Cannot convert {input_val} to TRT constant")


def to_torch(
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
) -> Optional[torch.Tensor]:
"""
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
Args:
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
A PyTorch tensor, Numpy array, int, float, or bool
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If a dtype is given, we will convert the type of the given `value` to this dtype.
Returns:
A PyTorch tensor or None, if the input was None.
"""

cpu_device = torch.device("cpu")
torch_dtype = (
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
)

with unset_fake_temporarily():
if value is None:
return None

elif isinstance(value, torch.Tensor):
output = value.to(cpu_device).contiguous()

elif isinstance(value, np.ndarray):
output = torch.from_numpy(value).to(cpu_device).contiguous()

elif isinstance(value, int):
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)

elif isinstance(value, float):
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)

elif isinstance(value, bool):
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)

else:
raise AssertionError(
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
)

return output.to(torch_dtype) if torch_dtype else output


@overload
def get_positive_dim(dim: int, dim_size: int) -> int: ...

Expand Down Expand Up @@ -633,42 +582,92 @@ def to_numpy(
Returns:
A Numpy array or None, if the input was None.
"""
output = None
with unset_fake_temporarily():
output = None

if value is None or isinstance(value, np.ndarray):
output = value
if value is None or isinstance(value, np.ndarray):
output = value

elif isinstance(value, torch.Tensor):
if value.is_quantized:
value = value.dequantize()
elif value.dtype == torch.bfloat16:
# TODO: Remove when numpy has a BF16 type
_LOGGER.warning(
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
elif isinstance(value, torch.Tensor):
if value.is_quantized:
value = value.dequantize()
elif value.dtype == torch.bfloat16:
# TODO: Remove when numpy has a BF16 type
_LOGGER.warning(
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
)
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()

elif isinstance(value, int):
output = np.array([value], dtype=np.int32)

elif isinstance(value, float):
output = np.array([value], dtype=np.float32)

elif isinstance(value, bool):
output = np.array([value], dtype=np.bool_)

if isinstance(output, np.ndarray) or output is None:
return (
output
if (dtype is None or output is None)
else output.astype(
_enums.dtype._from(dtype).to(np.dtype, use_default=True)
)
)
else:
raise AssertionError(
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
)
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()

elif isinstance(value, int):
output = np.array([value], dtype=np.int32)
def to_torch(
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
) -> Optional[torch.Tensor]:
"""
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
Args:
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
A PyTorch tensor, Numpy array, int, float, or bool
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If a dtype is given, we will convert the type of the given `value` to this dtype.
Returns:
A PyTorch tensor or None, if the input was None.
"""

elif isinstance(value, float):
output = np.array([value], dtype=np.float32)
cpu_device = torch.device("cpu")
torch_dtype = (
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
)

elif isinstance(value, bool):
output = np.array([value], dtype=np.bool_)
with unset_fake_temporarily():
if value is None:
return None

if isinstance(output, np.ndarray) or output is None:
return (
output
if (dtype is None or output is None)
else output.astype(_enums.dtype._from(dtype).to(np.dtype, use_default=True))
)
else:
raise AssertionError(
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
)
elif isinstance(value, torch.Tensor):
output = value.to(cpu_device).contiguous()

elif isinstance(value, np.ndarray):
output = torch.from_numpy(value).to(cpu_device).contiguous()

elif isinstance(value, int):
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)

elif isinstance(value, float):
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)

elif isinstance(value, bool):
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)

else:
raise AssertionError(
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
)

return output.to(torch_dtype) if torch_dtype else output


def flatten_dims(
Expand Down
34 changes: 18 additions & 16 deletions tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest

flashinfer = pytest.importorskip("flashinfer")
import unittest

import pytest
import torch
import torch.nn as nn
import torch_tensorrt
Expand All @@ -12,25 +10,29 @@

from ..conversion.harness import DispatchTestCase

# Toggle this flag to enable/disable flashinfer-based overrides
enable_flashinfer: bool = False
if enable_flashinfer:
import flashinfer

@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
def flashinfer_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
return flashinfer.norm.rmsnorm(input, weight)
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
def flashinfer_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
return flashinfer.norm.rmsnorm(input, weight)

@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
return input

@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
return input
torch_tensorrt.dynamo.conversion.plugins.custom_op(
"flashinfer::rmsnorm", supports_dynamic_shapes=True
)


torch_tensorrt.dynamo.conversion.plugins.custom_op(
"flashinfer::rmsnorm", supports_dynamic_shapes=True
@unittest.skip(
"Flashinfer RMSNorm test is disabled due to error: SM75 support not available"
)


@unittest.skip("Not Available")
class TestAutomaticPlugin(DispatchTestCase):
@parameterized.expand(
[
Expand Down
8 changes: 1 addition & 7 deletions tests/py/dynamo/backend/test_backend_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from copy import deepcopy

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.partitioning import fast_partition

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down Expand Up @@ -51,7 +50,6 @@ def forward(self, x, y):
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
use_python_runtime=False,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()
Expand Down Expand Up @@ -132,7 +130,6 @@ def forward(self, x, y):
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
use_python_runtime=False,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = model(*inputs).detach().cpu()
Expand Down Expand Up @@ -177,7 +174,6 @@ def forward(self, x, y):
optimization_level=4,
version_compatible=True,
max_aux_streams=5,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()
Expand Down Expand Up @@ -225,7 +221,6 @@ def forward(self, x, y):
min_block_size=1,
pass_through_build_failures=True,
truncate_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()
Expand Down Expand Up @@ -298,7 +293,6 @@ def forward(self, x, y):
min_block_size=1,
pass_through_build_failures=True,
truncate_double=False,
debug=True,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down
Loading
Loading