FlashAttention - Amazon SageMaker

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

FlashAttention

SMP v2 admite FlashAttentionnúcleos y facilita su aplicación a varios escenarios para los modelos Hugging Face Transformer. Tenga en cuenta que si usa el FlashAttention paquete v2.0 o posterior, SMP usa la FlashAttention v2; sin embargo, el núcleo Flash Attention de Triton utiliza de forma predeterminada el núcleo Flash Attention en la versión FlashAttention 1.x, por lo que es compatible exclusivamente con la versión 1. FlashAttention

El módulo (nn.Module) es una API de bajo nivel que define las capas de atención de un modelo. Debe aplicarse inmediatamente después de la creación del modelo, desde la AutoModelForCausalLM.from_config() API, por ejemplo, y antes de transformar o empaquetar el modelo con el FSDP.

Usa los FlashAttention núcleos para centrarte en ti mismo

El siguiente fragmento de código muestra cómo utilizar la torch.sagemaker.nn.attn.FlashSelfAttention API proporcionada por SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Usa los FlashAttention núcleos para la atención de las consultas agrupadas

SMP v2 también admite FlashAttentionnúcleos para la atención por consultas agrupadas (GQA) y facilita su aplicación a varios escenarios para los modelos de Hugging Face Transformer. A diferencia de la arquitectura de atención original, GQA divide en partes iguales los encabezados de consulta en grupos, y los encabezados de consulta del mismo grupo comparten los mismos encabezados clave y de valor. Por lo tanto, las cabeceras q y kv se pasan a la llamada directa por separado. Nota: El número de cabezas q debe ser divisible por el número de cabezas kv.

Ejemplo de uso FlashGroupedQueryAttention

El siguiente fragmento de código muestra cómo utilizar la torch.sagemaker.nn.attn.FlashGroupedQueryAttention API proporcionada por SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

La biblioteca SMP también proporcionatorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, que utiliza la torch.sagemaker.nn.attn.FlashGroupedQueryAttention API en un nivel bajo. Hugging Face Transformers tiene una implementación similar LlamaFlashAttention2llamada desde la v4.36.0. El siguiente fragmento de código muestra cómo usar la API SMP v2 o la LlamaFlashAttention API Transformers para reemplazar las capas de atención de un modelo de LlamaFlashAttention2 Llama existente.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))