Skip to content

Commit c0ab61a

Browse files
authored
add chainer, bump to 1.3 (#195)
1 parent 785f3b1 commit c0ab61a

19 files changed

+2048
-28
lines changed

CHANGELOG.rst

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.3.0
6+
=======
7+
8+
* feature: Add chainer
59

610
1.2.5
711
========

README.rst

+677-8
Large diffs are not rendered by default.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def read(fname):
2323

2424

2525
setup(name="sagemaker",
26-
version="1.2.5",
26+
version="1.3.0",
2727
description="Open source library for training and deploying models on Amazon SageMaker.",
2828
packages=find_packages('src'),
2929
package_dir={'': 'src'},

src/sagemaker/chainer/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.chainer.estimator import Chainer
16+
from sagemaker.chainer.model import ChainerModel, ChainerPredictor
17+
18+
__all__ = [Chainer, ChainerModel, ChainerPredictor]

src/sagemaker/chainer/defaults.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
CHAINER_VERSION = '4.0.0'

src/sagemaker/chainer/estimator.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.estimator import Framework
16+
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
17+
from sagemaker.chainer.defaults import CHAINER_VERSION
18+
from sagemaker.chainer.model import ChainerModel
19+
20+
21+
class Chainer(Framework):
22+
"""Handle end-to-end training and deployment of custom Chainer code."""
23+
24+
__framework_name__ = "chainer"
25+
26+
# Hyperparameters
27+
_use_mpi = "sagemaker_use_mpi"
28+
_num_processes = "sagemaker_num_processes"
29+
_process_slots_per_host = "sagemaker_process_slots_per_host"
30+
_additional_mpi_options = "sagemaker_additional_mpi_options"
31+
32+
def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None,
33+
additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3',
34+
framework_version=CHAINER_VERSION, **kwargs):
35+
"""
36+
This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker
37+
Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions
38+
defined in the supplied ``entry_point`` Python script.
39+
40+
Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
41+
After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a
42+
hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.chainer.model.ChainerPredictor` instance
43+
that can be used to perform inference against the hosted model.
44+
45+
Technical documentation on preparing Chainer scripts for SageMaker training and using the Chainer Estimator is
46+
available on the project home-page: https://github.com/aws/sagemaker-python-sdk
47+
48+
Args:
49+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
50+
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
51+
use_mpi (bool): If true, entry point is run as an MPI script. By default, the Chainer Framework runs
52+
the entry point with 'mpirun' if more than one instance is used.
53+
num_processes (int): Total number of processes to run the entry point with. By default, the Chainer
54+
Framework runs one process per GPU (on GPU instances), or one process per host (on CPU instances).
55+
process_slots_per_host (int): The number of processes that can run on each instance. By default, this is
56+
set to the number of GPUs on the instance (on GPU instances), or one (on CPU instances).
57+
additional_mpi_options (str): String of options to the 'mpirun' command used to run the entry point.
58+
For example, '-X NCCL_DEBUG=WARN' will pass that option string to the mpirun command.
59+
source_dir (str): Path (absolute or relative) to a directory with any other training
60+
source code dependencies aside from tne entry point file (default: None). Structure within this
61+
directory are preserved when training on Amazon SageMaker.
62+
hyperparameters (dict): Hyperparameters that will be used for training (default: None).
63+
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
64+
For convenience, this accepts other types for keys and values, but ``str()`` will be called
65+
to convert them before training.
66+
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
67+
One of 'py2' or 'py3'.
68+
framework_version (str): Chainer version you want to use for executing your model training code.
69+
List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators
70+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor.
71+
"""
72+
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, **kwargs)
73+
self.py_version = py_version
74+
self.framework_version = framework_version
75+
self.use_mpi = use_mpi
76+
self.num_processes = num_processes
77+
self.process_slots_per_host = process_slots_per_host
78+
self.additional_mpi_options = additional_mpi_options
79+
80+
def hyperparameters(self):
81+
"""Return hyperparameters used by your custom Chainer code during training."""
82+
hyperparameters = super(Chainer, self).hyperparameters()
83+
84+
additional_hyperparameters = {Chainer._use_mpi: self.use_mpi,
85+
Chainer._num_processes: self.num_processes,
86+
Chainer._process_slots_per_host: self.process_slots_per_host,
87+
Chainer._additional_mpi_options: self.additional_mpi_options}
88+
89+
# remove unset keys.
90+
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
91+
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
92+
return hyperparameters
93+
94+
def train_image(self):
95+
"""Return the Docker image to use for training.
96+
97+
The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
98+
find the image to use for model training.
99+
100+
Returns:
101+
str: The URI of the Docker image.
102+
"""
103+
104+
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
105+
self.train_instance_type, framework_version=self.framework_version,
106+
py_version=self.py_version)
107+
108+
def create_model(self, model_server_workers=None):
109+
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``.
110+
111+
Args:
112+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
113+
If None, server will use one worker per vCPU.
114+
115+
Returns:
116+
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel`` object.
117+
See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
118+
"""
119+
return ChainerModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
120+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
121+
container_log_level=self.container_log_level, code_location=self.code_location,
122+
py_version=self.py_version, framework_version=self.framework_version,
123+
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
124+
125+
@classmethod
126+
def _prepare_init_params_from_job_description(cls, job_details):
127+
"""Convert the job description to init params that can be handled by the class constructor
128+
129+
Args:
130+
job_details: the returned job details from a describe_training_job API call.
131+
132+
Returns:
133+
dictionary: The transformed init_params
134+
135+
"""
136+
init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details)
137+
138+
for argument in [Chainer._use_mpi, Chainer._num_processes, Chainer._process_slots_per_host,
139+
Chainer._additional_mpi_options]:
140+
141+
value = init_params['hyperparameters'].pop(argument, None)
142+
if value:
143+
init_params[argument[len('sagemaker_'):]] = value
144+
145+
framework, py_version, tag = framework_name_from_image(init_params.pop('image'))
146+
147+
init_params['py_version'] = py_version
148+
init_params['framework_version'] = framework_version_from_tag(tag)
149+
150+
training_job_name = init_params['base_job_name']
151+
152+
if framework != cls.__framework_name__:
153+
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
154+
return init_params

src/sagemaker/chainer/model.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sagemaker
16+
from sagemaker.fw_utils import create_image_uri
17+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
18+
from sagemaker.chainer.defaults import CHAINER_VERSION
19+
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
20+
from sagemaker.utils import name_from_image
21+
22+
23+
class ChainerPredictor(RealTimePredictor):
24+
"""A RealTimePredictor for inference against Chainer Endpoints.
25+
26+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for Chainer
27+
inference."""
28+
29+
def __init__(self, endpoint_name, sagemaker_session=None):
30+
"""Initialize an ``ChainerPredictor``.
31+
32+
Args:
33+
endpoint_name (str): The name of the endpoint to perform inference on.
34+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
35+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
36+
using the default AWS configuration chain.
37+
"""
38+
super(ChainerPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer)
39+
40+
41+
class ChainerModel(FrameworkModel):
42+
"""An Chainer SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
43+
44+
__framework_name__ = 'chainer'
45+
46+
def __init__(self, model_data, role, entry_point, image=None, py_version='py3', framework_version=CHAINER_VERSION,
47+
predictor_cls=ChainerPredictor, model_server_workers=None, **kwargs):
48+
"""Initialize an ChainerModel.
49+
50+
Args:
51+
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
52+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
53+
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
54+
After the endpoint is created, the inference code might use the IAM role,
55+
if it needs to access an AWS resource.
56+
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
57+
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
58+
image (str): A Docker image URI (default: None). If not specified, a default image for Chainer will be used.
59+
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
60+
framework_version (str): Chainer version you want to use for executing your model training code.
61+
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
62+
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
63+
invoking this function on the created endpoint name.
64+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
65+
If None, server will use one worker per vCPU.
66+
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
67+
"""
68+
super(ChainerModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
69+
**kwargs)
70+
self.py_version = py_version
71+
self.framework_version = framework_version
72+
self.model_server_workers = model_server_workers
73+
74+
def prepare_container_def(self, instance_type):
75+
"""Return a container definition with framework configuration set in model environment variables.
76+
77+
Args:
78+
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
79+
80+
Returns:
81+
dict[str, str]: A container definition object usable with the CreateModel API.
82+
"""
83+
deploy_image = self.image
84+
if not deploy_image:
85+
region_name = self.sagemaker_session.boto_session.region_name
86+
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
87+
self.framework_version, self.py_version)
88+
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
89+
self._upload_code(deploy_key_prefix)
90+
deploy_env = dict(self.env)
91+
deploy_env.update(self._framework_env_vars())
92+
93+
if self.model_server_workers:
94+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
95+
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

src/sagemaker/content_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
CONTENT_TYPE_JSON = 'application/json'
1616
CONTENT_TYPE_CSV = 'text/csv'
1717
CONTENT_TYPE_OCTET_STREAM = 'application/octet-stream'
18+
CONTENT_TYPE_NPY = 'application/x-npy'

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def framework_name_from_image(image_name):
156156
else:
157157
# extract framework, python version and image tag
158158
# We must support both the legacy and current image name format.
159-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet):(.*?)-(.*?)-(py2|py3)$')
159+
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet|chainer):(.*?)-(.*?)-(py2|py3)$')
160160
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
161161
name_match = name_pattern.match(sagemaker_match.group(8))
162162
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))

0 commit comments

Comments
 (0)