Points de contrôle d'activation - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Points de contrôle d'activation

Les points de contrôle d'activation (ou points de contrôle de gradient) sont une technique permettant de réduire l'utilisation de la mémoire en effaçant les activations de certaines couches et en les recalculant lors d'une transmission vers l'arrière. Concrètement, cela augmente le temps de calcul pour réduire l'utilisation de la mémoire. Si un module dispose de points de contrôle, à la fin d'une transmission vers l'avant, les entrées et les sorties du module restent en mémoire. Tous les tenseurs intermédiaires qui auraient fait partie du calcul à l'intérieur de ce module sont libérés pendant la transmission vers l'avant. Au cours de la transmission vers l'arrière des modules avec points de contrôle, ces tenseurs sont recalculés. À ce stade, les couches situées au-delà de ce module avec points de contrôle ont terminé leur transmission vers l'arrière. Ainsi, grâce aux points de contrôle, l'utilisation maximale de la mémoire peut être plus faible.

Note

Cette fonctionnalité est disponible PyTorch dans la bibliothèque de parallélisme des SageMaker modèles v1.6.0 et versions ultérieures.

Utilisation des points de contrôle d'activation

Avec smdistributed.modelparallel, vous pouvez utiliser les points de contrôle d'activation au niveau de détails d'un module. Pour tous les modules torch.nn à l'exception de torch.nn.Sequential, vous ne pouvez créer des points de contrôle pour une arborescence de modules que si celle-ci se trouve dans une seule partition du point de vue du parallélisme de pipeline. Dans le cas du module torch.nn.Sequential, chaque arborescence de modules à l'intérieur du module séquentiel doit se trouver complètement dans une partition pour que les points de contrôle d'activation fonctionnent. Lorsque vous utilisez le partitionnement manuel, soyez conscient de ces restrictions.

Lorsque vous utilisez le partitionnement automatisé des modèles, vous pouvez trouver les journaux d'affectation de partitionnement commençant parPartition assignments: dans les journaux de tâches d'entraînement. Si un module est partitionné sur plusieurs rangs (par exemple, avec un descendant sur un rang et un autre descendant sur un autre rang), la bibliothèque ignore la tentative de création de points de contrôles pour le module et génère un message d'avertissement indiquant qu'aucun point de contrôle ne sera créé pour le module.

Note

La bibliothèque de parallélisme du SageMaker modèle prend en charge les allreduce opérations avec ou sans chevauchement en combinaison avec le point de contrôle.

Note

PyTorchl'API de point de contrôle native n'est pas compatible avecsmdistributed.modelparallel.

Exemple 1 : l'exemple de code suivant montre comment utiliser les points de contrôle d'activation lorsque le script contient une définition de modèle.

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

Exemple 2 : l'exemple de code suivant montre comment utiliser les points de contrôle d'activation lorsque le script contient un modèle séquentiel.

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

Exemple 3 : L'exemple de code suivant montre comment utiliser le point de contrôle d'activation lorsque vous importez un modèle prédéfini à partir d'une bibliothèque, telle que Hugging Face PyTorch Transformers. Que vous créiez ou non des points de contrôle pour des modules séquentiels, procédez comme suit :

  1. Enveloppez le modèle par smp.DistributedModel().

  2. Définissez un objet pour les couches séquentielles.

  3. Encapsulez l'objet de couche séquentielle par smp.set_activation_checkpointig().

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')