FlashAttention - Amazon SageMaker

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

FlashAttention

SMP v2 unterstützt FlashAttentionKernel und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Beachten Sie, dass SMP FlashAttention v2 verwendet, wenn Sie FlashAttention Paket v2.0 oder höher verwenden. Triton Flash Attention verwendet jedoch standardmäßig den Flash Attention-Kernel in FlashAttention v1.x, sodass er ausschließlich in Version 1 unterstützt wird. FlashAttention

Das Modul (nn.Module) ist eine Low-Level-API, die die Aufmerksamkeitsebenen eines Modells definiert. Es sollte direkt nach der Modellerstellung angewendet werden, beispielsweise über die AutoModelForCausalLM.from_config() API, und bevor das Modell transformiert oder mit FSDP umschlossen wird.

Benutze FlashAttention Kernel zur Selbstaufmerksamkeit

Der folgende Codeausschnitt zeigt, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashSelfAttention API verwendet wird.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Verwenden Sie FlashAttention Kernel für die Bearbeitung von Gruppenabfragen

SMP v2 unterstützt auch FlashAttentionKernel für Grouped-Query Attention (GQA) und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Im Unterschied zur ursprünglichen Attention-Architektur partitioniert GQA Abfrageköpfe gleichermaßen in Gruppen, und Abfrageköpfe in derselben Gruppe verwenden dieselben Schlüssel- und Wertüberschriften. Daher werden Q- und KV-Heads getrennt an Forward Call übergeben. Hinweis: Die Anzahl der Q-Köpfe muss durch die Anzahl der kv-Köpfe teilbar sein.

Beispiel für die Verwendung FlashGroupedQueryAttention

Der folgende Codeausschnitt zeigt, wie die von SMP v2 bereitgestellte torch.sagemaker.nn.attn.FlashGroupedQueryAttention API verwendet wird.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

Die SMP-Bibliothek bietet auchtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, die die torch.sagemaker.nn.attn.FlashGroupedQueryAttention API auf niedriger Ebene verwendet. Hugging Face Transformers hat eine ähnliche Implementierung, die LlamaFlashAttention2ab Version 4.36.0 aufgerufen wird. Der folgende Codeausschnitt zeigt, wie die SMP LlamaFlashAttention v2-API oder die Transformers-API verwendet werden, um die Aufmerksamkeitsebenen eines LlamaFlashAttention2 vorhandenen Lama-Modells zu ersetzen.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))