The SageMaker model parallel library v2 reference - Amazon SageMaker

The SageMaker model parallel library v2 reference

The following are references for the SageMaker model parallel library v2 (SMP v2).

SMP v2 core feature configuration parameters

The following is a complete list of parameters to activate and configure the Core features of the SageMaker model parallelism library v2. These must be written in JSON format and passed to the PyTorch estimator in the SageMaker Python SDK or saved as a JSON file for SageMaker HyperPod.

{ "hybrid_shard_degree": Integer, "sm_activation_offloading": Boolean, "activation_loading_horizon": Integer, "fsdp_cache_flush_warnings": Boolean, "allow_empty_shards": Boolean, "tensor_parallel_degree": Integer, "expert_parallel_degree": Integer, "random_seed": Integer }
  • hybrid_shard_degree (Integer) – Specifies a sharded parallelism degree. The value must be an integer between 0 and world_size. The default value is 0.

    • If set to 0, it falls back to the native PyTorch implementation and API in the script when tensor_parallel_degree is 1. Otherwise, it computes the largest possible hybrid_shard_degree based on tensor_parallel_degree and world_size. When falling back to the native PyTorch FSDP use cases, if FULL_SHARD is the strategy you use, it shards across the whole cluster of GPUs. If HYBRID_SHARD or _HYBRID_SHARD_ZERO2 was the strategy, it is equivalent to hybrid_shard_degree of 8. When tensor parallelism is enabled, it shards based on the revised hybrid_shard_degree.

    • If set to 1, it falls back to the native PyTorch implementation and API for NO_SHARD in the script when tensor_parallel_degree is 1. Otherwise, it's equivalent to NO_SHARD within any given tensor parallel groups.

    • If set to an integer between 2 and world_size, sharding happens across the specified number of GPUs. If you don't set up sharding_strategy in the FSDP script, it gets overridden to HYBRID_SHARD. If you set _HYBRID_SHARD_ZERO2, the sharding_strategy you specify is used.

  • sm_activation_offloading (Boolean) – Specifies whether to enable the SMP activation offloading implementation. If False, offloading uses the native PyTorch implementation. If True, it uses the SMP activation offloading implementation. You also need to use the PyTorch activation offload wrapper (torch.distributed.algorithms._checkpoint.checkpoint_wrapper.offload_wrapper) in your script. To learn more, see Activation offloading. The default value is True.

  • activation_loading_horizon (Integer) – An integer specifying the activation offloading horizon type for FSDP. This is the maximum number of checkpointed or offloaded layers whose inputs can be in the GPU memory simultaneously. To learn more, see Activation offloading. The input value must be a positive integer. The default value is 2.

  • fsdp_cache_flush_warnings (Boolean) – Detects and warns if cache flushes happen in the PyTorch memory manager, because they can degrade computational performance. The default value is True.

  • allow_empty_shards (Boolean) – Whether to allow empty shards when sharding tensors if tensor is not divisible. This is an experimental fix for crash during checkpointing in certain scenarios. Disabling this falls back to the original PyTorch behavior. The default value is False.

  • tensor_parallel_degree (Integer) – Specifies a tensor parallelism degree. The value must be between 1 and world_size. The default value is 1. Passing a value greater than 1 does not enable tensor parallelism automatically. You also need to use the torch.sagemaker.transform API to wrap the model in your training script. To learn more, see Tensor parallelism.

  • expert_parallel_degree (Integer) – Specifies a expert parallelism degree. The value must be between 1 and world_size. The default value is 1. Passing a value greater than 1 does not enable expert parallelism automatically; make sure that you wrap the MoE model with the torch.sagemaker.transform API in your training script.

  • random_seed (Integer) – A seed number for the random operations in distributed modules by SMP tensor parallelism or expert parallelism. This seed will be added to tensor-parallel or expert-parallel ranks to set the actual seed for each rank. It is unique for each tensor-parallel and expert-parallel rank. SMP v2 makes sure that the random number generated across tensor-parallel and expert-parallel ranks matches the non-tensor-parallelism and non-expert-parallelism cases respectively.

Reference for the SMP v2 torch.sagemaker package

This section is a reference for the torch.sagemaker package provided by SMP v2.

torch.sagemaker.delayed_param.DelayedParamIniter

