Modify a PyTorch Training Script - Amazon SageMaker

Modify a PyTorch Training Script

In the SageMaker data parallel library v1.4.0 and later, the library is available as a backend option for the PyTorch distributed package. You only need to import the library once at the top of your training script and set it as the PyTorch distributed backend during initialization. With the single line of backend specification, you can keep your PyTorch training script unchanged and directly use the PyTorch distributed modules. To find the latest API documentation for the library, see the SageMaker distributed data parallel APIs for PyTorch in the SageMaker Python SDK documentation. To learn more about the PyTorch distributed package and backend options, see Distributed communication package - torch.distributed.


Because the SageMaker distributed data parallelism library v1.4.0 and later works as a backend of PyTorch distributed, the following smdistributed APIs for the PyTorch distributed package are deprecated.

If you need to use the previous versions of the library (v1.3.0 or before), see the archived SageMaker distributed data parallel library documentation in the SageMaker Python SDK documentation.

Use the SageMaker Distributed Data Parallel Library as the Backend of torch.distributed

To use the SageMaker distributed data parallel library, the only thing you need to do is to import the SageMaker distributed data parallel library’s PyTorch client (smdistributed.dataparallel.torch.torch_smddp). The client registers smddp as a backend for PyTorch. When you initialize the PyTorch distributed process group using the torch.distributed.init_process_group API, make sure you specify 'smddp' to the backend argument.

import smdistributed.dataparallel.torch.torch_smddp import torch.distributed as dist dist.init_process_group(backend='smddp')

The smddp backend currently does not support creating subprocess groups with the torch.distributed.new_group() API. You cannot use the smddp backend concurrently with other process group backends such as NCCL and Gloo.

If you already have a working PyTorch script and only need to add the backend specification, you can proceed to Using the SageMaker Framework Estimators For PyTorch and TensorFlow in the Step 2: Launch a SageMaker Distributed Training Job Using the SageMaker Python SDK topic.

If you still need to modify your training script to properly use the PyTorch distributed package, follow the rest of the procedures on this page.

Preparing a PyTorch Training Script for Distributed Training

The following steps provide additional tips on how to prepare your training script to successfully run a distributed training job using PyTorch.


In v1.4.0, the SageMaker distributed data parallel library supports the following collective primitive data types of the torch.distributed interface: all_reduce, broadcast, reduce, all_gather, and barrier.

  1. Import the PyTorch distributed modules.

    import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP
  2. After parsing arguments and defining a batch size parameter (for example, batch_size=args.batch_size), add two lines of code to resize the batch size per worker (GPU). PyTorch's DataLoader operation does not automatically handle the batch resizing for distributed training.

    batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
  3. Pin each GPU to a single SageMaker data parallel library process with local_rank—this refers to the relative rank of the process within a given node.

    You can retrieve the rank of the process from the LOCAL_RANK environment variable.

    import os local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank)
  4. After defining a model, wrap it with the PyTorch DistributedDataParallel API.

    model = ... # Wrap the model with the PyTorch DistributedDataParallel API model = DDP(model)
  5. When you call the API, specify the total number of processes (GPUs) participating in training across all the nodes in the cluster. This is called world_size, and you can retrieve the number from the torch.distributed.get_world_size() API. Also, specify the rank of each process among all processes using the torch.distributed.get_rank() API.

    from import DistributedSampler train_sampler = DistributedSampler( train_dataset, num_replicas = dist.get_world_size(), rank = dist.get_rank() )
  6. Modify your script to save checkpoints only on the leader process (rank 0). The leader process has a synchronized model. This also avoids other processes overwriting the checkpoints and possibly corrupting the checkpoints.

    if dist.get_rank() == 0:

The following example code shows the structure of a PyTorch training script with smddp as the backend.

import os import torch # SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.torch_smddp # SageMaker data parallel: Import PyTorch's distributed API import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the process group dist.init_process_group(backend='smddp') class Net(nn.Module): ... # Define model def train(...): ... # Model training def test(...): ... # Model evaluation def main(): # SageMaker data parallel: Scale batch size by world size batch_size //= dist.get_world_size() batch_size = max(batch_size, 1) # Prepare dataset train_dataset = torchvision.datasets.MNIST(...) # SageMaker data parallel: Set num_replicas and rank in DistributedSampler train_sampler = train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) train_loader = # SageMaker data parallel: Wrap the PyTorch model with the library's DDP model = DDP(Net().to(device)) # SageMaker data parallel: Pin each GPU to a single library process. local_rank = os.environ["LOCAL_RANK"] torch.cuda.set_device(local_rank) model.cuda(local_rank) # Train optimizer = optim.Adadelta(...) scheduler = StepLR(...) for epoch in range(1, args.epochs + 1): train(...) if rank == 0: test(...) scheduler.step() # SageMaker data parallel: Save model on master node. if dist.get_rank() == 0: if __name__ == '__main__': main()

After you have completed adapting your training script, proceed to Step 2: Launch a SageMaker Distributed Training Job Using the SageMaker Python SDK.