Activation offloading - Amazon SageMaker

Activation offloading

Important

In SMP v2.2.0, the activation offloading functionality of the SMP library doesn't work. Use the native PyTorch activation offloading instead.

Typically, the forward pass computes activations at each layer and keeps them in GPU memory until the backward pass for the corresponding layer finishes. Offloading these tensors to CPU memory after forward pass and fetching them back to GPU when they are needed can save substantial GPU memory usage. PyTorch supports offloading activations, but the implementation causes GPUs to be idle while activations are fetched back from CPU during backward pass. This causes a major performance degradation when using activation offloading.

SMP v2 improves this activation offloading. It pre-fetches activations ahead of time before the activations are needed for the GPU to start backward pass on those activations. The pre-fetching feature helps training progresses be run more efficiently without idle GPUs. This results in offering benefits from lower memory usage without a performance degradation.

You can keep the native PyTorch modules for offloading activations in your training script. The following is an example structure of applying the SMP activation offloading feature in your script. Note that activation offloading is applicable only when used together with Activation checkpointing. To learn more about the native PyTorch checkpoint tools for activation offloading, see:

You can apply the SMP activation offloading feature on PyTorch activation checkpointing. This is done by adding the sm_activation_offloading and activation_loading_horizon parameters to the SMP configuration dictionary during Step 2: Launch a training job.

The following code snippets show how to add the SMP initialization module torch.sagemaker.init() to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in Get started with the SageMaker model parallelism library v2. You don’t need to make any changes to your PyTorch model or PyTorch FSDP configuration. For more information about the sm_activation_offloading and activation_loading_horizon parameters, see SMP v2 core feature configuration parameters.

SMP configuration

{ "activation_loading_horizon": 2, "sm_activation_offloading": True }

In training script

Note

While activating the SMP activation offloading feature, make sure that you also use the PyTorch offload_wrapper function and apply it to the root module. The SMP activation offloading feature uses the root module to determine when forward pass is done to start pre-fetching.

import torch.sagemaker as tsm tsm.init() # Native PyTorch module for activation offloading from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, offload_wrapper, ) model = FSDP(...) # Activation offloading requires activation checkpointing. apply_activation_checkpointing( model, check_fn=checkpoint_transformer_layers_policy, ) model = offload_wrapper(model)