An API for applying Delayed parameter initialization to a PyTorch model.

class torch.sagemaker.delayed_param.DelayedParamIniter( model: nn.Module, init_method_using_config : Callable = None, verbose: bool = False, )

Parameters

  • model (nn.Module) – A PyTorch model to wrap and apply the delayed parameter initialization functionality of SMP v2.

  • init_method_using_config (Callable) – If you use the tensor parallel implementation of SMP v2 or supported Hugging Face Transformer models compatible with the SMP tensor parallelism, keep this parameter at the default value, which is None. By default, the DelayedParamIniter API finds out how to initialize the given model correctly. For any other models, you need to create a custom parameter initialization function and add it to your script. The following code snippet is the default init_method_using_config function that SMP v2 implemented for the Hugging Face Transformer models compatible with the SMP tensor parallelism. Use the following code snippet as a reference for creating your own initialization configuration function, adding it to your script, and passing it to the init_method_using_config parameter of the SMP DelayedParamIniter API.

    from torch.sagemaker.utils.module_utils import empty_module_params, move_buffers_to_device # Define a custom init config function. def custom_init_method_using_config(module): d = torch.cuda.current_device() empty_module_params(module, device=d) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) move_buffers_to_device(module, device=d) delayed_initer = DelayedParamIniter(model, init_method_using_config=custom_init_method_using_config)

    For more information about the torch.sagemaker.module_util functions in the preceding code snippet, see torch.sagemaker util functions and properties.

  • verbose (Boolean) – Whether to enable more detailed logging during initialization and validation. The default value is False.

Methods

  • get_param_init_fn() – Returns the parameter initialization function that you can pass to the param_init_fn argument of the PyTorch FSDP wrapper class.

  • get_post_param_init_fn() – Returns the parameter initialization function that you can pass to the post_param_init_fn argument of the PyTorch FSDP wrapper class. This is needed when you have tied weights in the model. The model must implement the method tie_weights. For more information, see the Notes on tied weight in Delayed parameter initialization.

  • count_num_params (module: nn.Module, *args: Tuple[nn.Parameter]) – Tracks how many parameters are being initialized by the parameter initialization function. This helps implement the following validate_params_and_buffers_inited method. You usually don’t need to call this function explicitly, because the validate_params_and_buffers_inited method implicitly calls this method in the backend.

  • validate_params_and_buffers_inited (enabled: bool=True) – This is a context manager that helps validate that the number of parameters initialized matches the total number of parameters in the model. It also validates that all parameters and buffers are now on GPU devices instead of meta devices. It raises AssertionErrors if these conditions are not met. This context manager is only optional and you're not required to use this context manager to initialize parameters.

torch.sagemaker.moe.moe_config.MoEConfig

A configuration class for setting up the SMP-implementation of Mixture-of-Experts (MoE). You can specify MoE configuration values through this class and pass it to the torch.sagemaker.transform API call. To learn more about the usage of this class for training MoE models, see Expert parallelism.

class torch.sagemaker.moe.moe_config.MoEConfig( smp_moe=True, random_seed=12345, moe_load_balancing="sinkhorn", global_token_shuffle=False, moe_all_to_all_dispatcher=True, moe_aux_loss_coeff=0.001, moe_z_loss_coeff=0.001 )
  • smp_moe (Boolean) - Whether to use the SMP-implementation of MoE. The default value is True.

  • random_seed (Integer) - A seed number for the random operations in expert-parallel distributed modules. This seed will be added to the expert parallel rank to set the actual seed for each rank. It is unique for each expert parallel rank. The default value is 12345.

  • moe_load_balancing (String) - Specify the load balancing type of the MoE router. Valid options are aux_loss, sinkhorn, balanced, and none. The default value is sinkhorn.

  • global_token_shuffle (Boolean) - Whether to shuffle tokens across EP ranks within the same EP group. The default value is False.

  • moe_all_to_all_dispatcher (Boolean) - Whether to use all-to-all dispatcher for the communications in MoE. The default value is True.

  • moe_aux_loss_coeff (Float) - A coefficient for auxiliary load balancing loss. The default value is 0.001.

  • moe_z_loss_coeff (Float) - Coefficient for z-loss. The default value is 0.001.

torch.sagemaker.nn.attn.FlashSelfAttention

An API for using FlashAttention with SMP v2.

