diff --git a/test.py b/test.py index ed8413839..4fe81b564 100644 --- a/test.py +++ b/test.py @@ -36,14 +36,17 @@ def tearDown(self): gc.collect() -def _create_example_model_instance(task: ModelTask, device: str): +def _create_example_model_instance(task: ModelTask, device: str, mode: str): skip = False + extra_args = ["--accuracy"] + if mode == "inductor": + extra_args.append("--inductor") try: - task.make_model_instance(test="eval", device=device, extra_args=["--accuracy"]) + task.make_model_instance(test="eval", device=device, extra_args=extra_args) except NotImplementedError: try: task.make_model_instance( - test="train", device=device, extra_args=["--accuracy"] + test="train", device=device, extra_args=extra_args ) except NotImplementedError: skip = True @@ -54,7 +57,7 @@ def _create_example_model_instance(task: ModelTask, device: str): ) -def _load_test(path, device): +def _load_test(path, device, mode): model_name = os.path.basename(path) def _skip_cuda_memory_check_p(metadata): @@ -70,7 +73,7 @@ def example_fn(self): skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual ): try: - _create_example_model_instance(task, device) + _create_example_model_instance(task, device, mode) accuracy = task.get_model_attribute("accuracy") assert ( accuracy == "pass" @@ -96,7 +99,7 @@ def train_fn(self): ): try: task.make_model_instance( - test="train", device=device, batch_size=batch_size + test="train", device=device, batch_size=batch_size, extra_args=["--inductor"] if mode == "inductor" else [] ) task.invoke() task.check_details_train(device=device, md=metadata) @@ -119,7 +122,7 @@ def eval_fn(self): ): try: task.make_model_instance( - test="eval", device=device, batch_size=batch_size + test="eval", device=device, batch_size=batch_size, extra_args=["--inductor"] if mode == "inductor" else [] ) task.invoke() task.check_details_eval(device=device, md=metadata) @@ -136,7 +139,7 @@ def check_device_fn(self): skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual ): try: - task.make_model_instance(test="eval", device=device) + task.make_model_instance(test="eval", device=device, extra_args=["--inductor"] if mode == "inductor" else []) task.check_device() task.del_model_instance() except NotImplementedError as e: @@ -152,9 +155,10 @@ def check_device_fn(self): # set exclude list based on metadata setattr( TestBenchmark, - f"test_{model_name}_{fn_name}_{device}", + f"test_{model_name}_{fn_name}_{device}_{mode}", ( unittest.skipIf( + # This is expecting that models will never be skipped just based on backend, just on eval or train functions being implemented skip_by_metadata( test=fn_name, device=device, extra_args=[], metadata=metadata ), @@ -165,6 +169,7 @@ def check_device_fn(self): def _load_tests(): + modes = ["eager", "inductor"] devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda") @@ -181,7 +186,8 @@ def _load_tests(): if "quantized" in path: continue for device in devices: - _load_test(path, device) + for mode in modes: + _load_test(path, device, mode) _load_tests()