Inicialização atrasada de parâmetros - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Inicialização atrasada de parâmetros

A inicialização de um modelo grande para treinamento nem sempre é possível com a GPU memória limitada. Para resolver esse problema de GPU memória insuficiente, você pode inicializar o modelo na CPU memória. No entanto, para modelos maiores com mais de 20 ou 40 bilhões de parâmetros, até mesmo a CPU memória pode não ser suficiente. Nesse caso, recomendamos que você inicialize o modelo no que PyTorch chama um meta-dispositivo, o que permite a criação de tensores sem nenhum dado anexado a eles. Um tensor em um meta-dispositivo precisa apenas das informações de forma, e isso permite criar um modelo grande com seus parâmetros em meta-dispositivos. O Hugging Face Accelerate fornece o gerenciador de contexto init_empty_weights para ajudar a criar esse modelo em meta-dispositivos enquanto inicializa os buffers em um dispositivo comum. Antes do início do treinamento, PyTorch FSDP inicializa os parâmetros do modelo. Esse recurso de inicialização retardada de parâmetros da SMP v2 atrasa a criação dos parâmetros do modelo após a execução da fragmentação de parâmetros PyTorch FSDP. PyTorch FSDPaceita uma função de inicialização de parâmetros (param_init_fn) ao fragmentar os módulos e chama param_init_fn cada módulo. O param_init_fn API usa um módulo como argumento e inicializa todos os parâmetros nele, sem incluir os parâmetros de nenhum módulo filho. Observe que esse comportamento difere da PyTorch versão 2.0.1 nativa, que tem um bug que faz com que os parâmetros sejam inicializados várias vezes.

SMPA v2 fornece a torch.sagemaker.delayed_param.DelayedParamIniter API aplicação da inicialização retardada de parâmetros.

Os trechos de código a seguir mostram como aplicar o torch.sagemaker.delayed_param.DelayedParamIniter API ao seu script de treinamento.

Suponha que você tenha um script PyTorch FSDP de treinamento da seguinte forma.

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

Observe que a abordagem de inicialização retardada de parâmetros não é independente do modelo. Para resolver esse problema, você precisa escrever uma init_weights função conforme mostrado no exemplo anterior para corresponder à inicialização na definição do modelo original e ela deve abranger todos os parâmetros do modelo. Para simplificar esse processo de preparação dessa init_weights função, a SMP v2 implementa essa função de inicialização para os seguintes modelos: GPT -2, GPT -J, GPT -NeoX e Llama da Hugging Face Transformers. torch.sagemaker.delayed_param.DelayedParamIniterAPITambém funciona com a implementação paralela do SMP tensor, torch.sagemaker.tensor_parallel.transformer.TransformerLMHead modelo, que você pode chamar após a torch.sagemaker.transform API chamada.

Usando o torch.sagemaker.delayed_param.DelayedParamIniterAPI, você pode adaptar seu PyTorch FSDP script da seguinte forma. Depois de criar um modelo com pesos vazios, torch.sagemaker.delayed_param.DelayedParamIniter API registre-os no modelo e defina um objeto dele. Passe o objeto para o param_init_fn da PyTorch FSDP classe.

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

Notas sobre pesos empatados

Ao treinar modelos com pesos empatados, precisamos tomar cuidado especial ao amarrar os pesos após inicializar os pesos com a inicialização atrasada dos parâmetros. PyTorchFSDPnão tem um mecanismo para amarrar os pesos após inicializá-los usando param_init_fn as instruções acima. Para resolver esses casos, adicionamos API a permissão apost_init_hook_fn, que pode ser usada para amarrar os pesos. Você pode passar qualquer função que aceite o módulo como argumento, mas também temos um método post_param_init_fn predefinido no DelayedParamIniter qual chama o tie_weights método do módulo, se ele existir. Observe que é seguro sempre passar, post_param_init_fn mesmo que não haja um tie_weights método para o módulo.

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )