Skip to content

Commit 2f2be05

Browse files
authored
feat: Add telemetry support for mlflow models (#4674)
* Initial commit for telemetry support * Fix style issues and add more logger messages * fix value error messages in ut
1 parent 1279620 commit 2f2be05

File tree

7 files changed

+131
-6
lines changed

7 files changed

+131
-6
lines changed

src/sagemaker/serve/builder/model_builder.py

+3
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,9 @@ def _initialize_for_mlflow(self) -> None:
669669
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
670670
if not _mlflow_input_is_local_path(mlflow_path):
671671
# TODO: extend to package arn, run id and etc.
672+
logger.info(
673+
"Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
674+
)
672675
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
673676
else:
674677
_copy_directory_contents(mlflow_path, self.model_path)

src/sagemaker/serve/model_format/mlflow/constants.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
"py39": "1.13.1",
2020
"py310": "2.2.0",
2121
}
22-
MODEL_PACAKGE_ARN_REGEX = (
23-
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/[" r"a-zA-Z0-9\-_\/\.]+$"
22+
MODEL_PACKAGE_ARN_REGEX = (
23+
r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
2424
)
2525
MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
2626
MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
27-
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+\/[a-zA-Z0-9\-_\/\.]*$"
27+
S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
2828
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
2929
MLFLOW_METADATA_FILE = "MLmodel"
3030
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"

src/sagemaker/serve/model_format/mlflow/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non
278278
os.makedirs(local_file_dir, exist_ok=True)
279279

280280
# Download the file
281-
print(f"Downloading {key} to {local_file_path}")
281+
logger.info(f"Downloading {key} to {local_file_path}")
282282
s3.download_file(s3_bucket, key, local_file_path)
283283

284284

@@ -356,6 +356,15 @@ def _select_container_for_mlflow_model(
356356
logger.info("Auto-detected framework to use is %s", framework_to_use)
357357
logger.info("Auto-detected framework version is %s", framework_version)
358358

359+
if framework_version is None:
360+
raise ValueError(
361+
(
362+
"Unable to auto detect framework version. Please provide framework %s as part of the "
363+
"requirements.txt file for deployment flavor %s"
364+
)
365+
% (framework_to_use, deployment_flavor)
366+
)
367+
359368
casted_versions = (
360369
_cast_to_compatible_version(framework_to_use, framework_version)
361370
if framework_version

src/sagemaker/serve/utils/lineage_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.lineage.query import LineageSourceEnum
2929
from sagemaker.serve.model_format.mlflow.constants import (
3030
MLFLOW_RUN_ID_REGEX,
31-
MODEL_PACAKGE_ARN_REGEX,
31+
MODEL_PACKAGE_ARN_REGEX,
3232
S3_PATH_REGEX,
3333
MLFLOW_REGISTRY_PATH_REGEX,
3434
)
@@ -107,7 +107,7 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
107107
"""
108108
mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX
109109
mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX
110-
sagemaker_arn_pattern = MODEL_PACAKGE_ARN_REGEX
110+
sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX
111111
s3_pattern = S3_PATH_REGEX
112112

113113
if re.match(mlflow_rub_id_pattern, mlflow_model_path):

src/sagemaker/serve/utils/telemetry_logger.py

+22
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,16 @@
1919

2020
from sagemaker import Session, exceptions
2121
from sagemaker.serve.mode.function_pointers import Mode
22+
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
2223
from sagemaker.serve.utils.exceptions import ModelBuilderException
24+
from sagemaker.serve.utils.lineage_constants import (
25+
MLFLOW_LOCAL_PATH,
26+
MLFLOW_S3_PATH,
27+
MLFLOW_MODEL_PACKAGE_PATH,
28+
MLFLOW_RUN_ID,
29+
MLFLOW_REGISTRY_PATH,
30+
)
31+
from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type
2332
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
2433
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
2534
from sagemaker.user_agent import SDK_VERSION
@@ -51,6 +60,14 @@
5160
str(ModelServer.TGI): 6,
5261
}
5362

63+
MLFLOW_MODEL_PATH_CODE = {
64+
MLFLOW_LOCAL_PATH: 1,
65+
MLFLOW_S3_PATH: 2,
66+
MLFLOW_MODEL_PACKAGE_PATH: 3,
67+
MLFLOW_RUN_ID: 4,
68+
MLFLOW_REGISTRY_PATH: 5,
69+
}
70+
5471

5572
def _capture_telemetry(func_name: str):
5673
"""Placeholder docstring"""
@@ -78,6 +95,11 @@ def wrapper(self, *args, **kwargs):
7895
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
7996
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"
8097

98+
if getattr(self, "_is_mlflow_model", False):
99+
mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH]
100+
mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path)
101+
extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}"
102+
81103
start_timer = perf_counter()
82104
try:
83105
response = func(self, *args, **kwargs)

tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py

+55
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,61 @@ def test_select_container_for_mlflow_model_no_dlc_detected(
418418
)
419419

420420

421+
@patch("sagemaker.image_uris.retrieve")
422+
@patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version")
423+
@patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements")
424+
@patch(
425+
"sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file"
426+
)
427+
@patch("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata")
428+
@patch("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path")
429+
def test_select_container_for_mlflow_model_no_framework_version_detected(
430+
mock_generate_mlflow_artifact_path,
431+
mock_get_all_flavor_metadata,
432+
mock_get_python_version_from_parsed_mlflow_model_file,
433+
mock_get_framework_version_from_requirements,
434+
mock_cast_to_compatible_version,
435+
mock_image_uris_retrieve,
436+
):
437+
mlflow_model_src_path = "/path/to/mlflow_model"
438+
deployment_flavor = "pytorch"
439+
region = "us-west-2"
440+
instance_type = "ml.m5.xlarge"
441+
442+
mock_requirements_path = "/path/to/requirements.txt"
443+
mock_metadata_path = "/path/to/mlmodel"
444+
mock_flavor_metadata = {"pytorch": {"some_key": "some_value"}}
445+
mock_python_version = "3.8.6"
446+
447+
mock_generate_mlflow_artifact_path.side_effect = lambda path, artifact: (
448+
mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path
449+
)
450+
mock_get_all_flavor_metadata.return_value = mock_flavor_metadata
451+
mock_get_python_version_from_parsed_mlflow_model_file.return_value = mock_python_version
452+
mock_get_framework_version_from_requirements.return_value = None
453+
454+
with pytest.raises(
455+
ValueError,
456+
match="Unable to auto detect framework version. Please provide framework "
457+
"pytorch as part of the requirements.txt file for deployment flavor "
458+
"pytorch",
459+
):
460+
_select_container_for_mlflow_model(
461+
mlflow_model_src_path, deployment_flavor, region, instance_type
462+
)
463+
464+
mock_generate_mlflow_artifact_path.assert_any_call(
465+
mlflow_model_src_path, "requirements.txt"
466+
)
467+
mock_generate_mlflow_artifact_path.assert_any_call(mlflow_model_src_path, "MLmodel")
468+
mock_get_all_flavor_metadata.assert_called_once_with(mock_metadata_path)
469+
mock_get_framework_version_from_requirements.assert_called_once_with(
470+
deployment_flavor, mock_requirements_path
471+
)
472+
mock_cast_to_compatible_version.assert_not_called()
473+
mock_image_uris_retrieve.assert_not_called()
474+
475+
421476
def test_validate_input_for_mlflow():
422477
_validate_input_for_mlflow(ModelServer.TORCHSERVE, "pytorch")
423478

tests/unit/sagemaker/serve/utils/test_telemetry_logger.py

+36
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import unittest
1515
from unittest.mock import Mock, patch
1616
from sagemaker.serve import Mode, ModelServer
17+
from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
1718
from sagemaker.serve.utils.telemetry_logger import (
1819
_send_telemetry,
1920
_capture_telemetry,
@@ -32,9 +33,13 @@
3233
"763104351884.dkr.ecr.us-east-1.amazonaws.com/"
3334
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
3435
)
36+
MOCK_PYTORCH_CONTAINER = (
37+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310"
38+
)
3539
MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf"
3640
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
3741
MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test"
42+
MOCK_MODEL_METADATA_FOR_MLFLOW = {MLFLOW_MODEL_PATH: "s3://some_path"}
3843

3944

4045
class ModelBuilderMock:
@@ -239,3 +244,34 @@ def test_construct_url_with_failure_reason_and_extra_info(self):
239244
f"&x-extra={mock_extra_info}"
240245
)
241246
self.assertEquals(ret_url, expected_base_url)
247+
248+
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
249+
def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
250+
mock_model_builder = ModelBuilderMock()
251+
mock_model_builder.serve_settings.telemetry_opt_out = False
252+
mock_model_builder.image_uri = MOCK_PYTORCH_CONTAINER
253+
mock_model_builder._is_mlflow_model = True
254+
mock_model_builder.model_metadata = MOCK_MODEL_METADATA_FOR_MLFLOW
255+
mock_model_builder._is_custom_image_uri = False
256+
mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT
257+
mock_model_builder.model_server = ModelServer.TORCHSERVE
258+
mock_model_builder.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN
259+
260+
mock_model_builder.mock_deploy()
261+
262+
args = mock_send_telemetry.call_args.args
263+
latency = str(args[5]).split("latency=")[1]
264+
expected_extra_str = (
265+
f"{MOCK_FUNC_NAME}"
266+
"&x-modelServer=1"
267+
"&x-imageTag=pytorch-inference:2.0.1-cpu-py310"
268+
f"&x-sdkVersion={SDK_VERSION}"
269+
f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
270+
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
271+
f"&x-mlflowModelPathType=2"
272+
f"&x-latency={latency}"
273+
)
274+
275+
mock_send_telemetry.assert_called_once_with(
276+
"1", 3, MOCK_SESSION, None, None, expected_extra_str
277+
)

0 commit comments

Comments
 (0)