SageMaker Spark for Scala examples - Amazon SageMaker

SageMaker Spark for Scala examples

Amazon SageMaker provides an Apache Spark library (SageMaker Spark) that you can use to integrate your Apache Spark applications with SageMaker. For example, you might use Apache Spark for data preprocessing and SageMaker for model training and hosting. For information about the SageMaker Apache Spark library, see Use Apache Spark with Amazon SageMaker.

Download Spark for Scala

You can download the source code and examples for both Python Spark (PySpark) and Scala libraries from the SageMaker Spark GitHub repository.

For detailed instructions on installing the SageMaker Spark library, see SageMaker Spark.

SageMaker Spark SDK for Scala is available in the Maven central repository. Add the Spark library to your project by adding the following dependency to your pom.xml file:

  • If your project is built with Maven, add the following to your pom.xml file:

    <dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency>
  • If your project depends on Spark 2.1, add the following to your pom.xml file:

    <dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>

Spark for Scala example

This section provides example code that uses the Apache Spark Scala library provided by SageMaker to train a model in SageMaker using DataFrames in your Spark cluster. This is then followed by examples on how to Use Custom Algorithms for Model Training and Hosting on Amazon SageMaker with Apache Spark and Use the SageMakerEstimator in a Spark Pipeline.

The following example hosts the resulting model artifacts using SageMaker hosting services. For more details on this example, see Getting Started: K-Means Clustering on SageMaker with SageMaker Spark SDK Specifically, this example does the following:

  • Uses the KMeansSageMakerEstimator to fit (or train) a model on data

    Because the example uses the k-means algorithm provided by SageMaker to train a model, you use the KMeansSageMakerEstimator. You train the model using images of handwritten single-digit numbers (from the MNIST dataset). You provide the images as an input DataFrame. For your convenience, SageMaker provides this dataset in an Amazon S3 bucket.

    In response, the estimator returns a SageMakerModel object.

  • Obtains inferences using the trained SageMakerModel

    To get inferences from a model hosted in SageMaker, you call the SageMakerModel.transform method. You pass a DataFrame as input. The method transforms the input DataFrame to another DataFrame containing inferences obtained from the model.

    For a given input image of a handwritten single-digit number, the inference identifies a cluster that the image belongs to. For more information, see K-Means Algorithm.

import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show

The example code does the following:

  • Loads the MNIST dataset from an S3 bucket provided by SageMaker (awsai-sparksdk-dataset) into a Spark DataFrame (mnistTrainingDataFrame):

    // Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" trainingData.show()

    The show method displays the first 20 rows in the data frame:

    +-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows

    In each row:

    • The label column identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5.

    • The features column stores a vector (org.apache.spark.ml.linalg.Vector) of Double values. These are the 784 features of the handwritten number. (Each handwritten number is a 28 x 28-pixel image, making 784 features.)

  • Creates a SageMaker estimator (KMeansSageMakerEstimator)

    The fit method of this estimator uses the k-means algorithm provided by SageMaker to train models using an input DataFrame. In response, it returns a SageMakerModel object that you can use to get inferences.

    Note

    The KMeansSageMakerEstimator extends the SageMaker SageMakerEstimator, which extends the Apache Spark Estimator.

    val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)

    The constructor parameters provide information that is used for training a model and deploying it on SageMaker:

    • trainingInstanceType and trainingInstanceCount—Identify the type and number of ML compute instances to use for model training.

    • endpointInstanceType—Identifies the ML compute instance type to use when hosting the model in SageMaker. By default, one ML compute instance is assumed.

    • endpointInitialInstanceCount—Identifies the number of ML compute instances initially backing the endpoint hosting the model in SageMaker.

    • sagemakerRole—SageMaker assumes this IAM role to perform tasks on your behalf. For example, for model training, it reads data from S3 and writes training results (model artifacts) to S3.

      Note

      This example implicitly creates a SageMaker client. To create this client, you must provide your credentials. The API uses these credentials to authenticate requests to SageMaker. For example, it uses the credentials to authenticate requests to create a training job and API calls for deploying the model using SageMaker hosting services.

    • After the KMeansSageMakerEstimator object has been created, you set the following parameters, are used in model training:

      • The number of clusters that the k-means algorithm should create during model training. You specify 10 clusters, one for each digit, 0 through 9.

      • Identifies that each input image has 784 features (each handwritten number is a 28 x 28-pixel image, making 784 features).

  • Calls the estimator fit method

    // train val model = estimator.fit(trainingData)

    You pass the input DataFrame as a parameter. The model does all the work of training the model and deploying it to SageMaker. For more information see, Integrate Your Apache Spark Application with SageMaker. In response, you get a SageMakerModel object, which you can use to get inferences from your model deployed in SageMaker.

    You provide only the input DataFrame. You don't need to specify the registry path to the k-means algorithm used for model training because the KMeansSageMakerEstimator knows it.

  • Calls the SageMakerModel.transform method to get inferences from the model deployed in SageMaker.

    The transform method takes a DataFrame as input, transforms it, and returns another DataFrame containing inferences obtained from the model.

    val transformedData = model.transform(testData) transformedData.show

    For simplicity, we use the same DataFrame as input to the transform method that we used for model training in this example. The transform method does the following:

    • Serializes the features column in the input DataFrame to protobuf and sends it to the SageMaker endpoint for inference.

    • Deserializes the protobuf response into the two additional columns (distance_to_cluster and closest_cluster) in the transformed DataFrame.

    The show method gets inferences to the first 20 rows in the input DataFrame:

    +-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+

    You can interpret the data, as follows:

    • A handwritten number with the label 5 belongs to cluster 4 (closest_cluster).

    • A handwritten number with the label 0 belongs to cluster 5.

    • A handwritten number with the label 4 belongs to cluster 9.

    • A handwritten number with the label 1 belongs to cluster 6.