Inisialisasi parameter tertunda - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Inisialisasi parameter tertunda

Inisialisasi model besar untuk pelatihan tidak selalu dimungkinkan dengan memori GPU yang terbatas. Untuk mengatasi masalah memori GPU yang tidak mencukupi ini, Anda dapat menginisialisasi model pada memori CPU. Namun, untuk model yang lebih besar dengan lebih dari 20 atau 40 miliar parameter, bahkan memori CPU mungkin tidak cukup. Untuk kasus seperti itu, kami menyarankan Anda menginisialisasi model pada apa yang PyTorch disebut perangkat meta, yang memungkinkan pembuatan tensor tanpa data apa pun yang melekat padanya. Tensor pada perangkat meta hanya membutuhkan informasi bentuk, dan ini memungkinkan untuk membuat model besar dengan parameternya pada perangkat meta. Hugging Face Accelerate menyediakan init_empty_weights manajer konteks untuk membantu membuat model seperti itu pada perangkat meta sambil menginisialisasi buffer pada perangkat biasa. Sebelum pelatihan dimulai, PyTorch FSDP menginisialisasi parameter model. Fitur inisialisasi parameter tertunda SMP v2 ini menunda pembuatan parameter model ini terjadi setelah PyTorch FSDP melakukan sharding parameter. PyTorch FSDP menerima fungsi inisialisasi parameter (param_init_fn) saat sharding modul, dan memanggil setiap modul. param_init_fn param_init_fnAPI mengambil modul sebagai argumen dan menginisialisasi semua parameter di dalamnya, tidak termasuk parameter modul anak mana pun. Perhatikan bahwa perilaku ini berbeda dari PyTorch v2.0.1 asli yang memiliki bug yang menyebabkan parameter diinisialisasi beberapa kali.

SMP v2 menyediakan torch.sagemaker.delayed_param.DelayedParamIniter API untuk menerapkan inisialisasi parameter tertunda.

Cuplikan kode berikut menunjukkan cara menerapkan torch.sagemaker.delayed_param.DelayedParamIniter API ke skrip pelatihan Anda.

Asumsikan bahwa Anda memiliki skrip pelatihan PyTorch FSDP sebagai berikut.

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

Perhatikan bahwa pendekatan inisialisasi parameter tertunda bukanlah model agnostik. Untuk mengatasi masalah ini, Anda perlu menulis init_weights fungsi seperti yang ditunjukkan pada contoh sebelumnya agar sesuai dengan inisialisasi dalam definisi model asli, dan itu harus mencakup semua parameter model. Untuk menyederhanakan proses persiapan init_weights fungsi tersebut, SMP v2 mengimplementasikan fungsi inisialisasi ini untuk model berikut: GPT-2, GPT-J, GPT-Neox, dan Llama dari Hugging Face Transformers. torch.sagemaker.delayed_param.DelayedParamIniterAPI juga berfungsi dengan implementasi paralel tensor SMP, torch.sagemaker.tensor_parallel.transformer.TransformerLMHead model, yang dapat Anda panggil setelah panggilan torch.sagemaker.transform API.

Menggunakan torch.sagemaker.delayed_param.DelayedParamIniter API, Anda dapat menyesuaikan skrip PyTorch FSDP Anda sebagai berikut. Setelah membuat model dengan bobot kosong, daftarkan torch.sagemaker.delayed_param.DelayedParamIniter API ke model, dan tentukan objeknya. Lewati objek ke param_init_fn kelas PyTorch FSDP.

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

Catatan tentang bobot yang diikat

Saat melatih model dengan bobot terikat, kita perlu berhati-hati untuk mengikat bobot setelah menginisialisasi bobot dengan inisialisasi parameter yang tertunda. PyTorchFSDP tidak memiliki mekanisme untuk mengikat bobot setelah menginisialisasi mereka menggunakan seperti di atas. param_init_fn Untuk mengatasi kasus seperti itu, kami menambahkan API untuk mengizinkan apost_init_hook_fn, yang dapat digunakan untuk mengikat bobot. Anda dapat meneruskan fungsi apa pun di sana yang menerima modul sebagai argumen, tetapi kami juga memiliki standar yang post_param_init_fn ditentukan di DelayedParamIniter mana tie_weights metode panggilan modul jika ada. Perhatikan bahwa aman untuk selalu masuk post_param_init_fn meskipun tidak ada tie_weights metode untuk modul.

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )