Apache MXNet on AWS
Developer Guide

Step 6: Test the Image Classification Model

Test the model by first predicting image classes using the model for images in the CIFAR-10 dataset. Then verify the predictions with actual image classes. You use Jupyter notebook to write and run your code.

Step 6.1: Create a Jupyter Notebook

In the step, you use a Jupyter notebook to write Apache MXNet code for training a model in Python.

Create the notebook

  1. Set up the Jupyter notebook. For instructions, see Set up a Jupyter Notebook. Follow the steps to configure both the server (the EC2 instance) and your client.

  2. Connect your client to the Jupyter notebook server. For more information, see Step 4: Test by Logging in to the Jupyter Server.

    1. In a browser window, type the URL in the address bar.

      • On Windows client, use the public DNS name of the EC2 instance followed by the port number, which is typically 8888.


        For example:
      • On macOS and Linux clients, use the following URL:
    2. If the connection is successful, the homepage of the Jupyter notebook server appears. Type the password that you created when you configured the Jupyter server.

  3. Create a Jupyter notebook, choosing the Python 2 option.

    Now you are ready to write code.

Step 6.2: Test the Image Classification Model

Now test the image classification model that you trained in the preceding section, as follows:

  • Using the model, perform image class predictions on all of the images provided in the CIFAR-10 validation dataset.

  • Validate the image classification predictions. Compare the image class predicted by the model for one of the images with the image's actual image class.

In each step in the following procedure, copy and paste the code into your Jupyter notebook, and then run the code.

Test the model

  1. Load the model we trained in the preceding section.

    import mxnet as mx symbol, arg_params, aux_params = mx.model.load_checkpoint('model-path/cifar', 300) cifar_model = mx.mod.Module(symbol=symbol) cifar_model.bind(for_training=False, data_shapes=[('data', (128,3,28,28))]) cifar_model.set_params(arg_params, aux_params)

    In load_checkpoint(), the parameter 300 is the number of epochs we used to train the model in the preceding section. This loads the model checkpoint at the end of the 300th epoch.


    If you trained your model using a different number of epochs, make sure to use that number.

  2. To download the CIFAR validation dataset, run the following code.

    import urllib urllib.urlretrieve('', '/tmp/cifar10_val.rec')
  3. Prepare a data iterator for the CIFAR-10 validation dataset (which has 10,000 images). You use this iterator in the next step to perform image classification predictions on all of these images.

    You also get the true label of the fourth image from the validation dataset. In the next step, you perform a prediction and then compare the predicted label against this true label.

    rgb_mean = [123.68,116.779,103.939] validation_data_iter = path_imgrec = "/tmp/cifar10_val.rec", label_width = 1, mean_r = rgb_mean[0], mean_g = rgb_mean[1], mean_b = rgb_mean[2], data_name = 'data', label_name = 'softmax_label', batch_size = 128, data_shape = (3,28,28), rand_crop = False, rand_mirror = False) # Get the actual label for the 4th image. first_batch = validation_data_iter.first_batch true_label = first_batch.label[0][4].asnumpy()[0]
  4. Perform image classification prediction on all of the images in the validation dataset.

    predictions = cifar_model.predict(validation_data_iter)

    In this step, you use the model to identify the image class of each of the 10,000 images in the dataset. The result is saved in predictions.

  5. Now validate the image class predictions by comparing the actual image class with the image class predicted by the model.

    # Print the actual label for the 4th image. print("Actual label for the image - ", true_label) # Print the label predicted by the model for the same image. predicted_label = predictions[4].asnumpy().argmax() print("Predicted label for the image - ", predicted_label)

Next Step

Step 7: Clean Up