Instructions for Checkpointing with Tensor Parallelism
The SageMaker model parallelism library supports saving partial or full checkpoints with tensor parallelism. The following guide shows how to modify your script to save and load a checkpoint when you use tensor parallelism.
-
Prepare a model object and wrap it with the library's wrapper function
smp.DistributedModel()
.model = MyModel(...) model = smp.DistributedModel(model)
-
Prepare an optimizer for the model. A set of model parameters is an iterable argument required by optimizer functions. To prepare a set of model parameters, you must process
model.parameters()
to assign unique IDs to individual model parameters.If there are parameters with duplicated IDs in the model parameter iterable, loading the checkpointed optimizer state fails. To create an iterable of model parameters with unique IDs for your optimizer, see the following:
unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
-
Wrap the optimizer using the library's wrapper function
smp.DistributedOptimizer()
.optimizer = smp.DistributedOptimizer(optimizer)
-
Save the model and the optimizer state using
smp.save()
. Depending on how you want to save checkpoints, choose one of the following two options: -
Option 1: Save a partial model on each
mp_rank
for a singleMP_GROUP
.model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )
With tensor parallelism, the library saves checkpointed files named in the following format:
checkpoint.pt_{pp_rank}_{tp_rank}
.Note With tensor parallelism, make sure you set the if statement as
if smp.rdp_rank() == 0
instead ofif smp.dp_rank() == 0
. When the optimizer state is sharded with tensor parallelism, all reduced-data parallel ranks must save their own partition of the optimizer state. Using a wrong if statement for checkpointing might result in a stalling training job. For more information about usingif smp.dp_rank() == 0
without tensor parallelism, see General Instruction for Saving and Loadingin the SageMaker Python SDK documentation.
-
Option 2: Save the full model.
if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
Note Consider the following for full checkpointing:
-
If you set
gather_to_rank0=True
, all ranks other than0
return empty dictionaries. -
For full checkpointing, you can only checkpoint the model. Full checkpointing of optimizer states is currently not supported.
-
The full model only needs to be saved at
smp.rank() == 0
.
-
-
-
Load the checkpoints using
smp.load()
. Depending on how you checkpointed in the previous step, choose one of the following two options: -
Option 1: Load the partial checkpoints.
checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
You can set
same_partition_load=True
inmodel.load_state_dict()
for a faster load, if you know that the partition will not change. -
Option 2: Load the full checkpoints.
if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])
The
if smp.rdp_rank() == 0
condition is not required, but it can help avoid redundant loading among differentMP_GROUP
s. Full checkpointing optimizer state dict is currently not supported with tensor parallelism.
-