Ponto de verificação de ativação - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Ponto de verificação de ativação

O ponto de verificação de ativação é uma técnica para reduzir o uso de memória limpando as ativações de determinadas camadas e recomputando-as durante a passagem para trás. Efetivamente, isso troca tempo extra de computação pela redução do uso de memória. Se um módulo for verificado, no final de uma passagem direta, somente as entradas iniciais do módulo e as saídas finais do módulo permanecerão na memória. PyTorch libera quaisquer tensores intermediários que façam parte da computação dentro desse módulo durante a passagem para frente. Durante a passagem para trás dos módulos de ponto de verificação, PyTorch recalcula esses tensores. Nesse ponto, as camadas além desse módulo de ponto de verificação terminaram sua passagem para trás, então o pico de uso da memória com o ponto de verificação se torna menor.

O SMP v2 suporta o módulo de ponto de verificação de PyTorch ativação,. apply_activation_checkpointing A seguir estão exemplos de pontos de verificação de ativação do modelo Hugging Face GPT-Neox.

Camadas de transformação Checkpointing do modelo Hugging Face GPT-Neox

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Verificando todas as outras camadas de Transformer do modelo Hugging Face GPT-Neox

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

Como alternativa, PyTorch também tem o torch.utils.checkpoint módulo para checkpoint, que é usado por um subconjunto dos modelos Hugging Face Transformers. Este módulo também funciona com o SMP v2. No entanto, isso requer que você tenha acesso à definição do modelo para adicionar o invólucro do ponto de verificação. Portanto, recomendamos que você use o apply_activation_checkpointing método.