FlashAttention - 아마존 SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

FlashAttention

SMP v2는 FlashAttention커널을 지원하므로 Hugging Face Transformer 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. 참고로 FlashAttention 패키지 v2.0 이상을 사용하는 경우 SMP는 FlashAttention v2를 사용하지만, 트리톤 플래시 어텐션은 FlashAttention v1.x의 플래시 어텐션 커널을 기본으로 사용하므로 v1에서만 독점적으로 지원됩니다. FlashAttention

모듈 (nn.Module) 은 모델의 어텐션 레이어를 정의하는 저수준 API입니다. 예를 들어 AutoModelForCausalLM.from_config() API에서 모델을 생성한 직후, 그리고 FSDP로 모델을 변환하거나 래핑하기 전에 적용해야 합니다.

FlashAttention 커널은 자기 주의를 끌기 위해 사용하세요.

다음 코드 스니펫은 SMP v2에서 제공하는 torch.sagemaker.nn.attn.FlashSelfAttention API를 사용하는 방법을 보여줍니다.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

그룹화된 쿼리 주의를 위해 FlashAttention 커널을 사용하세요.

또한 SMP v2는 GQA (그룹 쿼리 어텐션) 용 FlashAttention커널을 지원하므로 Hugging Face Transformer 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. 오리지널 어텐션 아키텍처와 달리 GQA는 쿼리 헤드를 그룹으로 균등하게 분할하고, 같은 그룹의 쿼리 헤드는 동일한 키 및 값 헤드를 공유합니다. 따라서 q 헤드와 kv 헤드는 별도로 전달 호출로 전달됩니다. 참고: q 헤드의 수는 kv 헤드의 수로 나눌 수 있어야 합니다.

사용 예시 FlashGroupedQueryAttention

다음 코드 스니펫은 SMP v2에서 제공하는 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API를 사용하는 방법을 보여줍니다.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

SMP 라이브러리는 낮은 수준에서 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API를 사용하는 기능도 제공합니다torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention. Hugging Face Transformer에는 v4.36.0에서 LlamaFlashAttention2호출된 유사한 구현이 있습니다. 다음 코드 스니펫은 SMP v2 LlamaFlashAttention API 또는 트랜스포머 LlamaFlashAttention2 API를 사용하여 기존 라마 모델의 어텐션 레이어를 대체하는 방법을 보여줍니다.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))