参数初始化延迟 - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

参数初始化延迟

在有限的 GPU 内存下,并非总是可以初始化用于训练的大型模型。要解决 GPU 内存不足的问题,可以在 CPU 内存上初始化模型。但是,对于参数超过 200 或 400 亿的大型号,即使是 CPU 内存也可能不够。在这种情况下,我们建议您在所 PyTorch 谓的元设备上初始化模型,这样就可以在不附加任何数据的情况下创建张量。元设备上的张量只需要形状信息,这允许在元设备上创建带有其参数的大型模型。Hugging Fac e Accelerate 提供了上下文init_empty_weights管理器,可帮助在元设备上创建此类模型,同时在普通设备上初始化缓冲区。在训练开始之前, PyTorch FSDP 会初始化模型参数。SMP v2 的延迟参数初始化功能延迟了模型参数的创建,使其在 PyTorch FSDP 执行参数分片之后发生。 PyTorch FSDP 在对模块进行分片时接受参数初始化函数 (param_init_fn),它会调param_init_fn用每个模块。param_init_fnAPI 将模块作为参数并初始化其中的所有参数,不包括任何子模块的参数。请注意,此行为与原生 PyTorch v2.0.1 不同,后者存在导致参数多次初始化的错误。

SMP v2 提供了用于应用延迟参数初始化torch.sagemaker.delayed_param.DelayedParamIniter的 API。

以下代码片段展示了如何将 torch.sagemaker.delayed_param.DelayedParamIniter API 应用于您的训练脚本。

假设你有一个 PyTorch FSDP 训练脚本,如下所示。

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

请注意,延迟参数初始化方法与模型无关。要解决此问题,您需要编写一个init_weights函数,如前面的示例所示,以匹配原始模型定义中的初始化,并且该函数应涵盖模型的所有参数。为了简化准备此类init_weights函数的过程,SMP v2 为以下模型实现了此初始化函数:GPT-2、GPT-J、GPT-Neox 和 Hugging Face Transformers 中的 Llama。该 torch.sagemaker.delayed_param.DelayedParamIniter API 还可与 SMP 张量并行实现(torch.sagemaker.tensor_parallel.transformer.TransformerLMHead模型)配合使用,您可以在 torch.sagemaker.transform API 调用后调用该实现模型。

使用该 torch.sagemaker.delayed_param.DelayedParamIniter API,您可以按如下方式调整您的 PyTorch FSDP 脚本。创建权重为空的模型后,将 torch.sagemaker.delayed_param.DelayedParamIniter API 注册到该模型,然后定义其对象。将对象传递给 PyTorch FSDP 类的。param_init_fn

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

关于绑定砝码的注意事项

在训练带有绑定权重的模型时,我们需要特别注意在使用延迟参数初始化权重后绑定权重。 PyTorchFSDP 没有在使用上述方法初始化权重后绑定权重param_init_fn的机制。为了解决此类情况,我们添加了 API 以允许 apost_init_hook_fn,它可用于绑定权重。你可以在其中传递任何接受模块作为参数的函数,但我们也有一个预post_param_init_fn定义的,如果模块存在则调用DelayedParamIniter该模块tie_weights的方法。请注意,post_param_init_fn即使该模块没有tie_weights方法,也始终可以安全地传入。

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )