Set up managed tier checkpointing
This section contains setup process for managed tier checkpointing for Amazon SageMaker HyperPod. You’ll learn how to enable the capability on your cluster and implement checkpointing in your training code.
Topics
Prerequisites
Before setting up managed tier checkpointing, ensure you have:
-
An Amazon EKS HyperPod cluster with sufficient CPU memory available for checkpoint allocation
-
PyTorch training workloads and DCP jobs (both are supported)
-
Appropriate IAM permissions for cluster management, including:
-
Amazon CloudWatch and Amazon S3 write permissions for the training pod to read/write checkpoints and push metrics
-
These permissions can be configured via EKS OIDC setup
-
Step 1: Enable managed tier checkpointing for your cluster
Important
You must opt in to use managed tier checkpointing.
Enable managed tier checkpointing through the HyperPod API when creating or
updating your cluster. The service automatically installs the memory management system when
you specify the TieredStorageConfig
parameter. For new
cluster creates:
aws sagemaker update-cluster \ --cluster-name my-training-cluster \ --tiered-storage-config { "Mode": "Enable" "InstanceMemoryAllocationPercentage":
percentage
}
The InstanceMemoryAllocationPercentage
parameter specifies the
(int) of cluster memory to allocate
for checkpointing. The range is 20-100.percentage
Step 2: Install the Python library in your training image
Install the Amazon SageMaker
checkpointing library
# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing
Step 3: Create a checkpoint configuration
Create a CheckpointConfig
object to specify checkpoint behavior. This
includes:
-
Checkpoint locations
-
Frequency of the checkpoints
-
Name of the namespaces
The following example shows a checkpoint configuration:
from amzn_sagemaker_checkpointing.config.sagemaker_checkpoint_config import SageMakerCheckpointConfig from amzn_sagemaker_checkpointing.checkpointing.filesystem import SageMakerTieredStorageWriter, SageMakerTieredStorageReader checkpoint_config = sm_ckpt.CheckpointConfig( world_size = 100, in_memory_namespace:
my-ml-workload
, # Logical grouping for checkpoints s3_base_path: "s3://bucket-name
/checkpointing-path-prefix
/", s3_every_n_steps: 100, # Every 100 steps, save to S3 )
Step 4: Define a SageMaker file system writer
Define your checkpointing file system writer. You can optionally specify a step number during initialization.
Basic writer (step specified in save call):
smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter(checkpoint_config)
Writer with step parameter (step specified at initialization):
smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter( checkpoint_config, step=step_number )
Note
When you specify the step
parameter during writer initialization, the
checkpoint_id
parameter in the save call becomes optional. The step
parameter takes precedence over the checkpoint directory format.
Step 5: Save checkpoints in your training loop
In your training loop, save checkpoints using PyTorch DCP with FileSystemWriter.
Use PyTorch DCP with FileSystemWriter
Call the dist_cp.save()
method with FileSystemWriter
as
input:
Option 1: Using checkpoint_id with step format (when step not specified in writer)
# Construct checkpoint directory with step number checkpoint_dir = f"step_
number
" dist_cp.save_state_dict( state_dict=state_dict, # state_dict is a dictionary containing model parameters, optimizer state, etc. checkpoint_id=checkpoint_dir, # Should contain step number storage_writer=smWriter )
Option 2: Using writer with step parameter (checkpoint_id becomes optional)
dist_cp.save_state_dict( state_dict=state_dict, storage_writer=smWriter # Step already specified in writer initialization )
Note
The checkpoint_id
value (or checkpoint_dir
string) must have
the format step_
. For example,
number
step_5
. When using the step parameter in writer initialization, the
checkpoint_id becomes optional.
Step 6: Load checkpoints for recovery
When you need to load a checkpoint, use PyTorch DCP with FileSystemReader.
Use PyTorch DCP with FileSystemReader
Call the DCP load method with FileSystemReader
as input:
# Define FileSystemReader smReader = sagemaker_checkpointing.SageMakerTieredStorageReader( config=checkpoint_config ) # Load checkpoint dist_cp.load_state_dict( state_dict=state_dict, checkpoint_id=checkpoint_dir, storage_reader=smReader )
Monitoring and validation
You can monitor and validate your managed tier checkpointing operations through metrics and logs.
Custom logging (optional)
You can integrate checkpointing logs with other logs by passing a custom logger to the library. For example, you can add a custom logger to your training code so that all logs from the library are also collected in the training logger.
Enhanced service logging (optional)
For enhanced debugging and service visibility, you can mount the checkpointing log path
/var/log/sagemaker_checkpointing
from within your pod to a path
/var/logs/sagemaker_checkpointing
on your host. This ensures that
only library-specific logs are collected separately. This provides the service team with
enhanced visibility for debugging and support.