기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.
PyTorch ONNX에서 MXNet으로 전환하는 튜토리얼
ONNX 개요
Open Neural Network Exchange
이 자습서는 ONNX 지원 Conda를 사용하는 DLAMI 사용법을 보여줍니다. 이 단계를 따라함으로써 하나의 프레임워크에서 모델을 교육하거나 사전 교육된 모델을 로드하고, 이 모델을 ONNX로 내보내고, 다른 프레임워크로 이 모델을 가져올 수 있습니다.
ONNX 사전 조건
ONNX 자습서를 사용하려면 Conda를 사용하는 DLAMI 버전 12 이상에 액세스해야 합니다. Conda를 사용하는 DLAMI 시작에 대한 자세한 내용은 Conda를 사용하는 Deep Learning AMI를 참조하세요.
중요
이 예제는 최대 8GB의 메모리(또는 그 이상)가 필요할 수 있는 기능을 사용합니다. 메모리가 충분한 인스턴스 유형을 선택해야 합니다.
Conda를 사용하는 DLAMI로 터미널 세션을 시작하고 다음 자습서를 시작합니다.
PyTorch 모델을 ONNX로 변환한 다음 모델을 MXNet에 로드합니다.
먼저 환경을 활성화하십시오. PyTorch
$
source activate pytorch_p36
텍스트 편집기로 새 파일을 만들고 스크립트의 다음 프로그램을 사용하여 모의 모델을 학습시킨 다음 ONNX 형식으로 내보냅니다. PyTorch
# Build a Mock Model in PyTorch with a convolution and a reduceMean layer import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import torch.onnx as torch_onnx class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False) def forward(self, inputs): x = self.conv(inputs) #x = x.view(x.size()[0], x.size()[1], -1) return torch.mean(x, dim=2) # Use this an input trace to serialize the model input_shape = (3, 100, 100) model_onnx_path = "torch_model.onnx" model = Model() model.train(False) # Export the model to an ONNX file dummy_input = Variable(torch.randn(1, *input_shape)) output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False) print("Export of torch_model.onnx complete!")
이 스크립트를 실행한 이후 동일한 디렉터리에 새로 생성된 .onnx 파일이 표시됩니다. 이제 MXNet Conda 환경으로 전환하여 MXNet으로 모델을 로드합니다.
다음으로 MXNet 환경을 활성화합니다.
$
source deactivate$
source activate mxnet_p36
텍스트 편집기로 새 파일을 생성하고, 스크립트에서 다음 프로그램을 사용하여 MXNet에서 ONNX 형식 파일을 엽니다.
import mxnet as mx from mxnet.contrib import onnx as onnx_mxnet import numpy as np # Import the ONNX model into MXNet's symbolic interface sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx") print("Loaded torch_model.onnx!") print(sym.get_internals())
이 스크립트를 실행한 이후 MXNet에 모델이 로드되고, 일부 기본 모델 정보를 출력합니다.