Checkpointing and Fine-Tuning a Model with Model Parallelism
The SageMaker AI model parallelism library provides checkpointing APIs to save the model state and the optimizer state split by the various model parallelism strategies, and to load checkpoints for continuous training from where you want to restart training and fine-tune. The APIs also support options to save the model and optimizer states partially or fully.
Checkpointing a distributed model
Choose one of the following topics depending on the framework between PyTorch and TensorFlow and the version of the SageMaker AI model parallelism library you use.
Topics
Checkpointing a distributed PyTorch model (for the SageMaker AI model parallelism library v1.10.0 and later)
The SageMaker AI model parallelism library provides checkpoint APIs to save and load full or partial checkpoints of the distributed model state and its optimizer state.
Note
This checkpointing method is recommended if you use PyTorch and the SageMaker AI model parallelism library v1.10.0 or later.
Partial checkpointing
To save checkpoints of a model trained with model parallelism, use the smdistributed.modelparallel.torch.save_checkpoint
partial=True
). This
saves each model partition individually. In addition to the model and the optimizer
state, you can also save any additional custom data through the
user_content
argument. The checkpointed model, optimizer, and user
content are saved as separate files. The save_checkpoint
API call
creates checkpoint folders in the following structure.
- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)
To resume training from partial checkpoints, use the smdistributed.modelparallel.torch.resume_from_checkpoint
partial=True
, and specify the checkpoint directory and the tag
used while saving the partial checkpoints. Note that the actual loading of model
weights happens after model partitioning, during the first run of the
smdistributed.modelparallel.torch.step
-decorated training step
function.
When saving a partial checkpoint, the library also saves the model partition decision
as files with .pt
file extension. Conversely, when resuming from the
partial checkpoint, the library loads the partition decision files together. Once the
partition decision is loaded, you can't change the partition.
The following code snippet shows how to set the checkpoint APIs in a PyTorch training script.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=True
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=True
)
Full checkpointing
To save the final model artifact for inference purposes, use the
smdistributed.modelparallel.torch.save_checkpoint
API with
partial=False
, which combines the model partitions to create a single
model artifact. Note that this does not combine the optimizer states.
To initialize training with particular weights, given a full model checkpoint, you can
use the smdistributed.modelparallel.torch.resume_from_checkpoint
API with
partial=False
. Note that this does not load optimizer states.
Note
With tensor parallelism, in general, the state_dict
must be
translated between the original model implementation and the
DistributedModel
implementation. Optionally, you can provide the
state_dict
translation function as an argument to the
smdistributed.modelparallel.torch.resume_from_checkpoint
. However,
for Supported Models Out of the Box, the library takes care of this translation automatically.
The following code shows an example of how to use the checkpoint APIs for fully checkpointing a PyTorch model trained with model parallelism.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=False
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=False
)
Checkpointing a distributed PyTorch model (for the SageMaker AI model parallelism library between v1.6.0 and v1.9.0)
The SageMaker AI model parallelism library provides Python functions for saving partial or
full checkpoints for training jobs with tensor parallelism. The following procedure
shows how to use smp.save()
smp.load()
Note
This checkpointing method is recommended if you use PyTorch, Tensor Parallelism, and the SageMaker AI model parallelism library between v1.6.0 and v1.9.0.
-
Prepare a model object and wrap it with the library's wrapper function
smp.DistributedModel()
.model = MyModel(...) model = smp.DistributedModel(model)
-
Prepare an optimizer for the model. A set of model parameters is an iterable argument required by optimizer functions. To prepare a set of model parameters, you must process
model.parameters()
to assign unique IDs to individual model parameters.If there are parameters with duplicated IDs in the model parameter iterable, loading the checkpointed optimizer state fails. To create an iterable of model parameters with unique IDs for your optimizer, see the following:
unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
-
Wrap the optimizer using the library's wrapper function
smp.DistributedOptimizer()
.optimizer = smp.DistributedOptimizer(optimizer)
-
Save the model and the optimizer state using
smp.save()
. Depending on how you want to save checkpoints, choose one of the following two options: -
Option 1: Save a partial model on each
mp_rank
for a singleMP_GROUP
.model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )
With tensor parallelism, the library saves checkpointed files named in the following format:
checkpoint.pt_{pp_rank}_{tp_rank}
.Note
With tensor parallelism, make sure you set the if statement as
if smp.rdp_rank() == 0
instead ofif smp.dp_rank() == 0
. When the optimizer state is sharded with tensor parallelism, all reduced-data parallel ranks must save their own partition of the optimizer state. Using a wrong if statement for checkpointing might result in a stalling training job. For more information about usingif smp.dp_rank() == 0
without tensor parallelism, see General Instruction for Saving and Loadingin the SageMaker Python SDK documentation.
-
Option 2: Save the full model.
if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
Note
Consider the following for full checkpointing:
-
If you set
gather_to_rank0=True
, all ranks other than0
return empty dictionaries. -
For full checkpointing, you can only checkpoint the model. Full checkpointing of optimizer states is currently not supported.
-
The full model only needs to be saved at
smp.rank() == 0
.
-
-
-
Load the checkpoints using
smp.load()
. Depending on how you checkpointed in the previous step, choose one of the following two options: -
Option 1: Load the partial checkpoints.
checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
You can set
same_partition_load=True
inmodel.load_state_dict()
for a faster load, if you know that the partition will not change. -
Option 2: Load the full checkpoints.
if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])
The
if smp.rdp_rank() == 0
condition is not required, but it can help avoid redundant loading among differentMP_GROUP
s. Full checkpointing optimizer state dict is currently not supported with tensor parallelism.
-
Checkpointing a distributed TensorFlow model
To save a TensorFlow model while training with model parallelism, use the following functions provided by the SageMaker AI model parallelism library.
Fine-tuning a distributed model
The fine-tuning needs to be configured in your training script. The following code
snippet shows an example structure of a training script using the AutoModelForCausalLMsmdistributed.model.parallel.torch
modules and settings
for fine-tuning.
Note
Fine-tuning a distributed transformer (a Transformer model wrapped by
smp.DistributedModel()
) with the smp.delayed_param_initialization
import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker AI with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker AI model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker AI model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)
For a complete example of training scripts and Jupyter notebooks, see the GPT-2 examples for PyTorch