Amazon SageMaker
Developer Guide

Example 1: Using Amazon SageMaker for Training and Inference with Apache Spark

Amazon SageMaker provides an Apache Spark library (in both Python and Scala) that you can use to integrate your Apache Spark applications with Amazon SageMaker. For example, you might use Apache Spark for data preprocessing and Amazon SageMaker for model training and hosting. For more information, see Using Apache Spark with Amazon SageMaker. This section provides example code that uses the Apache Spark Scala library provided by Amazon SageMaker to train a model in Amazon SageMaker using DataFrames in your Spark cluster. The example also hosts the resulting model artifacts using Amazon SageMaker hosting services. Specifically, this example does the following:

  • Uses the KMeansSageMakerEstimator to fit (or train) a model on data

     

    Because the example uses the k-means algorithm provided by Amazon SageMaker to train a model, you use the KMeansSageMakerEstimator. You train the model using images of handwritten single-digit numbers (from the MNIST dataset). You provide the images as an input DataFrame. For your convenience, Amazon SageMaker provides this dataset in an S3 bucket.

     

    In response, the estimator returns a SageMakerModel object.

     

  • Obtains inferences using the trained SageMakerModel

     

    To get inferences from a model hosted in Amazon SageMaker, you call the SageMakerModel.transform method. You pass a DataFrame as input. The method transforms the input DataFrame to another DataFrame containing inferences obtained from the model.

     

    For a given input image of a handwritten single-digit number, the inference identifies a cluster that the image belongs to. For more information, see K-Means Algorithm.

This is the example code:

import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show

The code does the following:

  • Loads the MNIST dataset from an S3 bucket provided by Amazon SageMaker (awsai-sparksdk-dataset) into a Spark DataFrame (mnistTrainingDataFrame):

    // Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" trainingData.show()

    The show method displays the first 20 rows in the data frame:

    +-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows

    In each row:

    • The label column identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5.

    • The features column stores a vector (org.apache.spark.ml.linalg.Vector) of Double values. These are the 784 features of the handwritten number. (Each handwritten number is a 28 x 28-pixel image, making 784 features.)

     

  • Creates an Amazon SageMaker estimator (KMeansSageMakerEstimator)

    The fit method of this estimator uses the k-means algorithm provided by Amazon SageMaker to train models using an input DataFrame. In response, it returns a SageMakerModel object that you can use to get inferences.

    Note

    The KMeansSageMakerEstimator extends the Amazon SageMaker SageMakerEstimator, which extends the Apache Spark Estimator.

    val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)

    The constructor parameters provide information that is used for training a model and deploying it on Amazon SageMaker:

    • trainingInstanceType and trainingInstanceCount—Identify the type and number of ML compute instances to use for model training.

       

    • endpointInstanceType—Identifies the ML compute instance type to use when hosting the model in Amazon SageMaker. By default, one ML compute instance is assumed.

       

    • endpointInitialInstanceCount—Identifies the number of ML compute instances initially backing the endpoint hosting the model in Amazon SageMaker.

       

    • sagemakerRole—Amazon SageMaker assumes this IAM role to perform tasks on your behalf. For example, for model training, it reads data from S3 and writes training results (model artifacts) to S3.

      Note

      This example implicitly creates an Amazon SageMaker client. To create this client, you must provide your credentials. The API uses these credentials to authenticate requests, such as requests to create a training job and API calls for deploying the model using Amazon SageMaker hosting services, to Amazon SageMaker.

    • After the KMeansSageMakerEstimator object has been created, you set the following parameters, are used in model training:

      • The number of clusters that the k-means algorithm should create during model training. You specify 10 clusters, one for each digit, 0-9.

      • Identifies that each input image has 784 features (each handwritten number is a 28 x 28-pixel image, making 784 features).

       

  • Call the estimator fit method

    // train val model = estimator.fit(trainingData)

    You pass the input DataFrame as a parameter. The model does all the work of training the model and deploying it to Amazon SageMaker. For more information see, Integrating Your Apache Spark Application with Amazon SageMaker. In response, you get a SageMakerModel object, which you can use to get inferences from your model deployed in Amazon SageMaker.

     

    You provide only the input DataFrame. You don't need to specify the registry path to the k-means algorithm used for model training because the KMeansSageMakerEstimator knows it.

     

  • Calls the SageMakerModel.transform method to get inferences from the model deployed in Amazon SageMaker.

    The transform method takes a DataFrame as input, transforms it, and returns another DataFrame containing inferences obtained from the model.

    val transformedData = model.transform(testData) transformedData.show

    For simplicity, we use the same DataFrame as input to the transform method that we used for model training in this example. The transform method does the following:

    • Serializes the features column in the input DataFrame to protobuf and sends it to the Amazon SageMaker endpoint for inference.

    • Deserializes the protobuf response into the two additional columns (distance_to_cluster and closest_cluster) in the transformed DataFrame.

    The show method gets inferences to the first 20 rows in the input DataFrame:

    +-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+

    You can interpret the data, as follows:

    • A handwritten number with the label 5 belongs to cluster 5 (closest_cluster).

    • A handwritten number with the label 0 belongs to cluster 2.

    • A handwritten number with the label 4 belongs to cluster 4.

    • A handwritten number with the label 1 belongs to cluster 1.

SageMaker Spark Github Readme provides information on how to run these examples. For more information, see https://github.com/aws/sagemaker-spark/blob/master/README.md.