Modify a PyTorch Training Script
In the SageMaker data parallel library v1.4.0 and later, the library is available as a
backend option for the PyTorch distributed
package
Important
Because the SageMaker distributed data parallelism library v1.4.0 and later works as a
backend of PyTorch distributed, the following smdistributed APIs
-
smdistributed.dataparallel.torch.distributed
is deprecated. Use the torch.distributedpackage instead. -
smdistributed.dataparallel.torch.parallel.DistributedDataParallel
is deprecated. Use the torch.nn.parallel.DistributedDataParallelAPI instead.
If you need to use the previous versions of the library (v1.3.0 or before), see the
archived SageMaker distributed data parallel library documentation
Use the SageMaker Distributed Data
Parallel Library as the Backend of torch.distributed
To use the SageMaker distributed data parallel library, the only thing you need to do is to
import the SageMaker distributed data parallel library’s PyTorch client
(smdistributed.dataparallel.torch.torch_smddp
). The client registers
smddp
as a backend for PyTorch. When you initialize the PyTorch distributed
process group using the torch.distributed.init_process_group
API, make sure
you specify 'smddp'
to the backend
argument.
import smdistributed.dataparallel.torch.torch_smddp import torch.distributed as dist dist.init_process_group(backend='smddp')
Note
The smddp
backend currently does not support creating subprocess groups
with the torch.distributed.new_group()
API. You cannot use the
smddp
backend concurrently with other process group backends such as NCCL
and Gloo.
If you already have a working PyTorch script and only need to add the backend specification, you can proceed to Using the SageMaker Framework Estimators For PyTorch and TensorFlow in the Step 2: Launch a SageMaker Distributed Training Job Using the SageMaker Python SDK topic.
If you still need to modify your training script to properly use the PyTorch distributed package, follow the rest of the procedures on this page.
Preparing a PyTorch Training Script for Distributed Training
The following steps provide additional tips on how to prepare your training script to successfully run a distributed training job using PyTorch.
Note
In v1.4.0, the SageMaker distributed data parallel library supports the following
collective primitive data types of the torch.distributedall_reduce
, broadcast
, reduce
,
all_gather
, and barrier
.
-
Import the PyTorch distributed modules.
import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
-
After parsing arguments and defining a batch size parameter (for example,
batch_size=args.batch_size
), add two lines of code to resize the batch size per worker (GPU). PyTorch's DataLoader operation does not automatically handle the batch resizing for distributed training.batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
-
Pin each GPU to a single SageMaker data parallel library process with
local_rank
—this refers to the relative rank of the process within a given node.You can retrieve the rank of the process from the
LOCAL_RANK
environment variable.import os local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank)
-
After defining a model, wrap it with the PyTorch
DistributedDataParallel
API.model = ... # Wrap the model with the PyTorch DistributedDataParallel API model = DDP(model)
-
When you call the
torch.utils.data.distributed.DistributedSampler
API, specify the total number of processes (GPUs) participating in training across all the nodes in the cluster. This is calledworld_size
, and you can retrieve the number from thetorch.distributed.get_world_size()
API. Also, specify the rank of each process among all processes using thetorch.distributed.get_rank()
API.from torch.utils.data.distributed import DistributedSampler train_sampler = DistributedSampler( train_dataset, num_replicas = dist.get_world_size(), rank = dist.get_rank() )
-
Modify your script to save checkpoints only on the leader process (rank 0). The leader process has a synchronized model. This also avoids other processes overwriting the checkpoints and possibly corrupting the checkpoints.
if dist.get_rank() == 0: torch.save(...)
The following example code shows the structure of a PyTorch training script with
smddp
as the backend.
import os import torch # SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.torch_smddp # SageMaker data parallel: Import PyTorch's distributed API import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the process group dist.init_process_group(backend='smddp') class Net(nn.Module): ... # Define model def train(...): ... # Model training def test(...): ... # Model evaluation def main(): # SageMaker data parallel: Scale batch size by world size batch_size //= dist.get_world_size() batch_size = max(batch_size, 1) # Prepare dataset train_dataset = torchvision.datasets.MNIST(...) # SageMaker data parallel: Set num_replicas and rank in DistributedSampler train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) train_loader = torch.utils.data.DataLoader(..) # SageMaker data parallel: Wrap the PyTorch model with the library's DDP model = DDP(Net().to(device)) # SageMaker data parallel: Pin each GPU to a single library process. local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank) model.cuda(local_rank) # Train optimizer = optim.Adadelta(...) scheduler = StepLR(...) for epoch in range(1, args.epochs + 1): train(...) if rank == 0: test(...) scheduler.step() # SageMaker data parallel: Save model on master node. if dist.get_rank() == 0: torch.save(...) if __name__ == '__main__': main()
After you have completed adapting your training script, proceed to Step 2: Launch a SageMaker Distributed Training Job Using the SageMaker Python SDK.