Checkpoint e ottimizzazione di un modello con il parallelismo dei modelli - Amazon SageMaker

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

Checkpoint e ottimizzazione di un modello con il parallelismo dei modelli

La libreria di parallelismo dei SageMaker modelli fornisce API di checkpoint per salvare lo stato del modello e lo stato dell'ottimizzatore suddivisi in base alle varie strategie di parallelismo del modello e per caricare i checkpoint per l'addestramento continuo da cui si desidera riavviare l'allenamento e perfezionarlo. Le API supportano anche opzioni per salvare parzialmente o completamente gli stati del modello e dell'ottimizzatore.

Checkpoint di un modello distribuito

Scegliete uno dei seguenti argomenti in base al framework tra PyTorch e TensorFlow alla versione della libreria di parallelismo dei modelli che utilizzate. SageMaker

Verifica di un PyTorch modello distribuito (per la libreria di parallelismo dei SageMaker modelli v1.10.0 e successive)

La libreria di parallelismo dei SageMaker modelli fornisce API di checkpoint per salvare e caricare checkpoint completi o parziali dello stato del modello distribuito e del relativo stato di ottimizzazione.

Nota

Questo metodo di checkpoint è consigliato se si utilizza la libreria di parallelismo del modello v1.10.0 o PyTorch successiva. SageMaker

Checkpoint parziale

Per salvare i checkpoint di un modello addestrato con il parallelismo del modello, utilizza l'API smdistributed.modelparallel.torch.save_checkpoint con l'opzione checkpointing parziale impostata su true (partial=True). Ciò salva ogni partizione del modello singolarmente. Oltre al modello e allo stato dell'ottimizzatore, puoi anche salvare eventuali dati personalizzati aggiuntivi tramite l'argomento user_content. Il modello Checkpoint, l'ottimizzatore e il contenuto dell'utente vengono salvati come file separati. La chiamata API save_checkpoint crea cartelle di checkpoint nella seguente struttura.

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

Per riprendere l'addestramento da checkpoint parziali, utilizza l'API smdistributed.modelparallel.torch.resume_from_checkpointcon partial=True e specifica la directory dei checkpoint e il tag utilizzati durante il salvataggio dei checkpoint parziali. Notate che il caricamento effettivo dei pesi del modello avviene dopo il partizionamento del modello, durante la prima esecuzione della step function di addestramento smdistributed.modelparallel.torch.step decorata.

Quando si salva un checkpoint parziale, la libreria salva anche la decisione sulla partizione del modello come file con estensione di file .pt. Al contrario, quando si riprende dal checkpoint parziale, la libreria carica insieme i file di decisione sulla partizione. Una volta caricata la decisione sulla partizione, non è possibile modificarla.

Il seguente frammento di codice mostra come impostare le API di checkpoint in uno script di formazione. PyTorch

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

Checkpoint completo

Per salvare l'artefatto finale del modello a scopo di inferenza, utilizza l'API smdistributed.modelparallel.torch.save_checkpoint con partial=False, che combina le partizioni del modello per creare un singolo artefatto del modello. Nota che questo non combina gli stati dell'ottimizzatore.

Per inizializzare l'addestramento con pesi particolari, con un checkpoint completo del modello, puoi utilizzare l'API smdistributed.modelparallel.torch.resume_from_checkpoint con partial=False. Nota che questo non carica gli stati dell'ottimizzatore.

Nota

Con il parallelismo tensoriale, in generale, state_dict deve essere tradotto tra l'implementazione del modello originale e l'implementazione DistributedModel. Facoltativamente, puoi fornire la funzione di traduzione state_dict come argomento per smdistributed.modelparallel.torch.resume_from_checkpoint. Tuttavia, per Modelli supportati pronti all'uso, la libreria si occupa di questa traduzione automaticamente.

Il codice seguente mostra un esempio di come utilizzare le API checkpoint per il checkpoint completo di un modello addestrato con il parallelismo dei modelli. PyTorch

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

Controllo di un PyTorch modello distribuito (per la libreria di parallelismo dei modelli tra v1.6.0 e v1.9.0 SageMaker )

La libreria di parallelismo dei SageMaker modelli fornisce funzioni Python per il salvataggio di checkpoint parziali o completi per i lavori di formazione con parallelismo tensoriale. La procedura seguente mostra come utilizzare smp.save() e smp.load() per salvare e caricare un checkpoint quando si utilizza il parallelismo tensoriale.

Nota