class torch.sagemaker.nn.attn.FlashSelfAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, triton_flash_attention: bool = False, use_alibi: bool = False, )

Parameters

  • attention_dropout_prob (float) – The dropout probability to apply to attention. The default value is 0.0.

  • scale (float) – If passed, this scale factor will be applied for softmax. If set to None (which is also the default value), the scale factor is 1 / sqrt(attention_head_size). The default value is None.

  • triton_flash_attention (bool) – If passed, Triton implementation of flash attention will be used. This is necessary to supports Attention with Linear Biases (ALiBi) (see the following use_alibi parameter). This version of the kernel doesn’t support dropout. The default value is False.

  • use_alibi (bool) – If passed, it enables Attention with Linear Biases (ALiBi) using the mask provided. When using ALiBi, it needs an attention mask prepared as follows. The default value is False.

    def generate_alibi_attn_mask(attention_mask, batch_size, seq_length, num_attention_heads, alibi_bias_max=8): device, dtype = attention_mask.device, attention_mask.dtype alibi_attention_mask = torch.zeros( 1, num_attention_heads, 1, seq_length, dtype=dtype, device=device ) alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view( 1, 1, 1, seq_length ) m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device) m.mul_(alibi_bias_max / num_attention_heads) alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1))) alibi_attention_mask.add_(alibi_bias) alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length] if attention_mask is not None and attention_mask.bool().any(): alibi_attention_mask.masked_fill( attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf") ) return alibi_attention_mask

Methods

  • forward(self, qkv, attn_mask=None, causal=False, cast_dtype=None, layout="b h s d") – A regular PyTorch module function. When a module(x) is called, SMP runs this function automatically.

    • qkvtorch.Tensor of the following form: (batch_size x seqlen x (3 x num_heads) x head_size) or (batch_size, (3 x num_heads) x seqlen x head_size), a tuple of torch.Tensors each of which might be of shape (batch_size x seqlen x num_heads x head_size), or (batch_size x num_heads x seqlen x head_size). An appropriate layout arg must be passed based on the shape.

    • attn_masktorch.Tensor of the following form (batch_size x 1 x 1 x seqlen). To enable this attention mask parameter, it requires triton_flash_attention=True and use_alibi=True. To learn how to generate an attention mask using this method, see the code examples at FlashAttention. The default value is None.

    • causal – When set to False, which is the default value of the argument, no mask is applied. When set to True, the forward method uses the standard lower triangular mask. The default value is False.

    • cast_dtype – When set to a particular dtype, it casts the qkv tensors to that dtype before attn. This is useful for implementations such as the Hugging Face Transformer GPT-NeoX model, which has q and k with fp32 after rotary embeddings. If set to None, no cast is applied. The default value is None.

    • layout (string) – Available values are b h s d or b s h d. This should be set to the layout of qkv tensors passed, so appropriate transformations can be applied for attn. The default value is b h s d.

Returns

A single torch.Tensor with shape (batch_size x num_heads x seq_len x head_size).

torch.sagemaker.nn.attn.FlashGroupedQueryAttention

An API for using FlashGroupedQueryAttention with SMP v2. To learn more about the usage of this API, see Use FlashAttention kernels for grouped-query attention.

class torch.sagemaker.nn.attn.FlashGroupedQueryAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, )

Parameters

  • attention_dropout_prob (float) – The dropout probability to apply to attention. The default value is 0.0.

  • scale (float) – If passed, this scale factor is applied for softmax. If set to None, 1 / sqrt(attention_head_size) is used as the scale factor. The default value is None.

Methods

  • forward(self, q, kv, causal=False, cast_dtype=None, layout="b s h d") – A regular PyTorch module function. When a module(x) is called, SMP runs this function automatically.

    • qtorch.Tensor of the following form (batch_size x seqlen x num_heads x head_size) or (batch_size x num_heads x seqlen x head_size). Appropriate layout arg must be passed based on the shape.

    • kvtorch.Tensor of the following form (batch_size x seqlen x (2 x num_heads) x head_size) or (batch_size, (2 x num_heads) x seqlen x head_size), or a tuple of two torch.Tensors, each of which might be of shape (batch_size x seqlen x num_heads x head_size) or (batch_size x num_heads x seqlen x head_size). Appropriate layout argument must also be passed based on the shape.

    • causal – When set to False, which is the default value of the argument, no mask is applied. When set to True, the forward method uses the standard lower triangular mask. The default value is False.

    • cast_dtype – When set to a particular dtype, it casts the qkv tensors to that dtype before attn. This is useful for implementations such as Hugging Face Transformers GPT-NeoX, which has q,k with fp32 after rotary embeddings. If set to None, no cast is applied. The default value is None.

    • layout (string) – Available values are "b h s d" or "b s h d". This should be set to the layout of qkv tensors passed, so appropriate transformations can be applied for attn. The default value is "b h s d".

