Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.
FlashAttention
SMP v2 unterstützt FlashAttention
Das Modul (nn.Module
) ist eine Low-Level-API, die die Aufmerksamkeitsebenen eines Modells definiert. Es sollte direkt nach der Modellerstellung angewendet werden, beispielsweise über die AutoModelForCausalLM.from_config()
API, und bevor das Modell transformiert oder mit FSDP umschlossen wird.
Benutze FlashAttention Kernel zur Selbstaufmerksamkeit
Der folgende Codeausschnitt zeigt, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashSelfAttention API verwendet wird.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Verwenden Sie FlashAttention Kernel für die Bearbeitung von Gruppenabfragen
SMP v2 unterstützt auch FlashAttention
Beispiel für die Verwendung FlashGroupedQueryAttention
Der folgende Codeausschnitt zeigt, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashGroupedQueryAttention API verwendet wird.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
Die SMP-Bibliothek bietet auchtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, die die torch.sagemaker.nn.attn.FlashGroupedQueryAttention API auf niedriger Ebene verwendet. Hugging Face Transformers hat eine ähnliche Implementierung, die LlamaFlashAttention2
LlamaFlashAttention
v2-API oder die Transformers-API verwendet werden, um die Aufmerksamkeitsebenen eines LlamaFlashAttention2
vorhandenen Lama-Modells zu ersetzen.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))