FlashAttention - Amazon SageMaker

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

FlashAttention

SMP v2 supporta i FlashAttentionkernel e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. Nota che se utilizzi il FlashAttention pacchetto v2.0 o successivo, SMP utilizza la FlashAttention v2; tuttavia, Triton flash attention utilizza per impostazione predefinita il kernel flash attention nella v1.x, rendendolo supportato esclusivamente nella v1. FlashAttention FlashAttention

Il module (nn.Module) è un'API di basso livello che definisce i livelli di attenzione di un modello. Dovrebbe essere applicato subito dopo la creazione del modello, ad esempio dall'AutoModelForCausalLM.from_config()API, e prima che il modello venga trasformato o confezionato con FSDP.

Usa i FlashAttention kernel per l'attenzione personale

Il seguente frammento di codice mostra come utilizzare l'torch.sagemaker.nn.attn.FlashSelfAttentionAPI fornita da SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Usa i kernel per attirare l'attenzione sulle query raggruppate FlashAttention

SMP v2 supporta anche i FlashAttentionkernel for grouped-query attention (GQA) e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. A differenza dell'architettura di attenzione originale, GQA suddivide equamente le testine di interrogazione in gruppi e le testine di query dello stesso gruppo condividono le stesse chiavi e valori. Pertanto, le testine q e kv vengono passate separatamente alla chiamata in avanti. Nota: il numero di teste q deve essere divisibile per il numero di teste kv.

Esempio di utilizzo FlashGroupedQueryAttention

Il seguente frammento di codice mostra come utilizzare l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI fornita da SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

Fornisce anche la libreria SMPtorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, che utilizza l'torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI a basso livello. Hugging Face Transformers ha un'implementazione simile chiamata dalla v4.36.0. LlamaFlashAttention2 Il seguente frammento di codice mostra come utilizzare l'API SMP v2 o l'LlamaFlashAttentionAPI LlamaFlashAttention2 Transformers per sostituire i livelli di attenzione di un modello Llama esistente.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))