FlashAttention - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

FlashAttention

SMP v2 mendukung FlashAttentionkernel dan membuatnya mudah untuk menerapkannya ke berbagai skenario untuk model Hugging Face Transformer. Perhatikan bahwa jika Anda menggunakan FlashAttention paket v2.0 atau yang lebih baru, SMP menggunakan FlashAttention v2; Namun, perhatian flash Triton default ke kernel perhatian flash di FlashAttention v1.x, membuatnya didukung secara eksklusif di v1. FlashAttention

Module (nn.Module) adalah API tingkat rendah yang mendefinisikan lapisan perhatian model. Ini harus diterapkan tepat setelah pembuatan model, dari AutoModelForCausalLM.from_config() API misalnya, dan sebelum model diubah atau dibungkus dengan FSDP.

Gunakan FlashAttention kernel untuk perhatian diri

Cuplikan kode berikut menunjukkan cara menggunakan torch.sagemaker.nn.attn.FlashSelfAttention API yang disediakan oleh SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Gunakan FlashAttention kernel untuk perhatian kueri yang dikelompokkan

SMP v2 juga mendukung FlashAttentionkernel untuk grouped-query attention (GQA) dan membuatnya mudah untuk menerapkannya ke berbagai skenario untuk model Hugging Face Transformer. Berbeda dari arsitektur perhatian asli, GQA sama-sama mempartisi kepala kueri ke dalam grup, dan kepala kueri dalam grup yang sama berbagi kunci dan kepala nilai yang sama. Oleh karena itu, kepala q dan kv diteruskan ke panggilan maju secara terpisah. Catatan: Jumlah kepala q harus habis dibagi dengan jumlah kepala kv.

Contoh penggunaan FlashGroupedQueryAttention

Cuplikan kode berikut menunjukkan cara menggunakan torch.sagemaker.nn.attn.FlashGroupedQueryAttention API yang disediakan oleh SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

Pustaka SMP juga menyediakantorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, yang menggunakan torch.sagemaker.nn.attn.FlashGroupedQueryAttention API pada tingkat rendah. Hugging Face Transformers memiliki LlamaFlashAttention2implementasi serupa yang disebut dari v4.36.0. Cuplikan kode berikut menunjukkan cara menggunakan API SMP v2 atau Transformers LlamaFlashAttention LlamaFlashAttention2 API untuk mengganti lapisan perhatian model Llama yang ada.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))