Aktivasi Checkpointing - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Aktivasi Checkpointing

Activation checkpointing (atau gradient checkpointing) adalah teknik untuk mengurangi penggunaan memori dengan membersihkan aktivasi lapisan tertentu dan mengkomputernya kembali selama backward pass. Secara efektif, ini memperdagangkan waktu komputasi ekstra untuk mengurangi penggunaan memori. Jika modul diperiksa, di akhir pass maju, input ke dan output dari modul tetap berada di memori. Tensor perantara apa pun yang akan menjadi bagian dari perhitungan di dalam modul itu dibebaskan selama pass maju. Selama lintasan mundur modul checkpoint, tensor ini dihitung ulang. Pada titik ini, lapisan di luar modul checkpointed ini telah menyelesaikan backward pass mereka, sehingga penggunaan memori puncak dengan checkpointing bisa lebih rendah.

catatan

Fitur ini tersedia untuk PyTorch di pustaka paralelisme SageMaker model v1.6.0 dan yang lebih baru.

Cara Menggunakan Checkpointing Aktivasi

Dengansmdistributed.modelparallel, Anda dapat menggunakan pos pemeriksaan aktivasi pada perincian modul. Untuk semua torch.nn modul kecualitorch.nn.Sequential, Anda hanya dapat memeriksa pohon modul jika terletak dalam satu partisi dari perspektif paralelisme pipa. Dalam kasus torch.nn.Sequential modul, setiap pohon modul di dalam modul sekuensial harus terletak sepenuhnya dalam satu partisi agar pos pemeriksaan aktivasi berfungsi. Saat Anda menggunakan partisi manual, perhatikan batasan ini.

Saat Anda menggunakan partisi model otomatis, Anda dapat menemukan log tugas partisi yang dimulai dengan Partition assignments: di log pekerjaan pelatihan. Jika modul dipartisi di beberapa peringkat (misalnya, dengan satu keturunan pada satu peringkat dan keturunan lain pada peringkat yang berbeda), perpustakaan mengabaikan upaya untuk memeriksa modul dan memunculkan pesan peringatan bahwa modul tidak akan diperiksa.

catatan

Pustaka paralelisme SageMaker model mendukung operasi yang tumpang tindih dan tidak tumpang tindih dalam kombinasi dengan pos allreduce pemeriksaan.

catatan

PyTorchAPI checkpointing asli tidak kompatibel dengan. smdistributed.modelparallel

Contoh 1: Kode contoh berikut menunjukkan cara menggunakan checkpointing aktivasi ketika Anda memiliki definisi model dalam skrip Anda.

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

Contoh 2: Kode contoh berikut menunjukkan cara menggunakan checkpointing aktivasi ketika Anda memiliki model sekuensial dalam skrip Anda.

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

Contoh 3: Contoh kode berikut menunjukkan cara menggunakan checkpointing aktivasi saat Anda mengimpor model bawaan dari pustaka, seperti dan PyTorch Hugging Face Transformers. Apakah Anda memeriksa modul sekuensial atau tidak, lakukan hal berikut:

  1. Bungkus model dengansmp.DistributedModel().

  2. Tentukan objek untuk lapisan berurutan.

  3. Bungkus objek layer sekuensial dengansmp.set_activation_checkpointig().

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')