MXNet Elastic Inference with Java - Amazon Elastic Inference

MXNet Elastic Inference with Java

Starting from Apache MXNet version 1.4, the Java API can now integrate with Amazon Elastic Inference. You can use Elastic Inference with the following MXNet Java API operations:

  • MXNet Java Infer API

Install Amazon EI Enabled Apache MXNet

Amazon Elastic Inference enabled Apache MXNet is available in the AWS Deep Learning AMI. A maven repository is also available on Amazon S3. You can build this repository into your own Amazon Linux or Ubuntu AMIs, or Docker containers.

For Maven projects, Elastic Inference Java can be included by adding the following to your project's pom.xml:

<repositories>     <repository>       <id>Amazon Elastic Inference</id>       <url></url>     </repository> </repositories>

In addition, add the Elastic Inference flavor of MXNet as a dependency using:

        <dependency>             <groupId></groupId>             <artifactId>mxnet-full_2.11-linux-x86_64-eia</artifactId>             <version>[1.4.0,)</version>         </dependency>

Check MXNet for Java Version

You can use the commit hash number to determine which release of the Java-specific version of MXNet is installed using the following code:

// Imports import org.apache.mxnet.javaapi.*; // Lines to run Version$ version$ = Version$.MODULE$; System.out.println(version$.getCommitHash());

You can then compare the commit hash with the Release Notes to find the specific info about the version you have.  ​

Use Amazon Elastic Inference with the MXNet Java Infer API

To use Amazon Elastic Inference with the MXNet Java Infer API, pass Context.eia() as the context when creating the Infer Predictor object. See the MXNet Infer Reference for more information. The following example uses the pre-trained real model (Resnet-152):

package mxnet; import java.awt.image.BufferedImage; import; import; import; import java.util.Arrays; import java.util.Comparator; import java.util.List; import; import; import org.apache.mxnet.infer.javaapi.ObjectDetector; import org.apache.mxnet.infer.javaapi.Predictor; import org.apache.mxnet.javaapi.*; public class Example {     public static void main(String[] args) throws IOException {         String urlPath = "";         String filePath = System.getProperty("");         // Download Model and Image         FileUtils.copyURLToFile(new URL(urlPath + "/resnet/152-layers/resnet-152-0000.params"),                 new File(filePath + "resnet-152/resnet-152-0000.params"));         FileUtils.copyURLToFile(new URL(urlPath + "/resnet/152-layers/resnet-152-symbol.json"),                 new File(filePath + "resnet-152/resnet-152-symbol.json"));         FileUtils.copyURLToFile(new URL(urlPath + "/synset.txt"),                 new File(filePath + "resnet-152/synset.txt"));         FileUtils.copyURLToFile(new URL(""),                 new File(filePath + "cat.jpg"));         List<Context> contexts = Arrays.asList(Context.eia());         Shape inputShape = new Shape(new int[]{1, 3, 224, 224});         List<DataDesc> inputDesc = Arrays.asList(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));         Predictor predictor = new Predictor(filePath + "resnet-152/resnet-152", inputDesc, contexts, 0);         BufferedImage originalImg = ObjectDetector.loadImageFromFile(filePath + "cat.jpg");         BufferedImage resizedImg = ObjectDetector.reshapeImage(originalImg, 224, 224);         NDArray img = ObjectDetector.bufferedImageToPixels(resizedImg, new Shape(new int[]{1, 3, 224, 224}));         List<NDArray> predictResults = predictor.predictWithNDArray(Arrays.asList(img));         float[] results = predictResults.get(0).toArray();         List<String> synsetLines = FileUtils.readLines(new File(filePath + "resnet-152/synset.txt"));         int[] best = IntStream.range(0, results.length)                 .boxed().sorted(Comparator.comparing(i -> -results[i]))                 .mapToInt(ele -> ele).toArray();         for (int i = 0; i < 5; i++) {             int ind = best[i];             System.out.println(i + ": " + synsetLines.get(ind) + " - " + best[ind]);         }     } }

More Models and Resources

For more tutorials and examples, see:


  • MXNet EI is built with MKL-DNN. All operations using Context.cpu() are supported and will run with the same performance as the standard release. MXNet EI does not support Context.gpu(). All operations using that context will throw an error.

  • You cannot allocate memory for NDArray on the remote accelerator by writing something like this:

    x = NDArray.array(Array(1,2,3), ctx=Context.eia())

    This throws an error. Instead you should use Context.cpu(). Look at the previous bind() example to see how MXNet automatically transfers your data to the accelerator as necessary. Sample error message:

  • Elastic Inference is only for production inference use cases and does not support any model training. When you use either the Symbol API or the Module API, do not call the backward() method or call bind() with forTraining=True. This throws an error. Because the default value of forTraining is True, make sure you set for_training=False manually in cases such as the example in Use Elastic Inference with the MXNet Module API. Sample error using

  • Because training is not allowed, there is no point of initializing an optimizer for inference.

  • A model trained on an earlier version of MXNet will work on a later version of MXNet EI because it is backwards compatible. For example, you can train a model on MXNet 1.3 and run it on MXNet EI 1.4. However, you may run into undefined behavior if you train on a later version of MXNet. For example, training a model on MXNet Master and running on MXNet EI 1.4.

  • Different sizes of EI accelerators have different amounts of GPU memory. If your model requires more GPU memory than is available in your accelerator, you get a message that looks like the log below. If you run into this message, you should use a larger accelerator size with more memory. Stop and restart your instance with a larger accelerator.

  • Calling reshape explicitly by using either the Module or the Symbol API can lead to OOM errors. Implicitly using different shapes for input NDArrays in different forward passes can also lead to OOM errors. Before being reshaped, the model is not cleaned up on the accelerator until the session is destroyed.