Activation offloading
Important
In SMP v2.2.0, the activation offloading functionality of the SMP library doesn't work. Use the native PyTorch activation offloading instead.
Typically, the forward pass computes activations at each layer and keeps them in GPU memory until the backward pass for the corresponding layer finishes. Offloading these tensors to CPU memory after forward pass and fetching them back to GPU when they are needed can save substantial GPU memory usage. PyTorch supports offloading activations, but the implementation causes GPUs to be idle while activations are fetched back from CPU during backward pass. This causes a major performance degradation when using activation offloading.
SMP v2 improves this activation offloading. It pre-fetches activations ahead of time before the activations are needed for the GPU to start backward pass on those activations. The pre-fetching feature helps training progresses be run more efficiently without idle GPUs. This results in offering benefits from lower memory usage without a performance degradation.
You can keep the native PyTorch modules for offloading activations in your training script. The following is an example structure of applying the SMP activation offloading feature in your script. Note that activation offloading is applicable only when used together with Activation checkpointing. To learn more about the native PyTorch checkpoint tools for activation offloading, see:
-
checkpoint_wrapper.py
in the PyTorch GitHub repository -
Activation Checkpointing
in the PyTorch blog Scaling Multi-modal Foundation Models in TorchMultimodal with PyTorch Distributed.
You can apply the SMP activation offloading feature on PyTorch activation checkpointingsm_activation_offloading
and activation_loading_horizon
parameters to the SMP configuration dictionary during Step 2: Launch a training
job.
The following code snippets show how to add the SMP initialization module
torch.sagemaker.init()
to your training script and set up the SMP
configuration dictionary in JSON format for training job launcher while following the
two-step process introduced in Get started with the SageMaker model parallelism
library v2. You don’t
need to make any changes to your PyTorch model or PyTorch FSDPsm_activation_offloading
and activation_loading_horizon
parameters, see SMP v2 core
feature configuration parameters.
SMP configuration
{ "activation_loading_horizon": 2, "sm_activation_offloading": True }
In training script
Note
While activating the SMP activation offloading feature, make sure that you also
use the PyTorch offload_wrapper
function and apply it to the root
module. The SMP activation offloading feature uses the root module to determine when
forward pass is done to start pre-fetching.
import torch.sagemaker as tsm tsm.init() # Native PyTorch module for activation offloading from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, offload_wrapper, ) model = FSDP(...) # Activation offloading requires activation checkpointing. apply_activation_checkpointing( model, check_fn=
checkpoint_transformer_layers_policy
, ) model = offload_wrapper(model)