@@ -203,7 +203,7 @@ def default_bucket(self):
203
203
return self ._default_bucket
204
204
205
205
def train (self , image , input_mode , input_config , role , job_name , output_config ,
206
- resource_config , hyperparameters , stop_condition ):
206
+ resource_config , hyperparameters , stop_condition , tags ):
207
207
"""Create an Amazon SageMaker training job.
208
208
209
209
Args:
@@ -232,6 +232,8 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
232
232
keys and values, but ``str()`` will be called to convert them before training.
233
233
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
234
234
service like ``MaxRuntimeInSeconds``.
235
+ tags (list[dict]): List of tags for labeling a training job. For more, see
236
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
235
237
236
238
Returns:
237
239
str: ARN of the training job, if it is created.
@@ -242,7 +244,6 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
242
244
'TrainingImage' : image ,
243
245
'TrainingInputMode' : input_mode
244
246
},
245
- # 'HyperParameters': hyperparameters,
246
247
'InputDataConfig' : input_config ,
247
248
'OutputDataConfig' : output_config ,
248
249
'TrainingJobName' : job_name ,
@@ -253,6 +254,10 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
253
254
254
255
if hyperparameters and len (hyperparameters ) > 0 :
255
256
train_request ['HyperParameters' ] = hyperparameters
257
+
258
+ if tags is not None :
259
+ train_request ['Tags' ] = tags
260
+
256
261
LOGGER .info ('Creating training-job with name: {}' .format (job_name ))
257
262
LOGGER .debug ('train request: {}' .format (json .dumps (train_request , indent = 4 )))
258
263
self .sagemaker_client .create_training_job (** train_request )
0 commit comments