Returns

Returns a single torch.Tensor (batch_size x num_heads x seq_len x head_size) that represents the output of attention computation.

torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention

An API that supports FlashAttention for the Llama model. This API uses the torch.sagemaker.nn.attn.FlashGroupedQueryAttention API at low level. To learn how to use this, see Use FlashAttention kernels for grouped-query attention.

class torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention( config: LlamaConfig )

Parameters

  • config – A FlashAttention configuration for the Llama model.

Methods

  • forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)

    • hidden_states (torch.Tensor) – Hidden states of a tensor in form of (batch_size x seq_len x num_heads x head_size).

    • attention_mask (torch.LongTensor) – Mask to avoid performing attention on padding token indices in form of (batch_size x seqlen). The default value is None.

    • position_ids (torch.LongTensor) – When not being None, it is in form of (batch_size x seqlen), indicating the indices of positions of each input sequence token in the position embeddings. The default value is None.

    • past_key_value (Cache) – Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks). The default value is None.

    • output_attentions (bool) – Indicates whether to return the attentions tensors of all attention layers. The default value is False.

    • use_cache (bool) – Indicates whether to return past_key_values key value states. The default value is False.

Returns

Returns a single torch.Tensor (batch_size x num_heads x seq_len x head_size) that represents the output of attention computation.

torch.sagemaker.transform

SMP v2 provides this torch.sagemaker.transform() API for transforming Hugging Face Transformer models to SMP model implementations and enabling the SMP tensor parallelism.

torch.sagemaker.transform( model: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, config: Optional[Dict] = None, load_state_dict_from_rank0: bool = False )

SMP v2 maintains transformation policies for the Hugging Face Transformer models compatible with the SMP tensor parallelism by converting the configuration of the Hugging Face Transformer models to the SMP transformer configuration.

Parameters

  • model (torch.nn.Module) – A model from Hugging Face Transformer models compatible with the SMP tensor parallelism to transform and apply the tensor parallelism feature of the SMP library.

  • device (torch.device) – If passed, a new model is created on this device. If the original module has any parameter on meta device (see Delayed parameter initialization), then the transformed module will also be created on meta device, ignoring the argument passed here. The default value is None.

  • dtype (torch.dtype) – If passed, sets this as the dtype context manager for the creation of the model and creates a model with this dtype. This is typically unnecessary, as we want to create the model with fp32 when using MixedPrecision, and fp32 is the default dtype in PyTorch. The default value is None.

  • config (dict) – This is a dictionary for configuring the SMP transformer. The default value is None.

  • load_state_dict_from_rank0 (Boolean) – By default, this module creates a new instance of the model with new weights. When this argument is set to True, SMP tries to load the state dictionary of the original PyTorch model from the 0th rank into transformed model for the tensor parallel group that the 0th rank is part of. When this is set to True, rank 0 can’t have any parameters on meta device. Only the first tensor parallel group populates the weights from the 0th rank after this transform call. You need to set sync_module_states to True in the FSDP wrapper to get these weights from the first tensor parallel group to all other processes. With this activated, the SMP library loads the state dictionary from the original model. The SMP library takes the state_dict of the model before transform, converts it to match the structure of the transformed model, shards it for each tensor parallel rank, communicates this state from the 0th rank to other ranks in the tensor parallel group that the 0th rank is part of, and loads it. The default value is False.

Returns

Returns a transformed model that you can wrap with PyTorch FSDP. When load_state_dict_from_rank0 is set to True, the tensor parallel group that involves rank 0 has weights loaded from the original state dictionary on rank 0. When using Delayed parameter initialization on the original model, only these ranks have the actual tensors on CPUs for the parameters and buffers of the transformed model. The rest of the ranks continue to have the parameters and buffers on the meta device to save memory.

