Skip to content

Commit 92eb47d

Browse files
authored
Add support for PyTorch (#243)
Add support for PyTorch framework.
1 parent b8f00ff commit 92eb47d

17 files changed

+1609
-5
lines changed

CHANGELOG.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
CHANGELOG
33
=========
44

5-
1.4.3dev
6-
========
5+
1.5.0
6+
=====
7+
* feature: Add Support for PyTorch Framework
78
* feature: Estimators: add support for TensorFlow 1.7.0
89
* feature: Estimators: add support for TensorFlow 1.8.0
910
* feature: Allow Local Serving of Models in S3

README.rst

+15-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ You can install from source by cloning this repository and issuing a pip install
4949

5050
git clone https://github.com/aws/sagemaker-python-sdk.git
5151
python setup.py sdist
52-
pip install dist/sagemaker-1.4.2.tar.gz
52+
pip install dist/sagemaker-1.5.0.tar.gz
5353

5454
Supported Python versions
5555
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -236,6 +236,20 @@ More details at `Chainer SageMaker Estimators and Models`_.
236236
.. _Chainer SageMaker Estimators and Models: src/sagemaker/chainer/README.rst
237237

238238

239+
PyTorch SageMaker Estimators
240+
-------------------------------
241+
242+
With PyTorch Estimators, you can train and host PyTorch models on Amazon SageMaker.
243+
244+
Supported versions of PyTorch: ``0.4.0``
245+
246+
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.
247+
248+
More details at `PyTorch SageMaker Estimators and Models`_.
249+
250+
.. _PyTorch SageMaker Estimators and Models: src/sagemaker/pytorch/README.rst
251+
252+
239253
AWS SageMaker Estimators
240254
------------------------
241255
Amazon SageMaker provides several built-in machine learning algorithms that you can use for a variety of problem types.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def read(fname):
2323

2424

2525
setup(name="sagemaker",
26-
version="1.4.2",
26+
version="1.5.0",
2727
description="Open source library for training and deploying models on Amazon SageMaker.",
2828
packages=find_packages('src'),
2929
package_dir={'': 'src'},

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def framework_name_from_image(image_name):
156156
else:
157157
# extract framework, python version and image tag
158158
# We must support both the legacy and current image name format.
159-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer):(.*?)-(.*?)-(py2|py3)$')
159+
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer|pytorch):(.*?)-(.*?)-(py2|py3)$')
160160
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
161161
name_match = name_pattern.match(sagemaker_match.group(8))
162162
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))

src/sagemaker/pytorch/README.rst

+711
Large diffs are not rendered by default.

src/sagemaker/pytorch/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from sagemaker.pytorch.estimator import PyTorch
15+
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor
16+
17+
__all__ = [PyTorch, PyTorchModel, PyTorchPredictor]

src/sagemaker/pytorch/defaults.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
PYTORCH_VERSION = '0.4'
16+
PYTHON_VERSION = 'py3'

