Enregistrez et chargez des points de contrôle lors de l'utilisation du SMP - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Enregistrez et chargez des points de contrôle lors de l'utilisation du SMP

La bibliothèque SMP prend en charge les PyTorch API pour les points de contrôle et fournit des API qui aident à contrôler correctement les points de contrôle lors de l'utilisation de la bibliothèque SMP.

PyTorch Le FSDP prend en charge trois types de points de contrôle : complets, partitionnés et locaux. Ils répondent à des objectifs différents. Idéalement, le point de contrôle complet ne doit être utilisé que lors de l'exportation du modèle une fois l'entraînement terminé, car il est coûteux de générer un point de contrôle complet. Le point de contrôle fragmenté est l'approche recommandée pour enregistrer et charger les points de contrôle pendant l'entraînement. À l'aide de points de contrôle fragmentés, vous pouvez également modifier la taille du cluster lors de la reprise de l'entraînement. Les points de contrôle locaux sont plus restrictifs. Avec les points de contrôle locaux, vous devez reprendre l'entraînement avec le même nombre de GPU. Actuellement, cela n'est pas pris en charge lors de l'utilisation du parallélisme des tenseurs avec le SMP. Notez que les points de contrôle FSDP nécessitent d'écrire dans un système de fichiers réseau partagé, tel que FSx.

Points de contrôle fragmentés

La procédure suivante décrit ce que vous devez faire pour adapter votre script d'entraînement afin d'enregistrer et de charger des points de contrôle fragmentés avec ou sans la fonction de parallélisme des tenseurs SMP.

  1. Importez le torch.sagemaker package SMP.

    import torch.sagemaker as tsm
  2. Configurez des variables auxiliaires pour enregistrer et charger les points de contrôle.

    1. Définissez un grade de coordinateur pour effectuer des opérations collectives de communication telles queAllReduce.

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. À l'aide des torch.sagemaker.state énumérations, configurez le rang d'action pour déterminer s'il convient de laisser les grades participer au point de contrôle. Et ajoutez une instruction if pour enregistrer les points de contrôle en fonction de l'utilisation du parallélisme des tenseurs SMP v2.

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. Continuez à utiliser les API de point de contrôle PyTorch FSDP telles quelles.

L'exemple de code suivant montre un script d'entraînement PyTorch FSDP complet avec les API de point de contrôle FSDP.

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

Modèles complets de points de contrôle

À la fin de la formation, vous pouvez enregistrer un point de contrôle complet qui combine tous les fragments d'un modèle dans un seul fichier de point de contrôle du modèle. La bibliothèque SMP prend entièrement en charge l'API des points de contrôle du modèle PyTorch complet, vous n'avez donc pas besoin d'apporter de modifications.

Notez que si vous utilisez le SMPParallélisme de tenseur, la bibliothèque SMP transforme le modèle. Dans ce cas, lorsque vous vérifiez le modèle complet, la bibliothèque SMP retraduit le modèle au format de point de contrôle Hugging Face Transformers par défaut.

Dans les cas où vous vous entraînez avec le parallélisme des tenseurs SMP et que vous désactivez le processus de traduction SMP, vous pouvez utiliser l'translate_on_saveargument de l' PyTorch FullStateDictConfigAPI pour activer ou désactiver la traduction automatique SMP selon vos besoins. Par exemple, si vous vous concentrez sur la formation d'un modèle, vous n'avez pas besoin d'ajouter le processus de traduction, ce qui entraîne des frais supplémentaires. Dans ce cas, nous vous recommandons de définirtranslate_on_save=False. De plus, si vous prévoyez de continuer à utiliser la traduction SMP du modèle pour une formation continue à l'avenir, vous pouvez la désactiver pour enregistrer la traduction SMP du modèle pour une utilisation ultérieure. Il est nécessaire de retraduire le modèle au format de point de contrôle du modèle Hugging Face Transformers lorsque vous terminez l'entraînement de votre modèle et que vous l'utilisez à des fins d'inférence.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Notez que l'option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) consiste à rassembler le modèle sur le processeur du périphérique de 0e rang pour économiser de la mémoire lors de l'entraînement de grands modèles.

Pour recharger le modèle à des fins d'inférence, procédez comme indiqué dans l'exemple de code suivant. Notez que la classe AutoModelForCausalLM peut être remplacée par d'autres classes de création de facteurs dans Hugging Face Transformers, par exemple AutoModelForSeq2SeqLM en fonction de votre modèle. Pour plus d'informations, consultez la documentation de Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)