19
19
from sagemaker .amazon import validation
20
20
from sagemaker .amazon .hyperparameter import Hyperparameter as hp # noqa
21
21
from sagemaker .amazon .common import write_numpy_to_dense_tensor
22
- from sagemaker .estimator import EstimatorBase
22
+ from sagemaker .estimator import EstimatorBase , _TrainingJob
23
23
from sagemaker .session import s3_input
24
24
from sagemaker .utils import sagemaker_timestamp
25
25
@@ -92,11 +92,38 @@ def _prepare_init_params_from_job_description(cls, job_details):
92
92
del init_params ['image' ]
93
93
return init_params
94
94
95
- def fit (self , records , mini_batch_size = None , ** kwargs ):
95
+ def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
96
+ """Set hyperparameters needed for training.
97
+
98
+ Args:
99
+ * records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
100
+ * mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
101
+ default value will be used.
102
+ * job_name (str): Name of the training job to be created. If not specified, one is generated,
103
+ using the base name given to the constructor if applicable.
104
+ """
105
+ super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (job_name = job_name )
106
+
107
+ feature_dim = None
108
+
109
+ if isinstance (records , list ):
110
+ for record in records :
111
+ if record .channel == 'train' :
112
+ feature_dim = record .feature_dim
113
+ break
114
+ if feature_dim is None :
115
+ raise ValueError ('Must provide train channel.' )
116
+ else :
117
+ feature_dim = records .feature_dim
118
+
119
+ self .feature_dim = feature_dim
120
+ self .mini_batch_size = mini_batch_size
121
+
122
+ def fit (self , records , mini_batch_size = None , wait = True , logs = True , job_name = None ):
96
123
"""Fit this Estimator on serialized Record objects, stored in S3.
97
124
98
125
``records`` should be an instance of :class:`~RecordSet`. This defines a collection of
99
- s3 data files to train this ``Estimator`` on.
126
+ S3 data files to train this ``Estimator`` on.
100
127
101
128
Training data is expected to be encoded as dense or sparse vectors in the "values" feature
102
129
on each Record. If the data is labeled, the label is expected to be encoded as a list of
@@ -110,15 +137,19 @@ def fit(self, records, mini_batch_size=None, **kwargs):
110
137
111
138
Args:
112
139
records (:class:`~RecordSet`): The records to train this ``Estimator`` on
113
- mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a
140
+ mini_batch_size (int or None): The size of each mini-batch to use when training. If `` None`` , a
114
141
default value will be used.
142
+ wait (bool): Whether the call should wait until the job completes (default: True).
143
+ logs (bool): Whether to show the logs produced by the job.
144
+ Only meaningful when wait is True (default: True).
145
+ job_name (str): Training job name. If not specified, the estimator generates a default job name,
146
+ based on the training image name and current timestamp.
115
147
"""
116
- self .feature_dim = records .feature_dim
117
- self .mini_batch_size = mini_batch_size
148
+ self ._prepare_for_training (records , job_name = job_name , mini_batch_size = mini_batch_size )
118
149
119
- data = { records . channel : s3_input ( records . s3_data , distribution = 'ShardedByS3Key' ,
120
- s3_data_type = records . s3_data_type )}
121
- super ( AmazonAlgorithmEstimatorBase , self ). fit ( data , ** kwargs )
150
+ self . latest_training_job = _TrainingJob . start_new ( self , records )
151
+ if wait :
152
+ self . latest_training_job . wait ( logs = logs )
122
153
123
154
def record_set (self , train , labels = None , channel = "train" ):
124
155
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
@@ -180,6 +211,14 @@ def __repr__(self):
180
211
"""Return an unambiguous representation of this RecordSet"""
181
212
return str ((RecordSet , self .__dict__ ))
182
213
214
+ def data_channel (self ):
215
+ """Return a dictionary to represent the training data in a channel for use with ``fit()``"""
216
+ return {self .channel : self .records_s3_input ()}
217
+
218
+ def records_s3_input (self ):
219
+ """Return a s3_input to represent the training data"""
220
+ return s3_input (self .s3_data , distribution = 'ShardedByS3Key' , s3_data_type = self .s3_data_type )
221
+
183
222
184
223
def _build_shards (num_shards , array ):
185
224
if num_shards < 1 :
0 commit comments