Puntos de control de activación - Amazon SageMaker

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

Puntos de control de activación

Los puntos de control de activación son una técnica para reducir el uso de memoria al borrar las activaciones de determinadas capas y volver a calcularlas durante la pasada hacia atrás. De hecho, esto cambia el tiempo de cómputo adicional por reducir el uso de memoria. Si se comprueba un módulo, al final de una transferencia hacia adelante, solo permanecen en la memoria las entradas iniciales del módulo y las salidas finales del módulo. PyTorch libera todos los tensores intermedios que formen parte del cálculo dentro de ese módulo durante la pasada hacia adelante. Al pasar hacia atrás los módulos con puntos de control, PyTorch vuelve a calcular estos tensores. En este punto, las capas situadas más allá de este módulo de puntos de control han terminado su recorrido hacia atrás, por lo que el uso máximo de memoria con los puntos de control disminuye.

SMP v2 es compatible con el módulo de puntos de control PyTorch de activación,. apply_activation_checkpointing Los siguientes son ejemplos de puntos de control de activación del modelo Hugging Face GPT-Neox.

Capas Checkpointing Transformer del modelo Hugging Face GPT-Neox

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Controle todas las demás capas de Transformer del modelo Hugging Face GPT-Neox

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

Como alternativa, PyTorch también cuenta con el torch.utils.checkpoint módulo de puntos de control, que es utilizado por un subconjunto de modelos de Hugging Face Transformers. Este módulo también funciona con SMP v2. Sin embargo, requiere que tenga acceso a la definición del modelo para añadir el contenedor de puntos de control. Por lo tanto, le recomendamos que utilice este método. apply_activation_checkpointing