텐서 병렬화 - 아마존 SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

텐서 병렬화

텐서 병렬 처리는 특정 모델 가중치, 그라디언트, 옵티마이저 상태가 디바이스 간에 분할되는 일종의 모델 병렬 처리입니다. 개별 가중치는 그대로 유지하면서 가중치, 그래디언트 또는 옵티마이저 세트를 기기 간에 분할하는 파이프라인 병렬 처리와 달리 텐서 병렬 처리는 개별 가중치를 분할합니다. 여기에는 일반적으로 해당 모델의 특정 연산, 모듈 또는 계층에 대한 분산 계산이 포함됩니다.

단일 파라미터가 대부분의 GPU 메모리를 소비하는 경우(예: 어휘 크기가 큰 대형 임베딩 테이블 또는 클래스 개수가 많은 대규모 소프트맥스 레이어) 텐서 병렬 처리가 필요합니다. 이 경우 이렇게 큰 텐서 또는 연산을 원자 단위로 처리하는 것은 비효율적이며 메모리 부하의 균형을 맞추는 데 방해가 됩니다.

SMP v2는 텐서 병렬 처리를 구현하기 위해 트랜스포머 엔진과 통합되며 FSDP API 위에서 실행됩니다. PyTorch PyTorch FSDP와 SMP 텐서 병렬 처리를 동시에 활성화하고 최상의 성능을 위한 최상의 모델 병렬 처리를 결정할 수 있습니다.

실제로 텐서 병렬화는 다음 시나리오에서 특히 유용합니다.

  • 컨텍스트 길이가 길면 FSDP만으로도 활성화 메모리가 많아지므로 학습할 때.

  • 전체 배치 크기가 원하는 한도를 초과하는 매우 큰 클러스터를 사용하여 훈련하는 경우.

SMP 텐서 병렬 처리와 호환되는 Hugging Face Transformer 모델

SMP v2는 현재 다음과 같은 Hugging Face 트랜스포머 모델에 텐서 병렬 처리를 지원합니다.

  • GPT-NeoX

  • 라마 2

이러한 모델에 텐서 병렬화를 적용하기 위한 참조 구성은 을 참조하십시오. 구성 팁

텐서 병렬화 설정하기

tensor_parallel_degree 경우 텐서 병렬도 값을 선택합니다. 값은 클러스터의 GPU 수를 균등하게 나누어야 합니다. 예를 들어, GPU가 8개인 인스턴스를 사용하면서 모델을 샤딩하려면 2, 4 또는 8을 선택합니다. 처음에는 적은 수로 시작해서 모델이 GPU 메모리에 들어갈 때까지 점진적으로 늘리는 것이 좋습니다.

다음 코드 스니펫은 에서 소개한 2단계 프로세스를 따르면서 교육 스크립트에 SMP 초기화 모듈을 torch.sagemaker.init() 추가하고 교육 작업 실행기에 사용할 SMP 구성 사전을 JSON 형식으로 설정하는 방법을 보여줍니다. SageMaker 모델 병렬화 라이브러리 v2로 시작하기 모델이나 FSDP 구성을 변경할 필요가 없습니다. PyTorch PyTorch tensor_parallel_degreerandom_seed 파라미터에 대한 자세한 내용은 SMP v2 핵심 기능 구성 매개변수 단원을 참조하세요.

SMP 컨피그레이션

{ "tensor_parallel_degree": 8, "random_seed": 0 }

교육 스크립트에서

torch.sagemaker.init()로 초기화하여 SMP v2를 활성화하고 API로 모델을 래핑합니다torch.sagemaker.transform.

import torch.sagemaker as tsm tsm.init() from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(..) model = tsm.transform(model)

Hugging Face Transformer 체크포인트 저장 및 불러오기

SMP 라이브러리는 모델을 변환한 후 모델의 상태 사전 (state_dict) 을 변경합니다. 즉, 이 모델은 오리지널 Hugging Face Transformer 체크포인트 기능과 호환되지 않게 됩니다. 이를 처리하기 위해 SMP 라이브러리는 변환된 모델의 체크포인트를 Hugging Face Transformer 표현으로 저장하는 API와 torch.sagemaker.transform 미세 조정을 위해 Hugging Face Transformer 모델 체크포인트를 로드하는 API를 제공합니다.

SMP v2의 텐서 병렬 처리 기능을 사용하는 동안 체크포인트를 저장하는 방법에 대한 자세한 내용은 을 참조하십시오. SMP를 사용하는 동안 체크포인트를 저장하고 로드하세요.

SMP v2의 텐서 병렬 처리 기능을 적용하는 모델 미세 조정에 대한 자세한 내용은 을 참조하십시오. 미세 조정