src/sagemaker/pytorch/estimator.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from sagemaker.estimator import Framework
15+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
16+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
17+
from sagemaker.pytorch.model import PyTorchModel
18+
19+
20+
class PyTorch(Framework):
21+
"""Handle end-to-end training and deployment of custom PyTorch code."""
22+
23+
__framework_name__ = "pytorch"
24+
25+
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
26+
framework_version=PYTORCH_VERSION, **kwargs):
27+
"""
28+
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
29+
Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions
30+
defined in the supplied ``entry_point`` Python script.
31+
32+
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
33+
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
34+
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance
35+
that can be used to perform inference against the hosted model.
36+
37+
Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is
38+
available on the project home-page: https://github.com/aws/sagemaker-python-sdk
39+
40+
Args:
41+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
42+
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
43+
source_dir (str): Path (absolute or relative) to a directory with any other training
44+
source code dependencies aside from tne entry point file (default: None). Structure within this
45+
directory are preserved when training on Amazon SageMaker.
46+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
47+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
48+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
49+
to convert them before training.
50+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
51+
One of 'py2' or 'py3'.
52+
framework_version (str): PyTorch version you want to use for executing your model training code.
53+
List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators
54+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
55+
"""
56+
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
57+
self.py_version = py_version
58+
self.framework_version = framework_version
59+
60+
def train_image(self):
61+
"""Return the Docker image to use for training.
62+
63+
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
64+
find the image to use for model training.
65+
66+
Returns:
67+
str: The URI of the Docker image.
68+
"""
69+
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
70+
self.train_instance_type, framework_version=self.framework_version,
71+
py_version=self.py_version)
72+
73+
def create_model(self, model_server_workers=None):
74+
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
75+
76+
Args:
77+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
78+
If None, server will use one worker per vCPU.
79+
80+
Returns:
81+
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel`` object.
82+
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
83+
"""
84+
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
85+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
86+
container_log_level=self.container_log_level, code_location=self.code_location,
87+
py_version=self.py_version, framework_version=self.framework_version,
88+
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
89+
90+
@classmethod
91+
def _prepare_init_params_from_job_description(cls, job_details):
92+
"""Convert the job description to init params that can be handled by the class constructor
93+
94+
Args:
95+
job_details: the returned job details from a describe_training_job API call.
96+
97+
Returns:
98+
dictionary: The transformed init_params
99+
100+
"""
101+
init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details)
102+
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
103+
104+
init_params['py_version'] = py_version
105+
init_params['framework_version'] = framework_version_from_tag(tag)
106+
107+
training_job_name = init_params['base_job_name']
108+
109+
if framework != cls.__framework_name__:
110+
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
111+
112+
return init_params

src/sagemaker/pytorch/model.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import sagemaker
15+
from sagemaker.fw_utils import create_image_uri
16+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
17+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
18+
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
19+
from sagemaker.utils import name_from_image
20+
21+
22+
class PyTorchPredictor(RealTimePredictor):
23+
"""A RealTimePredictor for inference against PyTorch Endpoints.
24+
25+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for PyTorch
26+
inference."""
27+
28+
def __init__(self, endpoint_name, sagemaker_session=None):
29+
"""Initialize an ``PyTorchPredictor``.
30+
31+
Args:
32+
endpoint_name (str): The name of the endpoint to perform inference on.
33+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
34+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
35+
using the default AWS configuration chain.
36+
"""
37+
super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)
38+
39+
40+
class PyTorchModel(FrameworkModel):
41+
"""An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
42+
43+
__framework_name__ = 'pytorch'
44+
45+
def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION,
46+
framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor,
47+
model_server_workers=None, **kwargs):
48+
"""Initialize an PyTorchModel.
49+
50+
Args:
51+
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
52+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
53+
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
54+
After the endpoint is created, the inference code might use the IAM role,
55+
if it needs to access an AWS resource.
56+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
57+
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
58+
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
59+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
60+
framework_version (str): PyTorch version you want to use for executing your model training code.
61+
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
62+
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
63+
invoking this function on the created endpoint name.
64+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
65+
If None, server will use one worker per vCPU.
66+
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
67+
"""
68+
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
69+
self.py_version = py_version
70+
self.framework_version = framework_version
71+
self.model_server_workers = model_server_workers
72+
73+
def prepare_container_def(self, instance_type):
74+
"""Return a container definition with framework configuration set in model environment variables.
75+
76+
Args:
77+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
78+
79+
Returns:
80+
dict[str, str]: A container definition object usable with the CreateModel API.
81+
"""
82+
deploy_image = self.image
83+
if not deploy_image:
84+
region_name = self.sagemaker_session.boto_session.region_name
85+
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
86+
self.framework_version, self.py_version)
87+
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
88+
self._upload_code(deploy_key_prefix)
89+
deploy_env = dict(self.env)
90+
deploy_env.update(self._framework_env_vars())
91+
92+
if self.model_server_workers:
93+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
94+
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def mxnet_version(request):
8181
return request.param
8282

8383

84+
@pytest.fixture(scope='module', params=["0.4", "0.4.0"])
85+
def pytorch_version(request):
86+
return request.param
87+
88+
8489
@pytest.fixture(scope='module', params=['4.0', '4.0.0'])
8590
def chainer_version(request):
8691
return request.param
@@ -96,6 +101,11 @@ def mxnet_full_version(request):
96101
return request.param
97102

98103

104+
@pytest.fixture(scope='module', params=["0.4.0"])
105+
def pytorch_full_version(request):
106+
return request.param
107+
108+
99109
@pytest.fixture(scope='module', params=['4.0.0'])
100110
def chainer_full_version(request):
101111
return request.param
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
if __name__ == '__main__':
2+
"""For use with integration tests expecting failures."""
3+
raise Exception('This failure is expected.')

0 commit comments

Comments
 (0)