Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.
FlashAttention
SMPLa version v2 prend en charge FlashAttention
Le module (nn.Module
) est un niveau bas API qui définit les couches d'attention d'un modèle. Il doit être appliqué juste après la création du modèle, à partir du par AutoModelForCausalLM.from_config()
API exemple, et avant que le modèle ne soit transformé ou encapsuléFSDP.
Utilisez des FlashAttention noyaux pour vous concentrer
L'extrait de code suivant montre comment utiliser le code torch.sagemaker.nn.attn.FlashSelfAttention API fourni par SMP la v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Utiliser des FlashAttention noyaux pour attirer l'attention sur les requêtes groupées
SMPLa version v2 prend également en charge FlashAttention
Exemple d'utilisation FlashGroupedQueryAttention
L'extrait de code suivant montre comment utiliser le code torch.sagemaker.nn.attn.FlashGroupedQueryAttention API fourni par SMP la v2.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
La SMP bibliothèque fournit égalementtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, qui utilise le torch.sagemaker.nn.attn.FlashGroupedQueryAttention API à bas niveau. Hugging Face Transformers a une implémentation similaire LlamaFlashAttention2
LlamaFlashAttention
API ou les Transformers LlamaFlashAttention2
API pour remplacer les couches d'attention d'un modèle de lama existant.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))