Questo metodo di checkpoint è consigliato se si utilizza la libreria di parallelismo dei modelli tra la v1.6.0 e la PyTorch Parallelismo tensoriale v1.9.0. SageMaker

  1. Prepara un oggetto modello ed esegui il wrapping con la funzione wrapper della libreria smp.DistributedModel().

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Prepara un ottimizzatore per il modello. Un set di parametri del modello è un argomento iterabile richiesto dalle funzioni di ottimizzazione. Per preparare un set di parametri del modello, è necessario elaborare model.parameters() per assegnare ID univoci ai singoli parametri del modello.

    Se sono presenti parametri con ID duplicati nell'iterabile dei parametri del modello, il caricamento dello stato dell'ottimizzatore di checkpoint non riesce. Per creare un iterabile di parametri del modello con ID univoci per l'ottimizzatore, consulta quanto segue:

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Esegui il wrapping dell'ottimizzatore usando la funzione wrapper della libreria smp.DistributedOptimizer().

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Salva il modello e lo stato dell'ottimizzatore utilizzando smp.save(). A seconda di come desideri salvare i checkpoint, scegli una delle seguenti due opzioni:

    • Opzione 1: salva un modello parziale su ciascuno mp_rank per uno singolo MP_GROUP.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Con il parallelismo tensoriale, la libreria salva i file con checkpoint denominati nel seguente formato: checkpoint.pt_{pp_rank}_{tp_rank}.

      Nota

      Con il parallelismo tensoriale, assicurati di impostare l'istruzione if come if smp.rdp_rank() == 0 anziché if smp.dp_rank() == 0. Quando lo stato dell'ottimizzatore è suddiviso con il parallelismo tensoriale, tutte le classificazioni parallele con dati ridotti devono salvare la propria partizione dello stato dell'ottimizzatore. L'utilizzo di un'istruzione if errata per il checkpoint potrebbe comportare l'interruzione del processo di addestramento. Per ulteriori informazioni sull'utilizzo if smp.dp_rank() == 0 senza parallelismo tensoriale, vedere Istruzioni generali per il salvataggio e il caricamento nella documentazione di Python SageMaker SDK.

    • Opzione 2: salva il modello completo.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      Nota

      Per un checkpoint completo:

      • Se imposti gather_to_rank0=True, tutte le classificazioni diverse da 0 restituiscono dizionari vuoti.

      • Per il checkpoint completo, puoi solo controllare il modello. Il checkpoint completo degli stati dell'ottimizzatore non è attualmente supportato.

      • Il modello completo deve solo essere salvato in smp.rank() == 0.

  5. Carica i checkpoint usando smp.load(). A seconda di come hai fatto il checkpoint del passaggio precedente, scegli una delle seguenti due opzioni:

    • Opzione 1: carica i checkpoint parziali.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Puoi impostare same_partition_load=True in model.load_state_dict() per un caricamento più veloce, se sai che la partizione non cambierà.

    • Opzione 2: carica i checkpoint completi.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      La condizione if smp.rdp_rank() == 0 non è richiesta, ma può aiutare a evitare caricamenti ridondanti tra diversi MP_GROUP. Il dizionario completo degli stati dell'ottimizzatore del checkpoint non è attualmente supportato con il parallelismo del tensore.

TensorFlow Checkpointing di un modello distribuito

Per salvare un TensorFlow modello durante l'addestramento con il parallelismo dei modelli, utilizzate le seguenti funzioni fornite dalla libreria di parallelismo dei SageMaker modelli.

Ottimizzazione di un modello distribuito

L’ottimizzazione deve essere configurata nello script di addestramento. Il seguente frammento di codice mostra una struttura di esempio di uno script di formazione che utilizza la classe AutoModelForCausalLM di Hugging Face Transformers con modifiche per la registrazione dei moduli e delle impostazioni per la regolazione fine. smdistributed.model.parallel.torch

Nota

L’ottimizzazione di un trasformatore distribuito (un modello di trasformatore incluso in smp.DistributedModel()) con la funzione smp.delayed_param_initialization attivata richiede la configurazione del processo di ottimizzazione con un file system FSx for Lustre. Nei casi in cui si desidera ottimizzare un modello su larga scala con l'opzione di inizializzazione ritardata dei parametri, è necessario configurare un file system FSx for Lustre.

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

Per un esempio completo di script di addestramento e notebook Jupyter, consultate gli esempi GPT-2 presenti nell'archivio Examples. PyTorch SageMaker GitHub