选择您的 Cookie 首选项

我们使用必要 Cookie 和类似工具提供我们的网站和服务。我们使用性能 Cookie 收集匿名统计数据,以便我们可以了解客户如何使用我们的网站并进行改进。必要 Cookie 无法停用,但您可以单击“自定义”或“拒绝”来拒绝性能 Cookie。

如果您同意,AWS 和经批准的第三方还将使用 Cookie 提供有用的网站功能、记住您的首选项并显示相关内容,包括相关广告。要接受或拒绝所有非必要 Cookie,请单击“接受”或“拒绝”。要做出更详细的选择,请单击“自定义”。

FlashAttention

聚焦模式
FlashAttention - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

SMP v2 支持FlashAttention内核,可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。请注意,如果您使用 v2.0 或更高版本的 FlashAttention 软件包,SMP 使用 FlashAttention v2;但是,在 v FlashAttention 1.x 中,Triton 闪光注意力默认为闪光注意内核,因此在 v1 中仅支持该内核。 FlashAttention

模块 (nn.Module) 是一种低级 API,用于定义模型的注意层。它应在模型创建后立即应用,例如从 AutoModelForCausalLM.from_config() API,并在使用 FSDP 对模型进行转换或封装之前应用。

使用 FlashAttention 内核来集中注意力

下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashSelfAttention API。

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

使用 FlashAttention 内核进行分组查询注意

SMP v2 还支持用于分组查询注意力 (GQA) 的FlashAttention内核,并且可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。与最初的注意架构不同,GQA 将查询磁头平均分为若干组,同一组中的查询磁头共享相同的键和值磁头。因此,q 和 kv 磁头被分别传入前向调用。注意:q 磁头的数量需要可以被 kv 磁头的数量整除。

使用示例 FlashGroupedQueryAttention

下面的代码片段显示了如何使用 SMP v2 提供的 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

SMP 库还提供 torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention,它在低级别使用 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API。Hugging Face 转换器在 4.36.0 版中也有一个名为 LlamaFlashAttention2 的类似实现。下面的代码片段显示了如何使用 SMP v2 LlamaFlashAttention API 或转换器 LlamaFlashAttention2 API 替换现有 Llama 模型的注意层。

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))
隐私网站条款Cookie 首选项
© 2025, Amazon Web Services, Inc. 或其附属公司。保留所有权利。