Menu
Apache MXNet on AWS
Developer Guide

Step 5: Train a Sample Image Classification Model Using the Cluster

In this step, you train and build an image classification model to classify input images (for example, airplane, automobile, cat, dog, bird, etc.) using the CIFAR-10 dataset. This shows how to run Apache MXNet for distributed training. For more information about the dataset, see The CIFAR-10 dataset.

Download the code and dependencies from the awslabs/deeplearning-cfn GitHub repository.

Train the model

  1. Clone the required GitHub repositories and update the submodules.

    Clone the deeplearning-cfn repository from AWS Labs/deeplearning-cfn. This repository includes MXNet, along with other dependencies, as a submodule. The MXNet submodule contains the dmlc-core submodule. This steps updates both of these submodules.

    Copy
    $ git clone --recursive https://github.com/awslabs/deeplearning-cfn $EFS_MOUNT/deeplearning-cfn

    The MXNet submodule contains the Python code file, train_cifar10.py, which you will run to train an image classification model.

    Note

    The AWS CloudFormation template sets the $EFS_MOUNT environment variable when it creates the stack. The variable points to the directory where the Amazon EFS file system is mounted. The file system is shared across EC2 instances in the cluster.

  2. Train the image classification model using the CIFAR-10 training dataset on the cluster of EC2 instances. First create a directory where the resulting model is saved.

    Copy
    $ mkdir $EFS_MOUNT/cifar_model/ $ cd $EFS_MOUNT/deeplearning-cfn/examples/mxnet/example/image-classification/ $ ../../tools/launch.py -n $DEEPLEARNING_WORKERS_COUNT \ -H $DEEPLEARNING_WORKERS_PATH python train_cifar10.py \ --network resnet --num-layers 50 --kv-store dist_device_sync \ --model-prefix $EFS_MOUNT/cifar_model/cifar --num-epochs 10

    This training on two c4.4xlarge CPU instances (that you specified when creating the cluster) runs for about 15 minutes for 10 epochs you specified. It achieves a training accuracy of 78%. You can increase the number of epochs for higher training accuracy, however it increases the amount of time it takes to train the model.

    In the command:

    • train_cifar10.py contains the MXNet Python code to train a model on the CIFAR-10 dataset. The dataset is a collection of 60,000 images grouped into 10 classes. After training, the model can infer the class of an input image.

    • launch.py is a utility script that launches distributed training of the model. This utility uses the following environment variables to provide MXNet with information about the AWS CloudFormation stack. The AWS CloudFormation template sets these variables when you created the MXNet stack. MXNet uses this information to run distributed training.

      • $DEEPLEARNING_WORKERS_COUNT provides the count of the workers in this stack.

      • $DEEPLEARNING_WORKERS_PATH provides the path to a file containing a list of worker instances.

      • $DEEPLEARNING_WORKER_GPU_COUNT provides the count of GPUs on an instance.

    Note

    For this exercise, you launched a cluster of CPU machines. If you launch a cluster of GPU machines (for example p2.xlarge, g2.xlarge instance types), you use the following command to train a model:

    Copy
    $ ../../tools/launch.py -n $DEEPLEARNING_WORKERS_COUNT -H $DEEPLEARNING_WORKERS_PATH \ python train_cifar10.py --gpus $(seq -s , 0 1 $(($DEEPLEARNING_WORKER_GPU_COUNT - 1))) \ --network resnet --num-layers 50 --kv-store dist_device_sync \ --model-prefix $EFS_MOUNT/cifar_model/cifar --num-epochs 10

    This training takes only about 2 minutes on the two EC2 p2.8xlarge (GPU) instances. The model provides 79% accuracy. If you increase the number of epochs to 100, the training completes in about 25 minutes, and provides an accuracy of 92%.

  3. Get the path where the model is saved. You need to test the model later.

    Copy
    $ echo $EFS_MOUNT/cifar_model/

If you find that you need to terminate the MXNet processes across the EC2 worker instances in the cluster, use the following command. Beware that it terminates all Python procesess, not just the MXNet processes.

Copy
$ while read -u 10 host; do ssh -o "StrictHostKeyChecking no" $host "pkill -f python" ; \ done 10<$DEEPLEARNING_WORKERS_PATH

The $DEEPLEARNING_WORKERS_PATH environment variable (which was set by the AWS CloudFormation template when you created the stack), provides a path to a file that contains a list of all worker instances in the stack.

Next Step

Step 6: Test the Image Classification Model