지연된 파라미터 초기화 - 아마존 SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

지연된 파라미터 초기화

제한된 GPU 메모리로는 학습용 대형 모델을 초기화하는 것이 항상 가능한 것은 아닙니다. 이 GPU 메모리 부족 문제를 해결하려면 CPU 메모리에서 모델을 초기화하면 됩니다. 그러나 파라미터가 200억 또는 400억 개가 넘는 대형 모델의 경우 CPU 메모리로도 충분하지 않을 수 있습니다. 이러한 경우에는 데이터를 첨부하지 않고도 텐서를 생성할 수 있도록 메타 장치를 PyTorch 호출하는 것으로 모델을 초기화하는 것이 좋습니다. 메타 기기의 텐서에는 모양 정보만 필요하며, 이렇게 하면 메타 기기에서 해당 파라미터를 사용하여 대형 모델을 만들 수 있습니다. Hugging Face Accelerate는 일반 장치에서 버퍼를 초기화하는 동시에 메타 장치에서 이러한 모델을 생성하는 데 도움이 되는 컨텍스트 init_empty_weights 관리자를 제공합니다. 훈련이 시작되기 전에 PyTorch FSDP는 모델 파라미터를 초기화합니다. SMP v2의 이러한 지연된 파라미터 초기화 기능은 FSDP가 파라미터 샤딩을 수행한 후에 이러한 모델 파라미터 생성이 지연됩니다. PyTorch PyTorch FSDP는 모듈을 샤딩할 때 매개변수 초기화 함수 (param_init_fn) 를 받아들이고 각 모듈을 호출합니다. param_init_fn param_init_fnAPI는 모듈을 인수로 사용하여 하위 모듈의 파라미터를 제외한 모듈의 모든 파라미터를 초기화합니다. 참고로 이 동작은 파라미터가 여러 번 초기화되는 버그가 있는 네이티브 PyTorch v2.0.1과는 다릅니다.

SMP v2는 지연된 파라미터 초기화를 적용하기 위한 torch.sagemaker.delayed_param.DelayedParamIniter API를 제공합니다.

다음 코드 스니펫은 torch.sagemaker.delayed_param.DelayedParamIniter API를 교육 스크립트에 적용하는 방법을 보여줍니다.

다음과 같은 PyTorch FSDP 교육 스크립트가 있다고 가정해 보겠습니다.

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

참고로 지연된 파라미터 초기화 접근 방식은 모델에 구애받지 않습니다. 이 문제를 해결하려면 이전 예제와 같이 원래 모델 정의의 초기화와 일치하도록 init_weights 함수를 작성해야 합니다. 이 함수는 모델의 모든 매개변수를 포함해야 합니다. 이러한 init_weights 기능을 준비하는 과정을 단순화하기 위해 SMP v2는 Hugging Face Transformers의 GPT-2, GPT-J, GPT-NEOx 및 Llama와 같은 모델에 이 초기화 기능을 구현합니다. torch.sagemaker.delayed_param.DelayedParamIniterAPI는 API 호출 후 호출할 수 있는 SMP 텐서 병렬 구현 torch.sagemaker.tensor_parallel.transformer.TransformerLMHead 모델과도 작동합니다. torch.sagemaker.transform

torch.sagemaker.delayed_param.DelayedParamIniterAPI를 사용하여 다음과 같이 PyTorch FSDP 스크립트를 조정할 수 있습니다. 가중치가 비어 있는 모델을 생성한 후 모델에 torch.sagemaker.delayed_param.DelayedParamIniter API를 등록하고 해당 모델의 객체를 정의합니다. 객체를 PyTorch FSDP param_init_fn 클래스의 에 전달합니다.

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

동점 가중치에 대한 참고 사항

동일한 가중치를 사용하여 모델을 훈련할 때는 매개변수 초기화가 지연된 상태에서 가중치를 초기화한 후 가중치를 연결하도록 특별히 주의해야 합니다. PyTorchFSDP에는 위와 같이 가중치를 초기화한 후 가중치를 연결하는 메커니즘이 없습니다. param_init_fn 이러한 경우를 해결하기 위해 가중치를 연결하는 데 사용할 수 있는 post_init_hook_fn a를 허용하는 API를 추가했습니다. 모듈을 인수로 받아들이는 모든 함수를 전달할 수 있지만, 모듈이 있는 경우 해당 모듈의 tie_weights 메서드를 DelayedParamIniter 호출하는 사전 post_param_init_fn 정의된 함수도 있습니다. 참고로 모듈에 사용할 tie_weights 메서드가 post_param_init_fn 없더라도 항상 전달하는 것이 안전합니다.

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )