Skip to content

Commit 731641c

Browse files
laurenyuChoiByungWook
authored andcommitted
Add tags for training jobs (#209)
1 parent 7f13df0 commit 731641c

File tree

8 files changed

+59
-11
lines changed

8 files changed

+59
-11
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* bug-fix: Remove __all__ and add noqa in __init__
99
* bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans
1010
* bug-fix: Estimators: Remove unused argument job_details for ``EstimatorBase.attach()``
11+
* feature: Estimators: add support for tagging training jobs
1112

1213
1.3.0
1314
=====

src/sagemaker/estimator.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4444

4545
def __init__(self, role, train_instance_count, train_instance_type,
4646
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
47-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
47+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
4848
"""Initialize an ``EstimatorBase`` instance.
4949
5050
Args:
@@ -73,13 +73,16 @@ def __init__(self, role, train_instance_count, train_instance_type,
7373
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
7474
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
7575
using the default AWS configuration chain.
76+
tags (list[dict]): List of tags for labeling a training job. For more, see
77+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
7678
"""
7779
self.role = role
7880
self.train_instance_count = train_instance_count
7981
self.train_instance_type = train_instance_type
8082
self.train_volume_size = train_volume_size
8183
self.train_max_run = train_max_run
8284
self.input_mode = input_mode
85+
self.tags = tags
8386

8487
if self.train_instance_type in ('local', 'local_gpu'):
8588
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
@@ -345,7 +348,8 @@ def start_new(cls, estimator, inputs):
345348
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
346349
input_config=input_config, role=role, job_name=estimator._current_job_name,
347350
output_config=output_config, resource_config=resource_config,
348-
hyperparameters=hyperparameters, stop_condition=stop_condition)
351+
hyperparameters=hyperparameters, stop_condition=stop_condition,
352+
tags=estimator.tags)
349353

350354
return cls(estimator.sagemaker_session, estimator._current_job_name)
351355

src/sagemaker/session.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def default_bucket(self):
203203
return self._default_bucket
204204

205205
def train(self, image, input_mode, input_config, role, job_name, output_config,
206-
resource_config, hyperparameters, stop_condition):
206+
resource_config, hyperparameters, stop_condition, tags):
207207
"""Create an Amazon SageMaker training job.
208208
209209
Args:
@@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
232232
keys and values, but ``str()`` will be called to convert them before training.
233233
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
234234
service like ``MaxRuntimeInSeconds``.
235+
tags (list[dict]): List of tags for labeling a training job. For more, see
236+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
235237
236238
Returns:
237239
str: ARN of the training job, if it is created.
@@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
242244
'TrainingImage': image,
243245
'TrainingInputMode': input_mode
244246
},
245-
# 'HyperParameters': hyperparameters,
246247
'InputDataConfig': input_config,
247248
'OutputDataConfig': output_config,
248249
'TrainingJobName': job_name,
@@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
253254

254255
if hyperparameters and len(hyperparameters) > 0:
255256
train_request['HyperParameters'] = hyperparameters
257+
258+
if tags is not None:
259+
train_request['Tags'] = tags
260+
256261
LOGGER.info('Creating training-job with name: {}'.format(job_name))
257262
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
258263
self.sagemaker_client.create_training_job(**train_request)

tests/unit/test_chainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def _create_train_job(version):
120120
},
121121
'stop_condition': {
122122
'MaxRuntimeInSeconds': 24 * 60 * 60
123-
}
123+
},
124+
'tags': None,
124125
}
125126

126127

tests/unit/test_estimator.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,18 @@ def test_fit_then_fit_again(sagemaker_session):
292292

293293
@patch('time.strftime', return_value=TIMESTAMP)
294294
def test_fit_verify_job_name(strftime, sagemaker_session):
295+
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
295296
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
296297
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
297-
enable_cloudwatch_metrics=True)
298+
enable_cloudwatch_metrics=True, tags=tags)
298299
fw.fit(inputs=s3_input('s3://mybucket/train'))
299300

300301
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
301302

302303
assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics']
303304
assert train_kwargs['image'] == IMAGE_NAME
304305
assert train_kwargs['input_mode'] == 'File'
306+
assert train_kwargs['tags'] == tags
305307
assert train_kwargs['job_name'] == JOB_NAME
306308
assert fw.latest_training_job.name == JOB_NAME
307309

@@ -475,7 +477,8 @@ def test_unsupported_type_in_dict():
475477
'InstanceType': INSTANCE_TYPE,
476478
'VolumeSizeInGB': 30
477479
},
478-
'stop_condition': {'MaxRuntimeInSeconds': 86400}
480+
'stop_condition': {'MaxRuntimeInSeconds': 86400},
481+
'tags': None,
479482
}
480483

481484

tests/unit/test_mxnet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def _create_train_job(version):
9393
},
9494
'stop_condition': {
9595
'MaxRuntimeInSeconds': 24 * 60 * 60
96-
}
96+
},
97+
'tags': None,
9798
}
9899

99100

tests/unit/test_session.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def test_s3_input_all_arguments():
142142
JOB_NAME = 'jobname'
143143

144144
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
145-
# 'HyperParameters': None,
146145
'OutputDataConfig': {
147146
'S3OutputPath': S3_OUTPUT
148147
},
@@ -224,12 +223,45 @@ def test_train_pack_to_request(sagemaker_session):
224223

225224
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
226225
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
227-
hyperparameters=None, stop_condition=stop_cond)
226+
hyperparameters=None, stop_condition=stop_cond, tags=None)
228227

229228
assert sagemaker_session.sagemaker_client.method_calls[0] == (
230229
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
231230

232231

232+
def test_train_pack_to_request_with_optional_params(sagemaker_session):
233+
in_config = [{
234+
'ChannelName': 'training',
235+
'DataSource': {
236+
'S3DataSource': {
237+
'S3DataDistributionType': 'FullyReplicated',
238+
'S3DataType': 'S3Prefix',
239+
'S3Uri': S3_INPUT_URI
240+
}
241+
}
242+
}]
243+
244+
out_config = {'S3OutputPath': S3_OUTPUT}
245+
246+
resource_config = {'InstanceCount': INSTANCE_COUNT,
247+
'InstanceType': INSTANCE_TYPE,
248+
'VolumeSizeInGB': MAX_SIZE}
249+
250+
stop_cond = {'MaxRuntimeInSeconds': MAX_TIME}
251+
252+
hyperparameters = {'foo': 'bar'}
253+
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
254+
255+
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
256+
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
257+
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags)
258+
259+
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
260+
261+
assert actual_train_args['HyperParameters'] == hyperparameters
262+
assert actual_train_args['Tags'] == tags
263+
264+
233265
@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO)
234266
def test_color_wrap(bio):
235267
color_wrap = sagemaker.logs.ColorWrap()

tests/unit/test_tf_estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def _create_train_job(tf_version):
101101
},
102102
'stop_condition': {
103103
'MaxRuntimeInSeconds': 24 * 60 * 60
104-
}
104+
},
105+
'tags': None,
105106
}
106107

107108

0 commit comments

Comments
 (0)