本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
SMP v2 支持FlashAttention
模块 (nn.Module
) 是一种低级 API,用于定义模型的注意层。它应在模型创建后立即应用,例如从 AutoModelForCausalLM.from_config()
API,并在使用 FSDP 对模型进行转换或封装之前应用。
使用 FlashAttention 内核来集中注意力
下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashSelfAttention API。
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
return (
self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
None,
)
for layer in model.gpt_neox.layers:
layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
layer.attention._attn = functools.partial(new_attn, layer.attention)
使用 FlashAttention 内核进行分组查询注意
SMP v2 还支持用于分组查询注意力 (GQA) 的FlashAttention
使用示例 FlashGroupedQueryAttention
下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention
class LlamaFlashAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.flash_attn = FlashGroupedQueryAttention(
attention_dropout_prob=0.0,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
...
):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
...
kv = (key_states, value_states)
attn_output = self.flash_attn(
query_states,
kv,
attn_mask=attention_mask,
causal=True,
layout="b h s d",
)
...
attn_output = self.o_proj(attn_output)
...
return attn_output
SMP 库还提供 torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它在低级别使用 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。Hugging Face 转换器在 4.36.0 版中也有一个名为 LlamaFlashAttention2
LlamaFlashAttention
API 或转换器 LlamaFlashAttention2
API 替换现有 Llama 模型的注意层。
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2
attn_name = "self_attn"
for layer in model.model.layers:
prev_layer = getattr(layer, attn_name)
setattr(layer, attn_name, flash_attn_class(model.config))