활성화 체크포인트 - 아마존 SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

활성화 체크포인트

활성화 체크포인트(또는 그라디언트 체크포인트)는 특정 레이어의 활성화를 지우고 역방향 패스 중에 이를 다시 계산하여 메모리 사용량을 줄이는 기법입니다. 이렇게 하면 추가 계산 시간이 줄어들어 메모리 사용량이 줄어듭니다. 모듈이 체크포인트로 지정된 경우 순방향 패스가 끝날 때 모듈의 입력과 출력은 메모리에 남습니다. 해당 모듈 내 계산의 일부를 구성한 모든 중간 텐서는 순방향 패스 중에 비워집니다. 체크포인트 모듈을 역방향으로 패스하는 동안 이러한 텐서는 다시 계산됩니다. 이 시점에서 이 체크포인트 모듈 뒤의 레이어는 역방향 패스를 완료했으므로 체크포인트의 최대 메모리 사용량을 줄일 수 있습니다.

참고

이 기능은 SageMaker 모델 병렬 처리 라이브러리 PyTorch v1.6.0 이상에서 사용할 수 있습니다.

활성화 체크포인트 사용 방법

smdistributed.modelparallel을 사용하면 모듈의 세부 수준에서 활성화 체크포인트를 사용할 수 있습니다. torch.nn.Sequential을 제외한 모든 torch.nn 모듈의 경우 파이프라인 병렬 처리 관점에서 볼 때 모듈 트리가 한 파티션 내에 있는 경우에만 모듈 트리를 체크포인트할 수 있습니다. torch.nn.Sequential 모듈의 경우 활성화 체크포인트가 작동하려면 순차 모듈 내의 각 모듈 트리가 완전히 한 파티션 내에 있어야 합니다. 수동 분할을 사용할 때는 이러한 제한 사항에 유의하세요.

자동 모델 분할을 사용하는 경우 훈련 작업 로그에서 Partition assignments:로 시작하는 분할 할당 로그를 확인할 수 있습니다. 모듈이 여러 랭크로 분할된 경우(예: 한 랭크에 하위 항목 하나가 있고 다른 랭크에 또 다른 하위 항목 하나) 라이브러리는 모듈을 체크포인트하려는 시도를 무시하고 모듈을 체크포인트하지 않을 것이라는 경고 메시지를 표시합니다.

참고

SageMaker 모델 병렬화 라이브러리는 체크포인트와 함께 중복 및 비중복 작업을 모두 지원합니다. allreduce

참고

PyTorch의 네이티브 체크포인트 API는 호환되지 않습니다. smdistributed.modelparallel

예제 1: 다음 샘플 코드는 스크립트에 모델 정의가 있을 때 활성화 체크포인트를 사용하는 방법을 보여줍니다.

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

예제 2: 다음 샘플 코드는 스크립트에 순차적 모델이 있을 때 활성화 체크포인트를 사용하는 방법을 보여줍니다.

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

예제 3: 다음 샘플 코드는 PyTorch Hugging Face Transformers와 같은 라이브러리에서 사전 빌드된 모델을 가져올 때 활성화 체크포인트를 사용하는 방법을 보여줍니다. 순차 모듈 체크포인트 여부에 관계없이 다음을 수행합니다.

  1. smp.DistributedModel()로 모델을 래핑합니다.

  2. 순차 계층용 객체를 정의합니다.

  3. smp.set_activation_checkpointig()로 순차 계층 객체를 래핑합니다.

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')