Using PyTorch Models with Elastic Inference - Amazon Elastic Inference

Using PyTorch Models with Elastic Inference

This release of Elastic Inference enabled PyTorch has been tested to perform well and provide cost-saving benefits with the following deep learning use cases and network architectures (and similar variants).

Note

Elastic Inference enabled PyTorch is only available with Amazon Deep Learning Containers v27 and later.

Use Case Example Network Topology

Image Recognition

Inception, ResNet, VGG

Semantic Segmentation

UNet

Text Embeddings

BERT

Transformers

GPT

Compile Elastic Inference-enabled PyTorch models

Elastic Inference-enabled PyTorch only supports TorchScript compiled models. You can compile a PyTorch model into TorchScript using either tracing or scripting. Both produce a computation graph, but differ in how they do so.

Scripting a model is the preferred way of compiling to TorchScript because it preserves all model logic. However, the set of models that can be scripted is smaller than the set of traceable models. Your model might be traceable, but not scriptable, or not traceable at all. You may need to modify your model code to make it TorchScript compatible.

Because of the way that Elastic Inference handles control-flow operations in PyTorch 1.3.1, inference latency might be noticeable for scripted models that contain many conditional branches. Try both tracing and scripting to see how your model performs with Elastic Inference. With PyTorch 1.3.1, it is likely that a traced model performs better than its scripted version.

Scripting

Scripting performs direct analysis of the source code to construct a computation graph and preserve control flow.

The following example code shows how to compile a model using scripting. It uses the TorchVision pretrained weights for ResNet18. The resulting scripted model can still be saved to a file, then loaded with torch.jit.load using Elastic Inference-enabled PyTorch.

import torchvision, torch # ImageNet pretrained models take inputs of this size. x = torch.rand(1,3,224,224) # Call eval() to set model to inference mode model = torchvision.models.resnet18(pretrained=True).eval() scripted_model = torch.jit.script(model)

Tracing

Tracing takes a sample input and records the operations performed when executing the model on that particular input. This means that control flow may be erased because the graph is compiled by tracing the code with just one input. For example, a model definition might have code to pad images of a particular size x. If the model is traced with an image of a different size y, then future inputs of size x fed to the traced model will not be padded. This happens because the code path was never executed while tracing with the sample input.

The following example shows how to compile a model using tracing. It uses the TorchVision pretrained weights for ResNet18. The torch.jit.optimized_execution context block is required to use traced models with Elastic Inference. This function is only available through the Elastic Inference enabled PyTorch framework.

If you are tracing your model with the basic PyTorch framework, don't include the torch.jit.optimized_execution context. The resulting traced model can still be saved to a file, then loaded with torch.jit.load using Elastic Inference-enabled PyTorch.

import torchvision, torch # ImageNet pretrained models take inputs of this size. x = torch.rand(1,3,224,224) # Call eval() to set model to inference mode model = torchvision.models.resnet18(pretrained=True).eval() # Required when using Elastic Inference with torch.jit.optimized_execution(True, {‘target_device’: ‘eia:0’}): traced_model = torch.jit.trace(model, x)

Saving and loading a compiled model

The output of tracing and scripting is a ScriptModule, the TorchScript version of the basic PyTorch nn.Module. Serializing and de-serializing a TorchScript module is as easy as calling torch.jit.save() and torch.jit.load() respectively. This is the JIT version of saving and loading a basic PyTorch model using torch.save() and torch.load().

torch.jit.save(traced_model, 'resnet18_traced.pt') torch.jit.save(scripted_model, 'resnet18_scripted.pt') traced_model = torch.jit.load('resnet18_traced.pt') scripted_model = torch.jit.load('resnet18_scripted.pt')

Saved TorchScript models are not bound to specific classes and code directories, unlike basic PyTorch models. You can directly load saved TorchScript models without instantiating the model class first.

CPU training requirement

PyTorch does not save models in a device-agnostic way. Model training frequently happens in a CUDA context on a GPU. However, the Elastic Inference enabled PyTorch framework is CPU-only on the client side, even though your model runs in a CUDA context on the server.

Tracing models may lead to tensor creation on a specific device. When this happens, you may get errors when loading the model onto a different device. To avoid device-related errors, load your model by explicitly specifying the CPU device using torch.jit.load(model, map_location=torch.device('cpu')). This forces all model tensors to CPU. If you still get an error, cast your model to CPU before saving it. This can be done on any instance type, including GPU instances. For more information, see TorchScript’s Frequently Asked Questions.

Additional Requirements and Considerations

Framework Paradigms: Dynamic versus Static Computational Graphs

All deep learning frameworks view models as directed acyclic graphs. However, the frameworks differ in how they allow you to specify models. TensorFlow and MXNet use static computation graphs, meaning that the computation graph must be defined and built before it's run. In contrast, PyTorch uses dynamic computational graphs. This means that models are imperatively specified by using idiomatic Python code, and then the computation graph is built at execution time. Rather than being predetermined, the graph’s structure can change during execution.

Productionizing PyTorch with TorchScript

TorchScript addresses the limitations of the computation graph being built at execution time with JIT. JIT is a just-in-time compiler that compiles and exports models to a Python-free representation. By converting PyTorch models into TorchScript, users can run their models in any production environment. JIT also performs graph-level optimizations, providing a performance boost over basic PyTorch.

To use Elastic Inference enabled PyTorch, you must convert your models to the TorchScript format.

Model Format

Basic PyTorch uses dynamic computational graphs. This means that models are specified with idiomatic Python code and the computation graph is built at execution time. Elastic Inference supports TorchScript saved models. TorchScript uses Torch.JIT, a just-in-time compiler, to produce models that can be serialized and optimized from PyTorch code. These models can be run anywhere, including environments without Python. Torch.JIT offers two ways to compile a PyTorch model: tracing and scripting. Both produce a computation graph, but differ in how they do so. For more information on compiling using Torch.JIT, see Compile Elastic Inference-enabled PyTorch models. For more information about running inference using TorchScript, see Use Elastic Inference with PyTorch for inference.

Additional Resources

For more information about using TorchScript, see the TorchScript tutorial.

The following pretrained PyTorch models can be used with Elastic Inference: