FlashAttention
SMP v2 supports FlashAttention
The module (nn.Module
) is a low level API that defines the attention
layers of a model. It should be applied right after model creation, from the
AutoModelForCausalLM.from_config()
API for example, and before the
model is being transformed or wrapped with FSDP.
Use FlashAttention kernels for self attention
The following code snippet shows how to use the torch.sagemaker.nn.attn.FlashSelfAttention API provided by SMP v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Use FlashAttention kernels for grouped-query attention
SMP v2 also supports FlashAttention
Example of using FlashGroupedQueryAttention
The following code snippet shows how to use the torch.sagemaker.nn.attn.FlashGroupedQueryAttention API provided by SMP v2.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
The SMP library also provides torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, which
uses the torch.sagemaker.nn.attn.FlashGroupedQueryAttention
API at low level. Hugging Face Transformers has a similar implementation called
LlamaFlashAttention2
LlamaFlashAttention
API or the
Transformers LlamaFlashAttention2
API to replace the attention layers
of an existing Llama model.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))