SMP の使用中にチェックポイントを保存してロードする - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

SMP の使用中にチェックポイントを保存してロードする

SMP ライブラリはチェックポイントの PyTorch APIs をサポートし、SMP ライブラリの使用中にチェックポイントを適切に行うのに役立つ APIs を提供します。

PyTorch FSDP は、フルチェックポイント、シャードチェックポイント、ローカルチェックポイントの 3 つのタイプをサポートしています。これらはさまざまな目的を果たします。フルチェックポイントは、トレーニング終了後にモデルをエクスポートする場合にのみ使用するのが理想的です。フルチェックポイントを生成するとコストがかかるためです。シャーディングチェックポイントは、トレーニング中にチェックポイントを保存およびロードするための推奨アプローチです。シャードチェックポイントを使用すると、トレーニングを再開するときにクラスターサイズを変更することもできます。ローカルチェックポイントはより制限されています。ローカルチェックポイントでは、同じ数の GPUs、SMP でテンソル並列処理を使用する場合、現在サポートされていません。FSDP によるチェックポイントは、FSx などの共有ネットワークファイルシステムに書き込む必要があることに注意してください。

シャーディングチェックポイント

次の手順では、SMP テンソル並列処理機能の有無にかかわらず、シャーディングチェックポイントを保存およびロードするためにトレーニングスクリプトを適応させるために必要なことに焦点を当てます。

  1. SMP torch.sagemakerパッケージをインポートします。

    import torch.sagemaker as tsm
  2. チェックポイントを保存およびロードするための補助変数を設定します。

    1. などの通信的な集合演算を実行するためのコーディネーターランクを設定しますAllReduce

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. torch.sagemaker.state 列挙を使用してアクションランクを設定し、ランクをチェックポイントに参加させるかどうかを決定します。また、SMP v2 テンソル並列処理の使用に応じて、チェックポイントを保存するための if ステートメントを追加します。

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. PyTorch FSDP チェックポイント APIsをそのまま使用します。

次のコード例は、 PyTorch FSDP チェックポイント APIs を使用した完全な FSDP トレーニングスクリプトを示しています。

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

完全なモデルチェックポイント

トレーニングの最後に、モデルのすべてのシャードを 1 つのモデルチェックポイントファイルに結合する完全なチェックポイントを保存できます。SMP ライブラリは PyTorch 完全なモデルチェックポイント API を完全にサポートしているため、変更を加える必要はありません。

SMP を使用する場合テンソル並列性、SMP ライブラリはモデルを変換することに注意してください。この場合、モデル全体をチェックポイントすると、SMP ライブラリはモデルをデフォルトで Hugging Face Transformers チェックポイント形式に変換します。

SMP テンソル並列処理を使用してトレーニングし、SMP 翻訳プロセスをオフにする場合は、 PyTorch FullStateDictConfig API の translate_on_save引数を使用して、必要に応じて SMP 自動翻訳のオンとオフを切り替えることができます。例えば、モデルのトレーニングに集中している場合、オーバーヘッドを増やす翻訳プロセスを追加する必要はありません。その場合は、 を設定することをお勧めしますtranslate_on_save=False。また、今後さらにトレーニングするためにモデルの SMP 変換を引き続き使用する予定の場合は、オフにして、後で使用するためにモデルの SMP 変換を保存できます。モデルのトレーニングをまとめ、それを推論に使用する場合は、モデルを Hugging Face Transformers モデルチェックポイント形式に戻す必要があります。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

オプションFullStateDictConfig(rank0_only=True, offload_to_cpu=True)は、0 ランクデバイスの CPU でモデルを収集して、大規模なモデルをトレーニングするときにメモリを節約することです。

推論のためにモデルをロードし直すには、次のコード例に示すようにロードします。クラスは、モデルによっては、 などAutoModelForSeq2SeqLM、Hugging Face Transformer の他の要素ビルダークラスに変更AutoModelForCausalLMされる可能性があることに注意してください。詳細については、「Hugging Face Transformers ドキュメント」を参照してください。

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)