Support pour FlashAttention - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Support pour FlashAttention

Support de FlashAttention est une fonctionnalité de la bibliothèque applicable uniquement au modèle de transformateur distribué, qui est un modèle de transformateur intégré smp.DistributedModel()pour l'apprentissage parallèle entre modèles. Cette fonctionnalité est également compatible avec Parallélisme de tenseur.

La FlashAttentionbibliothèque ne prend en charge les modèles que lorsqu'elle attention_head_size est définie sur une valeur multiple de 8 et inférieure à 128. Par conséquent, lorsque vous entraînez un transformateur distribué et que vous vous assurez qu'il FlashAttention fonctionne correctement, vous devez ajuster les paramètres pour que la taille de la tête d'attention soit conforme aux exigences. Pour plus d'informations, voir également Installation et fonctionnalités du FlashAttention GitHubréférentiel.

Supposons, par exemple, que vous configurez un modèle Transformer avec hidden_width=864 et num_heads=48. La taille de la tête de FlashAttention est calculée comme suitattention_head_size = hidden_width / num_heads = 864 / 48 = 18. Pour l'activer FlashAttention, vous devez ajuster le num_heads paramètre à54, de sorte queattention_head_size = hidden_width / num_heads = 864 / 54 = 16, soit un multiple de 8.