Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.
Allenamento misto di precisione
La libreria di parallelismo dei SageMaker modelli (SMP) v2 supporta l'addestramento di precisione misto pronto all'uso grazie all'integrazione con framework open source come FSDP e Transformer Engine. PyTorch Per ulteriori informazioni, consulta i seguenti argomenti.
Argomenti
Addestramento di precisione misto con nessuna istanza P5 utilizzando Transformer Engine FP8
A partire dalla libreria SageMaker Model Parallelism (SMP) v2.2.0, la libreria SMP si integra con Transformer EngineMixedPrecision
Nota
SMP v2 offre FP8 supporto per i seguenti modelli Hugging Face Transformer:
-
GPT-Neox (disponibile in SMP v2.2.0 e versioni successive)
-
Llama 2 (disponibile in SMP v2.2.0 e versioni successive)
-
Mixtral 8x7b e Mixtral 8x22b (disponibili in SMP v2.5.0 e versioni successive)
Nota
Questo FP8 corso di formazione sulla funzionalità P5 è disponibile nella seguente combinazione di librerie di e libreria: SageMaker PyTorch
-
SageMaker Python SDK v2.212.0 e versioni successive
-
PyTorch v2.2.0 e versioni successive
FP8(precisione in virgola mobile a 8 bit) è un tipo di dati che è emerso come un altro paradigma per accelerare l'addestramento del deep learning dei modelli LLM. Con il rilascio di NVIDIA H100 GPUs che supporta FP8 i tipi di dati, puoi sfruttare i vantaggi derivanti dai miglioramenti delle prestazioni sulle istanze P5 dotate di H100 GPUs, accelerando al contempo l'addestramento distribuito con un addestramento di precisione misto. FP8
Il tipo di FP8 dati si estende ulteriormente ai formati E4M3 ed E5M2. L'E4M3 offre una maggiore precisione, ha una gamma dinamica limitata ed è ideale per l'avanzamento nell'addestramento dei modelli. E5M2 ha una gamma dinamica più ampia, ma una precisione ridotta, ed è più adatto per il passaggio all'indietro, dove la precisione è meno critica e una gamma dinamica più ampia diventa vantaggiosa. Pertanto, ti consigliamo di utilizzare la ricetta della FP8 strategia ibrida
Per i tipi di dati a mezza precisione (FP16 e BF16), le tecniche globali di scalabilità delle perdite come la scalabilità statica delle perdite o la scalabilità dinamica delle perdite gestiscono i problemi di convergenza derivanti dalla perdita di informazioni dovuta all'arrotondamento dei gradienti a semiprecisione. Tuttavia, l'intervallo dinamico di è ancora più ristretto e le tecniche di scala globale delle perdite non sono sufficienti. FP8 A questo punto, abbiamo bisogno di una tecnica di ridimensionamento per tensore a grana più fine. Il ridimensionamento ritardato è una strategia che seleziona un fattore di scala basato sui valori massimi assoluti osservati in una serie di tensori nelle iterazioni precedenti. Questa strategia presenta un compromesso: sfrutta tutti i vantaggi prestazionali del FP8 calcolo, ma richiede memoria per conservare la cronologia dei valori massimi dei tensori. Per saperne di più sulla strategia di scalabilità ritardata in generale, consulta il paper FP8 Formats for Deep
In pratica, l'utilizzo FP8 è utile in tutti gli scenari di addestramento sulle istanze P5. Ti consigliamo vivamente di abilitarlo FP8 quando possibile per migliorare le prestazioni dell'allenamento.
SMP v2 supporta Transformer Engine fin dall'inizio. Pertanto, quando si esegue l' FP8allenamento con SMP v2 su istanze P5 di SageMaker AI (ml.p5.48xlarge
), l'unica cosa che devi fare è importare torch.sagemaker
lo script di addestramento e continuare a utilizzare il pacchetto Python Transformer Engine nativo. Per ulteriori informazioni sull'uso di Transformer Engine per la FP8 formazione in generale, consulta Using FP8 with Transformer Engine nella documentazione di NVIDIA Transformer
import torch.sagemaker as tsm import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format # Initialize the SMP torch.sagemaker API. tsm.init() # Define a transformer model and wrap it with the torch.sagemaker.transform API. from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(
ModelConfig
) model = tsm.transform(model) # Enable E4M3 during forward pass, E5M2 during backward pass. fp8_format = Format.HYBRID # Create an FP8 recipe. fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Enable FP8 autocasting. with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group): out = model(inp) loss = out.sum() loss.backward()
Per trovare un esempio pratico di FP8 formazione con SMP v2 su istanze P5, consultate il taccuino di esempio su Accelerate SageMaker PyTorch FSDP Training of
Addestramento di precisione misto con PyTorch tipi di dati a semiprecisione utilizzando FSDP
SMP v2 supporta PyTorch FSDP MixedPrecision
Nota
Questo addestramento di precisione misto con la funzionalità PyTorch FSDP è disponibile nella seguente combinazione di librerie di e libreria. SageMaker PyTorch
-
SMP v2.0.0 e versioni successive
-
SageMaker Python SDK v2.200.0 e versioni successive
-
PyTorch v2.0.1 e versioni successive
Il modo standard per configurare un modello a precisione mista consiste nel float32
creare il modello e quindi consentire a FSDP di trasmettere i parametri bfloat16
su float16
o al volo passando una MixedPrecision
policy, come mostrato nel seguente frammento di codice. Per ulteriori informazioni sulle opzioni per modificare i parametri dtype
for, la riduzione o i buffer per la precisione mista PyTorch, consultate l'API PyTorch MixedPrecision
FSDP
# Native PyTorch API from torch.distributed.fsdp import MixedPrecision dtype = torch.bfloat16 mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype ) model = FSDP( model, ..., mixed_precision=mixed_precision_policy )
Nota che alcuni modelli (come il modello Hugging Face Transformers Llama) prevedono buffer come. float32
Per utilizzarlofloat32
, sostituitelo torch.bfloat16
con torch.float32
nella riga che definisce l'oggetto. dtype