Checkpointing dan Fine-Tuning Model dengan Paralelisme Model - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Checkpointing dan Fine-Tuning Model dengan Paralelisme Model

Pustaka paralelisme SageMaker model menyediakan API pos pemeriksaan untuk menyimpan status model dan status pengoptimal yang dibagi oleh berbagai strategi paralelisme model, dan untuk memuat pos pemeriksaan untuk pelatihan berkelanjutan dari tempat Anda ingin memulai ulang pelatihan dan menyempurnakan. API juga mendukung opsi untuk menyimpan status model dan pengoptimal sebagian atau seluruhnya.

Checkpointing model terdistribusi

Pilih salah satu topik berikut tergantung pada kerangka kerja antara PyTorch dan TensorFlow dan versi pustaka paralelisme SageMaker model yang Anda gunakan.

Checkpointing PyTorch model terdistribusi (untuk pustaka paralelisme SageMaker model v1.10.0 dan yang lebih baru)

Pustaka paralelisme SageMaker model menyediakan API pos pemeriksaan untuk menyimpan dan memuat pos pemeriksaan penuh atau sebagian dari status model terdistribusi dan status pengoptimalnya.

catatan

Metode checkpointing ini direkomendasikan jika Anda menggunakan PyTorch dan pustaka paralelisme SageMaker model v1.10.0 atau yang lebih baru.

Pos pemeriksaan sebagian

Untuk menyimpan pos pemeriksaan model yang dilatih dengan paralelisme model, gunakan smdistributed.modelparallel.torch.save_checkpointAPI dengan opsi checkpointing sebagian yang disetel ke true (). partial=True Ini menghemat setiap partisi model satu per satu. Selain model dan status pengoptimal, Anda juga dapat menyimpan data kustom tambahan melalui user_content argumen. Model checkpoint, optimizer, dan konten pengguna disimpan sebagai file terpisah. Panggilan save_checkpoint API membuat folder pos pemeriksaan dalam struktur berikut.

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

Untuk melanjutkan pelatihan dari pos pemeriksaan sebagian, gunakan smdistributed.modelparallel.torch.resume_from_checkpointAPI denganpartial=True, dan tentukan direktori pos pemeriksaan dan tag yang digunakan saat menyimpan pos pemeriksaan sebagian. Perhatikan bahwa pemuatan bobot model yang sebenarnya terjadi setelah partisi model, selama menjalankan pertama fungsi langkah pelatihan smdistributed.modelparallel.torch.step yang didekorasi.

Saat menyimpan pos pemeriksaan sebagian, perpustakaan juga menyimpan keputusan partisi model sebagai file dengan ekstensi .pt file. Sebaliknya, ketika melanjutkan dari pos pemeriksaan sebagian, perpustakaan memuat file keputusan partisi bersama-sama. Setelah keputusan partisi dimuat, Anda tidak dapat mengubah partisi.

Cuplikan kode berikut menunjukkan cara mengatur API pos pemeriksaan dalam skrip pelatihan. PyTorch

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

Checkpointing penuh

Untuk menyimpan artefak model akhir untuk tujuan inferensi, gunakan smdistributed.modelparallel.torch.save_checkpoint API denganpartial=False, yang menggabungkan partisi model untuk membuat artefak model tunggal. Perhatikan bahwa ini tidak menggabungkan status pengoptimal.

Untuk menginisialisasi pelatihan dengan bobot tertentu, dengan pos pemeriksaan model lengkap, Anda dapat menggunakan API dengansmdistributed.modelparallel.torch.resume_from_checkpoint. partial=False Perhatikan bahwa ini tidak memuat status pengoptimal.

catatan

Dengan paralelisme tensor, secara umum, state_dict harus diterjemahkan antara implementasi model asli dan implementasi. DistributedModel Secara opsional, Anda dapat memberikan fungsi state_dict terjemahan sebagai argumen untuk. smdistributed.modelparallel.torch.resume_from_checkpoint Namun, untukModel yang Didukung Di Luar Kotak, perpustakaan menangani terjemahan ini secara otomatis.

Kode berikut menunjukkan contoh cara menggunakan API pos pemeriksaan untuk memeriksa sepenuhnya model yang dilatih dengan PyTorch paralelisme model.

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

Checkpointing PyTorch model terdistribusi (untuk pustaka paralelisme SageMaker model antara v1.6.0 dan v1.9.0)

Pustaka paralelisme SageMaker model menyediakan fungsi Python untuk menyimpan pos pemeriksaan sebagian atau penuh untuk pekerjaan pelatihan dengan paralelisme tensor. Prosedur berikut menunjukkan cara menggunakan smp.save()dan menyimpan dan smp.load()memuat pos pemeriksaan saat Anda menggunakan paralelisme tensor.

catatan

