Using PyTorch Models with Elastic Inference
This release of Elastic Inference enabled PyTorch has been tested to perform well and provide cost-saving benefits with the following deep learning use cases and network architectures (and similar variants).
Note
Elastic Inference enabled PyTorch is only available with Amazon Deep Learning Containers v27 and later.
Use Case | Example Network Topology |
---|---|
Image Recognition |
Inception, ResNet, VGG |
Semantic Segmentation |
UNet |
Text Embeddings |
BERT |
Transformers |
GPT |
Topics
Compile Elastic Inference-enabled PyTorch models
Elastic Inference-enabled PyTorch only supports TorchScript
Scripting a model is the preferred way of compiling to TorchScript because it preserves all model logic. However, the set of models that can be scripted is smaller than the set of traceable models. Your model might be traceable, but not scriptable, or not traceable at all. You may need to modify your model code to make it TorchScript compatible.
Because of the way that Elastic Inference handles control-flow operations in PyTorch, inference latency might be noticeable for scripted models that contain many conditional branches. Try both tracing and scripting to see how your model performs with Elastic Inference. It is likely that a traced model performs better than its scripted version.
Scripting
Scripting performs direct analysis of the source code to construct a computation graph and preserve control flow.
The following example code shows how to compile a model using scripting. It
uses the TorchVision pretrained weights for ResNet18. The resulting scripted
model can still be saved to a file, then loaded with torch.jit.load
using Elastic Inference-enabled PyTorch.
import torchvision, torch # ImageNet pretrained models take inputs of this size. x = torch.rand(1,3,224,224) # Call eval() to set model to inference mode model = torchvision.models.resnet18(pretrained=True).eval() scripted_model = torch.jit.script(model)
Tracing
Tracing takes a sample input and records the operations performed when executing the model on that particular input. This means that control flow may be erased because the graph is compiled by tracing the code with just one input. For example, a model definition might have code to pad images of a particular size x. If the model is traced with an image of a different size y, then future inputs of size x fed to the traced model will not be padded. This happens because the code path was never run while tracing with the sample input.
The following example shows how to compile a model using tracing. It uses the
TorchVision pretrained weights for ResNet18. The
torch.jit.optimized_execution
context block is required to use
traced models with Elastic Inference. This function is only available through the Elastic Inference
enabled PyTorch framework.
If you are tracing your model with the basic PyTorch framework, don't include
the torch.jit.optimized_execution
context. The resulting traced
model can still be saved to a file, then loaded with torch.jit.load
using Elastic Inference-enabled PyTorch.
import torchvision, torch # ImageNet pretrained models take inputs of this size. x = torch.rand(1,3,224,224) # Call eval() to set model to inference mode model = torchvision.models.resnet18(pretrained=True).eval() # Required when using Elastic Inference with torch.jit.optimized_execution(True, {‘target_device’: ‘eia:0’}): traced_model = torch.jit.trace(model, x)
Saving and loading a compiled model
The output of tracing and scripting is a ScriptModuletorch.save()
and
torch.load()
.
torch.jit.save(traced_model, 'resnet18_traced.pt') torch.jit.save(scripted_model, 'resnet18_scripted.pt') traced_model = torch.jit.load('resnet18_traced.pt') scripted_model = torch.jit.load('resnet18_scripted.pt')
Saved TorchScript models are not bound to specific classes and code directories, unlike basic PyTorch models. You can directly load saved TorchScript models without instantiating the model class first.
CPU training requirement
PyTorch does not save models in a device-agnostic way. Model training frequently happens in a CUDA context on a GPU. However, the Elastic Inference enabled PyTorch framework is CPU-only on the client side, even though your model runs in a CUDA context on the server.
Tracing models may lead to tensor creation on a specific device. When this
happens, you may get errors when loading the model onto a different device. To
avoid device-related errors, load your model by explicitly specifying the CPU
device using torch.jit.load(model,
map_location=torch.device('cpu'))
. This forces all model tensors to
CPU. If you still get an error, cast your model to CPU before saving it. This
can be done on any instance type, including GPU instances. For more information,
see TorchScript’s Frequently Asked Questions
Additional Requirements and Considerations
Framework Paradigms: Dynamic versus Static Computational Graphs
All deep learning frameworks view models as directed acyclic graphs. However, the frameworks differ in how they allow you to specify models. TensorFlow and MXNet use static computation graphs, meaning that the computation graph must be defined and built before it's run. In contrast, PyTorch uses dynamic computational graphs. This means that models are imperatively specified by using idiomatic Python code, and then the computation graph is built at execution time. Rather than being predetermined, the graph’s structure can change during execution.
Productionizing PyTorch with TorchScript
TorchScript
To use Elastic Inference enabled PyTorch, you must convert your models to the TorchScript format.
Model Format
Basic PyTorch uses dynamic computational graphs. This means that models are
specified with idiomatic Python code and the computation graph is built at execution
time. Elastic Inference supports TorchScript
Additional Resources
For more information about using TorchScript, see the TorchScript tutorial
The following pretrained PyTorch models can be used with Elastic Inference: