アクティベーションチェックポイント - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

アクティベーションチェックポイント

アクティベーションチェックポイントは、特定のレイヤーのアクティベーションをクリアし、バックワードパス中に再計算することで、メモリ使用量を削減する手法です。これにより、メモリ使用量を削減するために、実質的に追加の計算時間が交換されます。モジュールがチェックポイントされている場合、フォワードパスの最後に、モジュールへの初期入力とモジュールからの最終出力のみがメモリに残ります。 はフォワードパス中に、そのモジュール内の計算の一部である中間テンソルを PyTorch 解放します。チェックポイントされたモジュールのバックワードパス中に、 はこれらのテンソル PyTorch を再計算します。この時点で、このチェックポイントされたモジュールを超えるレイヤーはバックワードパスを完了したため、チェックポイントによるピークメモリ使用量が低くなります。

SMP v2 はアクティベーション PyTorch チェックポイントモジュール をサポートしますapply_activation_checkpointing。Hugging Face GPT-NeoX モデルのアクティベーションチェックポイントの例を次に示します。

Hugging Face GPT-NeoX モデルのトランスフォーマーレイヤーのチェックポイント

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Hugging Face GPT-NeoX モデルの他のすべての Transformer レイヤーのチェックポイント

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

または、 にはチェックポイント用の torch.utils.checkpoint モジュール PyTorch もあります。これは Hugging Face Transformers モデルのサブセットで使用されます。このモジュールは SMP v2 でも動作します。ただし、チェックポイントラッパーを追加するためのモデル定義にアクセスする必要があります。したがって、 apply_activation_checkpointingメソッドを使用することをお勧めします。