Metode checkpointing ini direkomendasikan jika Anda menggunakan PyTorch,Paralelisme Tensor, dan pustaka paralelisme SageMaker model antara v1.6.0 dan v1.9.0.

  1. Siapkan objek model dan bungkus dengan fungsi smp.DistributedModel() pembungkus perpustakaan.

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Siapkan pengoptimal untuk model. Satu set parameter model adalah argumen iterable yang diperlukan oleh fungsi pengoptimal. Untuk menyiapkan satu set parameter model, Anda harus memproses model.parameters() untuk menetapkan ID unik ke parameter model individual.

    Jika ada parameter dengan ID duplikat dalam parameter model yang dapat diulang, memuat status pengoptimal checkpoint gagal. Untuk membuat parameter model yang dapat diulang dengan ID unik untuk pengoptimal Anda, lihat yang berikut ini:

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Bungkus pengoptimal menggunakan fungsi pembungkus perpustakaan. smp.DistributedOptimizer()

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Simpan model dan status pengoptimal menggunakan smp.save(). Bergantung pada bagaimana Anda ingin menyimpan pos pemeriksaan, pilih salah satu dari dua opsi berikut:

    • Opsi 1: Simpan sebagian model pada masing-masing mp_rank untuk satuMP_GROUP.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Dengan paralelisme tensor, perpustakaan menyimpan file yang ditunjuk periksa yang dinamai dalam format berikut:. checkpoint.pt_{pp_rank}_{tp_rank}

      catatan

      Dengan paralelisme tensor, pastikan Anda mengatur pernyataan if sebagai if smp.rdp_rank() == 0 pengganti. if smp.dp_rank() == 0 Saat status pengoptimal dibagi dengan paralelisme tensor, semua peringkat paralel data tereduksi harus menyimpan partisi mereka sendiri dari status pengoptimal. Menggunakan pernyataan if yang salah untuk checkpointing dapat mengakibatkan pekerjaan pelatihan yang terhenti. Untuk informasi selengkapnya tentang penggunaan paralelisme if smp.dp_rank() == 0 tanpa tensor, lihat Instruksi Umum untuk Menyimpan dan Memuat dalam dokumentasi Python SageMaker SDK.

    • Opsi 2: Simpan model lengkap.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      catatan

      Pertimbangkan hal berikut untuk pemeriksaan lengkap:

      • Jika Anda mengaturgather_to_rank0=True, semua peringkat selain 0 mengembalikan kamus kosong.

      • Untuk pos pemeriksaan penuh, Anda hanya dapat memeriksa model. Pemeriksaan penuh status pengoptimal saat ini tidak didukung.

      • Model lengkap hanya perlu disimpan dismp.rank() == 0.

  5. Muat pos pemeriksaan menggunakan smp.load(). Bergantung pada bagaimana Anda memeriksa pada langkah sebelumnya, pilih salah satu dari dua opsi berikut:

    • Opsi 1: Muat pos pemeriksaan sebagian.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Anda dapat mengatur same_partition_load=True model.load_state_dict() untuk beban yang lebih cepat, jika Anda tahu bahwa partisi tidak akan berubah.

    • Opsi 2: Muat pos pemeriksaan penuh.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      if smp.rdp_rank() == 0Kondisi ini tidak diperlukan, tetapi dapat membantu menghindari pemuatan berlebihan di antara s yang berbedaMP_GROUP. Dict status pengoptimal pos pemeriksaan penuh saat ini tidak didukung dengan paralelisme tensor.

Checkpointing model terdistribusi TensorFlow

Untuk menyimpan TensorFlow model saat berlatih dengan paralelisme model, gunakan fungsi berikut yang disediakan oleh perpustakaan paralelisme SageMaker model.

Menyetel model terdistribusi

Penyetelan halus perlu dikonfigurasi dalam skrip pelatihan Anda. Cuplikan kode berikut menunjukkan contoh struktur skrip pelatihan menggunakan kelas AutoModelForCausalLM dari Hugging Face Transformers dengan modifikasi untuk mendaftarkan smdistributed.model.parallel.torch modul dan pengaturan untuk fine-tuning.

catatan

Menyetel transformator terdistribusi (model Transformer yang dibungkus olehsmp.DistributedModel()) dengan fungsi smp.delayed_param_initialization diaktifkan memerlukan pekerjaan fine-tuning untuk dikonfigurasi dengan sistem file FSx for Lustre. Dalam kasus di mana Anda ingin menyempurnakan model skala besar dengan opsi inisialisasi parameter tertunda, Anda harus menyiapkan sistem file FSx for Lustre.

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

Untuk contoh lengkap skrip pelatihan dan buku catatan Jupyter, lihat contoh GPT-2 di repositori Contoh. PyTorch SageMaker GitHub