Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.
SMP v2 prend en charge FlashAttention
Le module (nn.Module
) est une API de bas niveau qui définit les couches d'attention d'un modèle. Il doit être appliqué juste après la création du modèle, à partir de l'AutoModelForCausalLM.from_config()
API par exemple, et avant que le modèle ne soit transformé ou encapsulé avec FSDP.
Utilisez des FlashAttention noyaux pour vous concentrer
L'extrait de code suivant montre comment utiliser l'torch.sagemaker.nn.attn.FlashSelfAttentionAPI fournie par SMP v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
return (
self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
None,
)
for layer in model.gpt_neox.layers:
layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
layer.attention._attn = functools.partial(new_attn, layer.attention)
Utiliser des FlashAttention noyaux pour attirer l'attention sur les requêtes groupées
SMP v2 prend également en charge les FlashAttention
Exemple d'utilisation FlashGroupedQueryAttention
L'extrait de code suivant montre comment utiliser l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI fournie par SMP v2.
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention
class LlamaFlashAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.flash_attn = FlashGroupedQueryAttention(
attention_dropout_prob=0.0,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
...
):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
...
kv = (key_states, value_states)
attn_output = self.flash_attn(
query_states,
kv,
attn_mask=attention_mask,
causal=True,
layout="b h s d",
)
...
attn_output = self.o_proj(attn_output)
...
return attn_output
La bibliothèque SMP fournit égalementtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, qui utilise l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI à bas niveau. Hugging Face Transformers a une implémentation similaire LlamaFlashAttention2
LlamaFlashAttention
API Transformers LlamaFlashAttention2
pour remplacer les couches d'attention d'un modèle de lama existant.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2
attn_name = "self_attn"
for layer in model.model.layers:
prev_layer = getattr(layer, attn_name)
setattr(layer, attn_name, flash_attn_class(model.config))