Support FlashAttention - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

Support FlashAttention

Support FlashAttention 是僅適用於分散式變壓器模型的程式庫功能,分散式變壓器模型是用於 parallel 模型訓練所包裝smp.DistributedModel()的變壓器模型。此功能也相容 張量平行處理

只有當設定attention_head_size為 8 的倍數且小於 128 的值時,程式FlashAttention庫才支援模型。因此,當您訓練分散式變壓器並確保正常 FlashAttention 工作時,應該調整參數以使注意頭尺寸符合要求。如需詳細資訊,另請參閱FlashAttention GitHub儲存庫中的安裝和功能

例如,假設您使用 hidden_width=864num_heads=48 設定轉換器模型。的頭部大小計 FlashAttention 算方式為attention_head_size = hidden_width / num_heads = 864 / 48 = 18。要啟用 FlashAttention,您需要將num_heads參數調整為 54attention_head_size = hidden_width / num_heads = 864 / 54 = 16,以便是 8 的倍數。