Skip to content

Commit 854bafc

Browse files
beniericpintaoz-aws
authored andcommitted
Add model trainer documentation (#1639)
1 parent efc9aa8 commit 854bafc

File tree

6 files changed

+118
-49
lines changed

6 files changed

+118
-49
lines changed

doc/api/training/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Training APIs
55
.. toctree::
66
:maxdepth: 4
77

8+
model_trainer
89
algorithm
910
analytics
1011
automl

doc/api/training/model_trainer.rst

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
ModelTrainer
2+
------------
3+
4+
.. autoclass:: sagemaker.modules.train.model_trainer.ModelTrainer
5+
:members:
6+
7+
Configs
8+
~~~~~~~
9+
10+
.. automodule:: sagemaker.modules.configs
11+
:members:
12+
13+
Distributed
14+
~~~~~~~~~~~
15+
16+
.. automodule:: sagemaker.modules.distributed
17+
:members:

doc/overview.rst

+43-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Using the SageMaker Python SDK
44

55
SageMaker Python SDK provides several high-level abstractions for working with Amazon SageMaker. These are:
66

7+
- **ModelTrainer**: New interface encapsulating training on SageMaker.
78
- **Estimators**: Encapsulate training on SageMaker.
89
- **Models**: Encapsulate built ML models.
910
- **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker endpoint.
@@ -24,8 +25,8 @@ Train a Model with the SageMaker Python SDK
2425
To train a model by using the SageMaker Python SDK, you:
2526

2627
1. Prepare a training script
27-
2. Create an estimator
28-
3. Call the ``fit`` method of the estimator
28+
2. Create a ModelTrainer or Estimator
29+
3. Call the ``train`` method of the ModelTrainer or the ``fit`` method of the Estimator
2930

3031
After you train a model, you can save it, and then serve the model as an endpoint to get real-time inferences or get inferences for an entire dataset by using batch transform.
3132

@@ -85,6 +86,46 @@ If you want to use, for example, boolean hyperparameters, you need to specify ``
8586
For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_.
8687

8788

89+
Using ModelTrainer
90+
==================
91+
92+
To use the ModelTrainer class, you need to provide a few essential parameters such as the training image URI and the source code configuration. The class allows you to spin up a SageMaker training job with minimal parameters, particularly by specifying the source code and training image.
93+
94+
For more information about class definitions see `ModelTrainer <https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html>`_.
95+
96+
Example: Launching a Training Job with Custom Script
97+
98+
.. code:: python
99+
100+
from sagemaker.modules.train import ModelTrainer
101+
from sagemaker.modules.configs import SourceCode, InputData
102+
103+
# Image URI for the training job
104+
pytorch_image = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
105+
106+
# Define the script to be run
107+
source_code = SourceCode(
108+
source_dir="basic-script-mode",
109+
requirements="requirements.txt",
110+
entry_script="custom_script.py",
111+
)
112+
113+
# Define the ModelTrainer
114+
model_trainer = ModelTrainer(
115+
training_image=pytorch_image,
116+
source_code=source_code,
117+
base_job_name="script-mode",
118+
)
119+
120+
# Pass the input data
121+
input_data = InputData(
122+
channel_name="train",
123+
data_source=training_input_path, # S3 path where training data is stored
124+
)
125+
126+
# Start the training job
127+
model_trainer.train(input_data_config=[input_data], wait=False)
128+
88129
Using Estimators
89130
================
90131

src/sagemaker/modules/configs.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module provides the configuration classes used in `sagemaker.modules`.
13+
"""This module provides the configuration classes used in ``sagemaker.modules``.
1414
15-
Some of these classes are re-exported from `sagemaker_core.shapes`. For convinence,
16-
users can import these classes directly from `sagemaker.modules.configs`.
15+
Some of these classes are re-exported from ``sagemaker_core.shapes``. For convinence,
16+
users can import these classes directly from ``sagemaker.modules.configs``.
1717
18-
For more documentation on `sagemaker_core.shapes`, see:
18+
For more documentation on ``sagemaker_core.shapes``, see:
1919
- https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes
2020
"""
2121

@@ -80,14 +80,14 @@ class SourceCode(BaseModel):
8080
The SourceCode class allows the user to specify the source code location, dependencies,
8181
entry script, or commands to be executed in the training job container.
8282
83-
Attributes:
83+
Parameters:
8484
source_dir (Optional[str]):
8585
The local directory containing the source code to be used in the training job container.
8686
requirements (Optional[str]):
87-
The path within `source_dir` to a `requirements.txt` file. If specified, the listed
87+
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
8888
requirements will be installed in the training job container.
8989
entry_script (Optional[str]):
90-
The path within `source_dir` to the entry script that will be executed in the training
90+
The path within ``source_dir`` to the entry script that will be executed in the training
9191
job container. If not specified, command must be provided.
9292
command (Optional[str]):
9393
The command(s) to execute in the training job container. Example: "python my_script.py".
@@ -103,10 +103,10 @@ class SourceCode(BaseModel):
103103
class Compute(shapes.ResourceConfig):
104104
"""Compute.
105105
106-
The Compute class is a subclass of `sagemaker_core.shapes.ResourceConfig`
106+
The Compute class is a subclass of ``sagemaker_core.shapes.ResourceConfig``
107107
and allows the user to specify the compute resources for the training job.
108108
109-
Attributes:
109+
Parameters:
110110
instance_type (Optional[str]):
111111
The ML compute instance type. For information about available instance types,
112112
see https://aws.amazon.com/sagemaker/pricing/.
@@ -152,10 +152,10 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
152152
class Networking(shapes.VpcConfig):
153153
"""Networking.
154154
155-
The Networking class is a subclass of `sagemaker_core.shapes.VpcConfig ` and
155+
The Networking class is a subclass of ``sagemaker_core.shapes.VpcConfig`` and
156156
allows the user to specify the networking configuration for the training job.
157157
158-
Attributes:
158+
Parameters:
159159
security_group_ids (Optional[List[str]]):
160160
The VPC security group IDs, in the form sg-xxxxxxxx. Specify the
161161
security groups for the VPC that is specified in the Subnets field.
@@ -199,15 +199,15 @@ class InputData(BaseModel):
199199
200200
This config allows the user to specify an input data source for the training job.
201201
202-
Will be found at `/opt/ml/input/data/<channel_name>` within the training container.
202+
Will be found at ``/opt/ml/input/data/<channel_name>`` within the training container.
203203
For convience, can be referenced inside the training container like:
204204
205-
```python
206-
import os
207-
input_data_dir = os.environ['SM_CHANNEL_<channel_name>']
208-
```
205+
.. code:: python
209206
210-
Attributes:
207+
import os
208+
input_data_dir = os.environ['SM_CHANNEL_<channel_name>']
209+
210+
Parameters:
211211
channel_name (str):
212212
The name of the input data source channel.
213213
data_source (Union[str, S3DataSource, FileSystemDataSource]):

src/sagemaker/modules/distributed.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class SMP(BaseModel):
2525
For more information on the model parallelism parameters, see:
2626
https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config
2727
28-
Attributes:
28+
Parameters:
2929
hybrid_shard_degree (Optional[int]):
3030
Specifies a sharded parallelism degree for the model.
3131
sm_activation_offloading (Optional[bool]):
@@ -87,10 +87,10 @@ def model_dump(self, *args, **kwargs):
8787
class Torchrun(DistributedConfig):
8888
"""Torchrun.
8989
90-
The Torchrun class configures a job that uses `torchrun` or
91-
`torch.distributed.launch` in the backend to launch distributed training.
90+
The Torchrun class configures a job that uses ``torchrun`` or
91+
``torch.distributed.launch`` in the backend to launch distributed training.
9292
93-
Attributes:
93+
Parameters:
9494
process_count_per_node (int):
9595
The number of processes to run on each node in the training job.
9696
Will default to the number of GPUs available in the container.
@@ -107,10 +107,10 @@ class Torchrun(DistributedConfig):
107107
class MPI(DistributedConfig):
108108
"""MPI.
109109
110-
The MPI class configures a job that uses `mpirun` in the backend to launch
110+
The MPI class configures a job that uses ``mpirun`` in the backend to launch
111111
distributed training.
112112
113-
Attributes:
113+
Parameters:
114114
process_count_per_node (int):
115115
The number of processes to run on each node in the training job.
116116
Will default to the number of GPUs available in the container.

src/sagemaker/modules/train/model_trainer.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -115,25 +115,31 @@ class ModelTrainer(BaseModel):
115115
"""Class that trains a model using AWS SageMaker.
116116
117117
Example:
118-
```python
119-
from sagemaker.modules.train import ModelTrainer
120-
from sagemaker.modules.configs import SourceCode, Compute, InputData
121-
122-
source_code = SourceCode(source_dir="source", entry_script="train.py")
123-
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124-
model_trainer = ModelTrainer(
125-
training_image=training_image,
126-
source_code=source_code,
127-
)
128-
129-
train_data = InputData(channel_name="train", data_source="s3://bucket/train")
130-
model_trainer.train(input_data_config=[train_data])
131-
```
132-
133-
Attributes:
118+
119+
.. code:: python
120+
121+
from sagemaker.modules.train import ModelTrainer
122+
from sagemaker.modules.configs import SourceCode, Compute, InputData
123+
124+
source_code = SourceCode(source_dir="source", entry_script="train.py")
125+
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
126+
model_trainer = ModelTrainer(
127+
training_image=training_image,
128+
source_code=source_code,
129+
)
130+
131+
train_data = InputData(channel_name="train", data_source="s3://bucket/train")
132+
model_trainer.train(input_data_config=[train_data])
133+
134+
training_job = model_trainer._latest_training_job
135+
136+
Parameters:
137+
training_mode (Mode):
138+
The training mode. Valid values are "Mode.LOCAL_CONTAINER" or
139+
"Mode.SAGEMAKER_TRAINING_JOB".
134140
sagemaker_session (Optiona(Session)):
135141
The SageMakerCore session. For convinience, can be imported like:
136-
`from sagemaker.modules import Session`.
142+
``from sagemaker.modules import Session``.
137143
If not specified, a new session will be created.
138144
If the default bucket for the artifacts needs to be updated, it can be done by
139145
passing it in the Session object.
@@ -149,7 +155,7 @@ class ModelTrainer(BaseModel):
149155
running the training job.
150156
distributed (Optional[Union[MPI, Torchrun]]):
151157
The distributed runner for the training job. This is used to configure
152-
a distributed training job. If specifed, `source_code` must also
158+
a distributed training job. If specifed, ``source_code`` must also
153159
be provided.
154160
compute (Optional[Compute]):
155161
The compute configuration. This is used to specify the compute resources for
@@ -176,7 +182,7 @@ class ModelTrainer(BaseModel):
176182
The output data configuration. This is used to specify the output data location
177183
for the training job.
178184
If not specified in the session, will default to
179-
`s3://<default_bucket>/<default_prefix>/<base_job_name>/`.
185+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/``.
180186
input_data_config (Optional[List[Union[Channel, InputData]]]):
181187
The input data config for the training job.
182188
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -194,6 +200,9 @@ class ModelTrainer(BaseModel):
194200
tags (Optional[List[Tag]]):
195201
An array of key-value pairs. You can use tags to categorize your AWS resources
196202
in different ways, for example, by purpose, owner, or environment.
203+
local_container_root (Optional[str]):
204+
The local root directory to store artifacts from a training job launched in
205+
"LOCAL_CONTAINER" mode.
197206
"""
198207

199208
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
@@ -652,9 +661,10 @@ def create_input_data_channel(
652661
key_prefix (Optional[str]): The key prefix to use when uploading data to S3.
653662
Only applicable when data_source is a local file path string.
654663
If not specified, local data will be uploaded to:
655-
s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/
664+
``s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/``
665+
656666
If specified, local data will be uploaded to:
657-
s3://<default_bucket_path>/<key_prefix>/<channel_name>/
667+
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
658668
"""
659669
channel = None
660670
if isinstance(data_source, str):
@@ -881,7 +891,7 @@ def from_recipe(
881891
output_data_config (Optional[OutputDataConfig]):
882892
The output data configuration. This is used to specify the output data location
883893
for the training job.
884-
If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
894+
If not specified, will default to ``s3://<default_bucket>/<base_job_name>/output/``.
885895
input_data_config (Optional[List[Union[Channel, InputData]]]):
886896
The input data config for the training job.
887897
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -910,7 +920,7 @@ def from_recipe(
910920
"""
911921
if compute.instance_type is None:
912922
raise ValueError(
913-
"Must set `instance_type` in compute_config when using training recipes."
923+
"Must set ``instance_type`` in compute_config when using training recipes."
914924
)
915925
device_type = _determine_device_type(compute.instance_type)
916926
if device_type == "cpu":
@@ -970,7 +980,7 @@ def with_tensorboard_output_config(
970980
"""Set the TensorBoard output configuration.
971981
972982
Args:
973-
tensorboard_output_config (TensorBoardOutputConfig):
983+
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
974984
The TensorBoard output configuration.
975985
"""
976986
self._tensorboard_output_config = tensorboard_output_config

0 commit comments

Comments
 (0)