Simpan dan muat pos pemeriksaan saat menggunakan SMP - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Simpan dan muat pos pemeriksaan saat menggunakan SMP

Pustaka SMP mendukung PyTorch API untuk pos pemeriksaan, dan menyediakan API yang membantu pos pemeriksaan dengan benar saat menggunakan pustaka SMP.

PyTorch FSDP mendukung tiga jenis pos pemeriksaan: penuh, sharded dan lokal. Ini melayani tujuan yang berbeda. Pos pemeriksaan penuh idealnya hanya digunakan saat mengekspor model setelah pelatihan selesai, karena mahal untuk menghasilkan pos pemeriksaan penuh. Pos pemeriksaan sharded adalah pendekatan yang direkomendasikan untuk menyimpan dan memuat pos pemeriksaan selama pelatihan. Menggunakan pos pemeriksaan sharded, Anda juga dapat mengubah ukuran cluster saat melanjutkan pelatihan. Pos pemeriksaan lokal lebih ketat. Dengan pos pemeriksaan lokal, Anda perlu melanjutkan pelatihan dengan jumlah GPU yang sama dan saat ini tidak didukung saat menggunakan paralelisme tensor dengan SMP. Perhatikan bahwa pos pemeriksaan oleh FSDP memerlukan penulisan ke sistem file jaringan bersama, seperti FSx.

Pos pemeriksaan sharded

Prosedur berikut menyoroti apa yang perlu Anda lakukan untuk menyesuaikan skrip pelatihan Anda untuk menyimpan dan memuat pos pemeriksaan sharded dengan atau tanpa fitur paralelisme tensor SMP.

  1. Impor torch.sagemaker paket SMP.

    import torch.sagemaker as tsm
  2. Siapkan variabel tambahan untuk menyimpan dan memuat pos pemeriksaan.

    1. Siapkan peringkat koordinator untuk melakukan operasi kolektif komunikatif seperti. AllReduce

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. Menggunakan torch.sagemaker.state enumerasi, atur peringkat tindakan untuk menentukan apakah akan membiarkan peringkat mengambil bagian dalam pemeriksaan. Dan tambahkan pernyataan if untuk menyimpan pos pemeriksaan tergantung pada penggunaan paralelisme tensor SMP v2.

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. Tetap gunakan API pos pemeriksaan PyTorch FSDP apa adanya.

Contoh kode berikut menunjukkan skrip pelatihan PyTorch FSDP lengkap dengan API pos pemeriksaan FSDP.

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

Pos pemeriksaan model lengkap

Di akhir pelatihan, Anda dapat menyimpan pos pemeriksaan lengkap yang menggabungkan semua pecahan model menjadi satu file pos pemeriksaan model. Pustaka SMP sepenuhnya mendukung API pos pemeriksaan model PyTorch lengkap, jadi Anda tidak perlu melakukan perubahan apa pun.

Perhatikan bahwa jika Anda menggunakan SMPParalelisme tensor, perpustakaan SMP mengubah model. Saat memeriksa model lengkap dalam kasus ini, pustaka SMP menerjemahkan model kembali ke format pos pemeriksaan Hugging Face Transformers secara default.

Jika Anda berlatih dengan paralelisme tensor SMP dan mematikan proses penerjemahan SMP, Anda dapat menggunakan translate_on_save argumen PyTorch FullStateDictConfig API untuk mengaktifkan atau menonaktifkan terjemahan otomatis SMP sesuai kebutuhan. Misalnya, jika Anda berfokus pada pelatihan model, Anda tidak perlu menambahkan proses terjemahan yang menambahkan overhead. Dalam hal ini, kami sarankan Anda untuk mengaturtranslate_on_save=False. Juga, jika Anda berencana untuk tetap menggunakan terjemahan SMP model untuk pelatihan lebih lanjut di masa depan, Anda dapat mematikannya untuk menyimpan terjemahan SMP model untuk digunakan nanti. Menerjemahkan model kembali ke format pos pemeriksaan model Hugging Face Transformers diperlukan saat Anda menyelesaikan pelatihan model Anda dan menggunakannya untuk inferensi.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Perhatikan bahwa pilihannya FullStateDictConfig(rank0_only=True, offload_to_cpu=True) adalah mengumpulkan model pada CPU perangkat peringkat 0 untuk menghemat memori saat melatih model besar.

Untuk memuat kembali model untuk inferensi, Anda melakukannya seperti yang ditunjukkan pada contoh kode berikut. Perhatikan bahwa kelas AutoModelForCausalLM mungkin berubah ke kelas pembuat faktor lain di Hugging Face Transformers, AutoModelForSeq2SeqLM seperti, tergantung pada model Anda. Untuk informasi selengkapnya, lihat dokumentasi Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)