Pos pemeriksaan aktivasi - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Pos pemeriksaan aktivasi

Checkpointing aktivasi adalah teknik untuk mengurangi penggunaan memori dengan membersihkan aktivasi lapisan tertentu dan mengkomputernya kembali selama backward pass. Secara efektif, ini memperdagangkan waktu komputasi ekstra untuk mengurangi penggunaan memori. Jika modul diperiksa, di akhir pass maju, hanya input awal ke modul dan output akhir dari modul yang tetap berada di memori. PyTorch melepaskan tensor perantara apa pun yang merupakan bagian dari perhitungan di dalam modul itu selama pass maju. Selama lintasan mundur modul checkpoint, PyTorch hitung ulang tensor ini. Pada titik ini, lapisan di luar modul checkpointed ini telah menyelesaikan backward pass mereka, sehingga penggunaan memori puncak dengan checkpointing menjadi lebih rendah.

SMP v2 mendukung modul checkpointing PyTorch aktivasi,. apply_activation_checkpointing Berikut ini adalah contoh checkpointing aktivasi model Hugging Face GPT-Neox.

Lapisan Checkpointing Transformer dari model Hugging Face GPT-Neox

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Checkpointing setiap lapisan Transformer lainnya dari model Hugging Face GPT-Neox

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

Atau, PyTorch juga memiliki torch.utils.checkpoint modul untuk checkpointing, yang digunakan oleh subset model Hugging Face Transformers. Modul ini juga bekerja dengan SMP v2. Namun, ini mengharuskan Anda untuk memiliki akses ke definisi model untuk menambahkan pembungkus pos pemeriksaan. Karena itu, kami sarankan Anda untuk menggunakan apply_activation_checkpointing metode ini.