Points de contrôle d'activation - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Points de contrôle d'activation

Le point de contrôle d'activation est une technique qui permet de réduire l'utilisation de la mémoire en effaçant les activations de certaines couches et en les recalculant lors du retour en arrière. En fait, cela permet d'échanger du temps de calcul supplémentaire contre une réduction de l'utilisation de la mémoire. Si un module est contrôlé, à la fin d'une passe directe, seules les entrées initiales du module et les sorties finales du module restent en mémoire. PyTorch libère tous les tenseurs intermédiaires qui font partie du calcul à l'intérieur de ce module lors de la passe directe. Lors du passage en arrière des modules pointés de contrôle, PyTorch recalcule ces tenseurs. À ce stade, les couches situées au-delà de ce module de point de contrôle ont terminé leur retour en arrière, de sorte que l'utilisation maximale de la mémoire avec le point de contrôle diminue.

SMP v2 prend en charge le module de point de contrôle PyTorch d'activation,. apply_activation_checkpointing Vous trouverez ci-dessous des exemples de points de contrôle d'activation du modèle Hugging Face GPT-Neox.

Couches de transformation Checkpointing du modèle Hugging Face GPT-Neox

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Vérifiez toutes les autres couches de transformation du modèle Hugging Face GPT-Neox

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

Il possède PyTorch également le torch.utils.checkpoint module de point de contrôle, qui est utilisé par un sous-ensemble de modèles Hugging Face Transformers. Ce module fonctionne également avec SMP v2. Cependant, vous devez avoir accès à la définition du modèle pour ajouter le wrapper de point de contrôle. Nous vous recommandons donc d'utiliser apply_activation_checkpointing cette méthode.