torch.sagemaker util functions and properties

torch.sagemaker util functions
  • torch.sagemaker.init(config: Optional[Union[str, Dict[str, Any]]] = None) -> None – Initializes the PyTorch training job with SMP.

  • torch.sagemaker.is_initialized() -> bool – Checks whether the training job is initialized with SMP. When falling back to the native PyTorch while the job is initialized with SMP, some of the properties are not relevant and become None, as indicated in the following Properties list.

  • torch.sagemaker.utils.module_utils.empty_module_params(module: nn.Module, device: Optional[torch.device] = None, recurse: bool = False) -> nn.Module – Creates empty parameters on the given device if any, and it can be recursive for all nested modules if specified.

  • torch.sagemaker.utils.module_utils.move_buffers_to_device(module: nn.Module, device: torch.device, recurse: bool = False) -> nn.Module – Moves module buffers to the given device, and it can be recursive for all nested modules if specified.

Properties

torch.sagemaker.state holds multiple useful properties after the initialization of SMP with torch.sagemaker.init.

  • torch.sagemaker.state.hybrid_shard_degree (int) – The sharded data parallelism degree, a copy from user input in the SMP configuration passed to torch.sagemaker.init(). To learn more, see Get started with the SageMaker model parallelism library v2.

  • torch.sagemaker.state.rank (int) – The global rank for the device, in the range of [0, world_size).

  • torch.sagemaker.state.rep_rank_process_group (torch.distributed.ProcessGroup) – The process group including all devices with the same replication rank. Note the subtle but fundamental difference with torch.sagemaker.state.tp_process_group. When falling back to native PyTorch, it returns None.

  • torch.sagemaker.state.tensor_parallel_degree (int) – The tensor parallelism degree, a copy from user input in the SMP configuration passed to torch.sagemaker.init(). To learn more, see Get started with the SageMaker model parallelism library v2.

  • torch.sagemaker.state.tp_size (int) – An alias to torch.sagemaker.state.tensor_parallel_degree.

  • torch.sagemaker.state.tp_rank (int) – The tensor parallelism rank for the device in the range of [0, tp_size), determined by the tensor parallelism degree and the ranking mechanism.

  • torch.sagemaker.state.tp_process_group (torch.distributed.ProcessGroup) – The tensor parallel process group including all devices with the same rank in other dimensions (for example, sharded data parallelism and replication) but unique tensor parallel ranks. When falling back to native PyTorch, it returns None.

  • torch.sagemaker.state.world_size (int) – The total number of devices used in training.

Upgrade from SMP v1 to SMP v2

To move from SMP v1 to SMP v2, you must make script changes to remove the SMP v1 APIs and apply the SMP v2 APIs. Instead of starting from your SMP v1 script, we recommend you start from a PyTorch FSDP script, and follow the instructions at Get started with the SageMaker model parallelism library v2.

To bring SMP v1 models to SMP v2, in SMP v1 you must collect the full model state dictionary and apply the translation functions on the model state dictionary to convert it into the Hugging Face Transformers model checkpoint format. Then in SMP v2, as discussed in Save and load checkpoints while using SMP, you can load the Hugging Face Transformers model checkpoints, and then continue with using the PyTorch checkpoint APIs with SMP v2. To use SMP with your PyTorch FSDP model, make sure that you move to SMP v2 and make changes to your training script to use PyTorch FSDP and other latest features.

import smdistributed.modelparallel.torch as smp # Create model model = ... model = smp.DistributedModel(model) # Run training ... # Save v1 full checkpoint if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model # Get the corresponding translation function in smp v1 and translate if model_type == "gpt_neox": from smdistributed.modelparallel.torch.nn.huggingface.gptneox import translate_state_dict_to_hf_gptneox translated_state_dict = translate_state_dict_to_hf_gptneox(state_dict, max_seq_len=None) # Save the checkpoint checkpoint_path = "checkpoint.pt" if smp.rank() == 0: smp.save( {"model_state_dict": translated_state_dict}, checkpoint_path, partial=False, )

To find available translation functions in SMP v1, see Support for Hugging Face Transformer Models.

For instruction on model checkpoints saving and loading in SMP v2, see Save and load checkpoints while using SMP.