기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.
FlashAttention
SMP v2는 FlashAttention
모듈 (nn.Module
) 은 모델의 어텐션 레이어를 정의하는 저수준 API입니다. 예를 들어 AutoModelForCausalLM.from_config()
API에서 모델을 생성한 직후, 그리고 FSDP로 모델을 변환하거나 래핑하기 전에 적용해야 합니다.
FlashAttention 커널은 자기 주의를 끌기 위해 사용하세요.
다음 코드 스니펫은 SMP v2에서 제공하는 torch.sagemaker.nn.attn.FlashSelfAttention API를 사용하는 방법을 보여줍니다.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
그룹화된 쿼리 주의를 위해 FlashAttention 커널을 사용하세요.
또한 SMP v2는 GQA (그룹 쿼리 어텐션) 용 FlashAttention
사용 예시 FlashGroupedQueryAttention
다음 코드 스니펫은 SMP v2에서 제공하는 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API를 사용하는 방법을 보여줍니다.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
SMP 라이브러리는 낮은 수준에서 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API를 사용하는 기능도 제공합니다torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention. Hugging Face Transformer에는 v4.36.0에서 LlamaFlashAttention2
LlamaFlashAttention
API 또는 트랜스포머 LlamaFlashAttention2
API를 사용하여 기존 라마 모델의 어텐션 레이어를 대체하는 방법을 보여줍니다.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))