本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
FlashAttention
SMPv2 支持FlashAttention
模块 (nn.Module
) 是一个低级别API,用于定义模型的注意力层。它应该在模型创建之后立即应用,AutoModelForCausalLM.from_config()
API例如,在模型被转换或封装之前FSDP。
使用 FlashAttention 内核来集中注意力
以下代码片段显示了如何使用 SMP v2 torch.sagemaker.nn.attn.FlashSelfAttention API 提供的。
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
使用 FlashAttention 内核进行分组查询注意
SMPv2 还支持用于分组查询注意FlashAttention
使用示例 FlashGroupedQueryAttention
以下代码片段显示了如何使用 SMP v2 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API 提供的。
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
该SMP库还提供torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI在低级别使用。Hugging Face Transformers 也有类似的实现,LlamaFlashAttention2
LlamaFlashAttention
API 或 Transformers LlamaFlashAttention2
API 来替换现有 Llama 模型的注意力层。
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))