Support for FlashAttention - Amazon SageMaker

Support for FlashAttention

Support for FlashAttention is a feature of the library only applicable for the distributed transformer model, which is a Transformer model wrapped by smp.DistributedModel() for model-parallel training. This feature is also compatible with Tensor Parallelism.

The FlashAttention library only supports models when attention_head_size is set to a value that's a multiple of 8 and less than 128. Therefore, when you train a distributed transformer and make sure that FlashAttention works properly, you should adjust parameters to make the attention head size comply the requirements. For more information, see also Installation and features in the FlashAttention GitHub repository.

For example, assume that you configure a Transformer model with hidden_width=864 and num_heads=48. The head size of FlashAttention is calculated as attention_head_size = hidden_width / num_heads = 864 / 48 = 18. To enable FlashAttention, you need to adjust the num_heads parameter to 54, so that attention_head_size = hidden_width / num_heads = 864 / 54 = 16, which is a multiple of 8.