Sélectionner vos préférences de cookies

Nous utilisons des cookies essentiels et des outils similaires qui sont nécessaires au fonctionnement de notre site et à la fourniture de nos services. Nous utilisons des cookies de performance pour collecter des statistiques anonymes afin de comprendre comment les clients utilisent notre site et d’apporter des améliorations. Les cookies essentiels ne peuvent pas être désactivés, mais vous pouvez cliquer sur « Personnaliser » ou « Refuser » pour refuser les cookies de performance.

Si vous êtes d’accord, AWS et les tiers approuvés utiliseront également des cookies pour fournir des fonctionnalités utiles au site, mémoriser vos préférences et afficher du contenu pertinent, y compris des publicités pertinentes. Pour accepter ou refuser tous les cookies non essentiels, cliquez sur « Accepter » ou « Refuser ». Pour effectuer des choix plus détaillés, cliquez sur « Personnaliser ».

FlashAttention

Mode de mise au point
FlashAttention - Amazon SageMaker AI

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

SMP v2 prend en charge FlashAttentionles noyaux et permet de les appliquer facilement à différents scénarios pour les modèles Hugging Face Transformer. Notez que si vous utilisez le FlashAttention package v2.0 ou une version ultérieure, SMP utilise la version FlashAttention v2 ; toutefois, le Triton Flash Attention utilise par défaut le noyau Flash Attention dans la FlashAttention version v1.x, ce qui le rend exclusivement pris en charge dans la version v1. FlashAttention

Le module (nn.Module) est une API de bas niveau qui définit les couches d'attention d'un modèle. Il doit être appliqué juste après la création du modèle, à partir de l'AutoModelForCausalLM.from_config()API par exemple, et avant que le modèle ne soit transformé ou encapsulé avec FSDP.

Utilisez des FlashAttention noyaux pour vous concentrer

L'extrait de code suivant montre comment utiliser l'torch.sagemaker.nn.attn.FlashSelfAttentionAPI fournie par SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Utiliser des FlashAttention noyaux pour attirer l'attention sur les requêtes groupées

SMP v2 prend également en charge les FlashAttentionnoyaux pour l'attention par requêtes groupées (GQA) et permet de les appliquer facilement à différents scénarios pour les modèles Hugging Face Transformer. Contrairement à l'architecture d'attention originale, GQA divise également les têtes de requête en groupes, et les têtes de requête d'un même groupe partagent les mêmes têtes de clé et de valeur. Par conséquent, les têtes q et kv sont transmises séparément à l'appel direct. Remarque : Le nombre de têtes q doit être divisible par le nombre de têtes kv.

Exemple d'utilisation FlashGroupedQueryAttention

L'extrait de code suivant montre comment utiliser l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI fournie par SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

La bibliothèque SMP fournit égalementtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, qui utilise l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI à bas niveau. Hugging Face Transformers a une implémentation similaire LlamaFlashAttention2appelée v4.36.0. L'extrait de code suivant montre comment utiliser l'API SMP v2 ou l'LlamaFlashAttentionAPI Transformers LlamaFlashAttention2 pour remplacer les couches d'attention d'un modèle de lama existant.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))

Rubrique précédente :

Affinement
ConfidentialitéConditions d'utilisation du sitePréférences de cookies
© 2025, Amazon Web Services, Inc. ou ses affiliés. Tous droits réservés.