Modifier un script d'entraînement PyTorch - Amazon SageMaker

Modifier un script d'entraînement PyTorch

Les étapes suivantes vous expliquent comment convertir un script d'entraînement PyTorch pour utiliser la bibliothèque de données parallèles distribuées de SageMaker.

La conception des API de la bibliothèque est similaire à celle des API PyTorch Distributed Data Parallel (DDP). Pour plus de détails sur chaque API de données parallèles proposée pour PyTorch, consultez la documentation sur l'API PyTorch de la bibliothèque de données parallèles distribuées SageMaker.

Note

La bibliothèque d'entraînement distribué pour le parallélisme des données de SageMaker prend en charge la précision mixte automatique (AMP). Pour activer l'AMP, il vous suffit de modifier le cadre de votre script d'entraînement. Si les gradients sont en FP16, la bibliothèque de parallélisme de données SageMaker exécute son opération AllReduce en FP16. Pour plus d'informations sur la mise en œuvre des API AMP dans votre script d'entraînement, consultez les ressources suivantes :

  1. Importez le client PyTorch de la bibliothèque et initialisez-le, puis importez le module pour l'entraînement distribué.

    import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP dist.init_process_group()
  2. Après avoir analysé les arguments et défini un paramètre de taille de lot (batch_size=args.batch_size, par exemple), ajoutez une ligne de code de 2 pour redimensionner la taille du lot par employé (GPU). L'opération DataLoader de PyTorch ne gère pas automatiquement le redimensionnement du lot pour l'entraînement distribué.

    batch_size //= dist.get_world_size() batch_size = max(batch_size, 1)
  3. Épinglez chaque GPU à un seul processus de bibliothèque de données parallèles SageMaker avec local_rank : cela fait référence au rang relatif du processus au sein d'un nœud donné.

    L'API smdistributed.dataparallel.torch.get_local_rank() vous indique le rang local du périphérique. Le nœud principal est le rang 0, et les nœuds des employés sont les rangs 1, 2, 3, etc. Cela est appelé sous dist.get_local_rank() dans le bloc de code suivant.

    torch.cuda.set_device(dist.get_local_rank())
  4. Enveloppez le modèle PyTorch avec le DDP de la bibliothèque.

    model = ... # Wrap model with the library's DistributedDataParallel model = DDP(model)
  5. Modifiez le torch.utils.data.distributed.DistributedSampler pour inclure les informations du cluster. Définissez num_replicas au nombre total de GPU participant à l'entraînement sur tous les nœuds du cluster. C'est ce qu'on appelle world_size. Vous pouvez obtenir world_size avec l'API smdistributed.dataparallel.torch.get_world_size(). Cela est appelé sous dist.get_world_size() dans le code suivant. Fournissez également le rang de noeud en utilisant smdistributed.dataparallel.torch.get_rank(). Cela est appelé sous dist.get_rank().

    train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
  6. Modifiez votre script de sorte à enregistrer les points de contrôle sur le nœud principal uniquement. Le nœud principal a un modèle synchronisé. Cela évite également que les nœuds d'employés écrasent les points de contrôle et les endommagent éventuellement.

Voici un exemple de script d'entraînement PyTorch pour l'entraînement distribué avec la bibliothèque :

# SageMaker data parallel: Import the library PyTorch API import smdistributed.dataparallel.torch.distributed as dist # SageMaker data parallel: Import the library PyTorch DDP from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP # SageMaker data parallel: Initialize the library dist.init_process_group() class Net(nn.Module):     ...     # Define model def train(...):     ...     # Model training def test(...):     ...     # Model evaluation def main():          # SageMaker data parallel: Scale batch size by world size     batch_size //= dist.get_world_size()     batch_size = max(batch_size, 1)     # Prepare dataset     train_dataset = torchvision.datasets.MNIST(...)       # SageMaker data parallel: Set num_replicas and rank in DistributedSampler     train_sampler = torch.utils.data.distributed.DistributedSampler(             train_dataset,             num_replicas=dist.get_world_size(),             rank=dist.get_rank())       train_loader = torch.utils.data.DataLoader(..)       # SageMaker data parallel: Wrap the PyTorch model with the library's DDP     model = DDP(Net().to(device))          # SageMaker data parallel: Pin each GPU to a single library process.     torch.cuda.set_device(local_rank)     model.cuda(local_rank)          # Train     optimizer = optim.Adadelta(...)     scheduler = StepLR(...)     for epoch in range(1, args.epochs + 1):         train(...)         if rank == 0:             test(...)         scheduler.step()     # SageMaker data parallel: Save model on master node.     if dist.get_rank() == 0:         torch.save(...) if __name__ == '__main__':     main()

Pour une utilisation plus avancée, consultez la documentation sur l'API PyTorch de la bibliothèque de données parallèles distribuées SageMaker.

Une fois l'adaptation de votre script d'entraînement terminée, passez à la rubrique suivante : Exécuter une tâche d'entraînement de données parallèles distribuées SageMaker à l'aide du kit SDK Python SageMaker.