@@ -418,6 +418,61 @@ def test_select_container_for_mlflow_model_no_dlc_detected(
418
418
)
419
419
420
420
421
+ @patch ("sagemaker.image_uris.retrieve" )
422
+ @patch ("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version" )
423
+ @patch ("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements" )
424
+ @patch (
425
+ "sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file"
426
+ )
427
+ @patch ("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata" )
428
+ @patch ("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path" )
429
+ def test_select_container_for_mlflow_model_no_framework_version_detected (
430
+ mock_generate_mlflow_artifact_path ,
431
+ mock_get_all_flavor_metadata ,
432
+ mock_get_python_version_from_parsed_mlflow_model_file ,
433
+ mock_get_framework_version_from_requirements ,
434
+ mock_cast_to_compatible_version ,
435
+ mock_image_uris_retrieve ,
436
+ ):
437
+ mlflow_model_src_path = "/path/to/mlflow_model"
438
+ deployment_flavor = "pytorch"
439
+ region = "us-west-2"
440
+ instance_type = "ml.m5.xlarge"
441
+
442
+ mock_requirements_path = "/path/to/requirements.txt"
443
+ mock_metadata_path = "/path/to/mlmodel"
444
+ mock_flavor_metadata = {"pytorch" : {"some_key" : "some_value" }}
445
+ mock_python_version = "3.8.6"
446
+
447
+ mock_generate_mlflow_artifact_path .side_effect = lambda path , artifact : (
448
+ mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path
449
+ )
450
+ mock_get_all_flavor_metadata .return_value = mock_flavor_metadata
451
+ mock_get_python_version_from_parsed_mlflow_model_file .return_value = mock_python_version
452
+ mock_get_framework_version_from_requirements .return_value = None
453
+
454
+ with pytest .raises (
455
+ ValueError ,
456
+ match = "Unable to auto detect framework version. Please provide framework "
457
+ "pytorch as part of the requirements.txt file for deployment flavor "
458
+ "pytorch" ,
459
+ ):
460
+ _select_container_for_mlflow_model (
461
+ mlflow_model_src_path , deployment_flavor , region , instance_type
462
+ )
463
+
464
+ mock_generate_mlflow_artifact_path .assert_any_call (
465
+ mlflow_model_src_path , "requirements.txt"
466
+ )
467
+ mock_generate_mlflow_artifact_path .assert_any_call (mlflow_model_src_path , "MLmodel" )
468
+ mock_get_all_flavor_metadata .assert_called_once_with (mock_metadata_path )
469
+ mock_get_framework_version_from_requirements .assert_called_once_with (
470
+ deployment_flavor , mock_requirements_path
471
+ )
472
+ mock_cast_to_compatible_version .assert_not_called ()
473
+ mock_image_uris_retrieve .assert_not_called ()
474
+
475
+
421
476
def test_validate_input_for_mlflow ():
422
477
_validate_input_for_mlflow (ModelServer .TORCHSERVE , "pytorch" )
423
478
0 commit comments