SageMaker Spark for Scala examples
Amazon SageMaker provides an Apache Spark library (SageMaker
Spark
Download Spark for Scala
You can download the source code and examples for both Python Spark (PySpark) and
Scala libraries from the SageMaker
Spark
For detailed instructions on installing the SageMaker Spark library, see SageMaker
Spark
SageMaker Spark SDK for Scala is available in the Maven central repository. Add the Spark
library to your project by adding the following dependency to your
pom.xml
file:
-
If your project is built with Maven, add the following to your pom.xml file:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency>
-
If your project depends on Spark 2.1, add the following to your pom.xml file:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>
Spark for Scala example
This section provides example code that uses the Apache Spark Scala library provided
by SageMaker to train a model in SageMaker using DataFrame
s in your Spark cluster.
This is then followed by examples on how to Use Custom Algorithms for Model
Training and Hosting on Amazon SageMaker with Apache Spark and Use the SageMakerEstimator
in a Spark Pipeline.
The following example hosts the resulting model artifacts using SageMaker hosting services.
For more details on this example, see Getting Started: K-Means Clustering on SageMaker with SageMaker Spark SDK
-
Uses the
KMeansSageMakerEstimator
to fit (or train) a model on dataBecause the example uses the k-means algorithm provided by SageMaker to train a model, you use the
KMeansSageMakerEstimator
. You train the model using images of handwritten single-digit numbers (from the MNIST dataset). You provide the images as an inputDataFrame
. For your convenience, SageMaker provides this dataset in an Amazon S3 bucket.In response, the estimator returns a
SageMakerModel
object. -
Obtains inferences using the trained
SageMakerModel
To get inferences from a model hosted in SageMaker, you call the
SageMakerModel.transform
method. You pass aDataFrame
as input. The method transforms the inputDataFrame
to anotherDataFrame
containing inferences obtained from the model.For a given input image of a handwritten single-digit number, the inference identifies a cluster that the image belongs to. For more information, see K-Means Algorithm.
import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show
The example code does the following:
-
Loads the MNIST dataset from an S3 bucket provided by SageMaker (
awsai-sparksdk-dataset
) into a SparkDataFrame
(mnistTrainingDataFrame
):// Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" trainingData.show()The
show
method displays the first 20 rows in the data frame:+-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows
In each row:
-
The
label
column identifies the image's label. For example, if the image of the handwritten number is the digit 5, the label value is 5. -
The
features
column stores a vector (org.apache.spark.ml.linalg.Vector
) ofDouble
values. These are the 784 features of the handwritten number. (Each handwritten number is a 28 x 28-pixel image, making 784 features.)
-
-
Creates a SageMaker estimator (
KMeansSageMakerEstimator
)The
fit
method of this estimator uses the k-means algorithm provided by SageMaker to train models using an inputDataFrame
. In response, it returns aSageMakerModel
object that you can use to get inferences.Note
The
KMeansSageMakerEstimator
extends the SageMakerSageMakerEstimator
, which extends the Apache SparkEstimator
.val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)
The constructor parameters provide information that is used for training a model and deploying it on SageMaker:
-
trainingInstanceType
andtrainingInstanceCount
—Identify the type and number of ML compute instances to use for model training. -
endpointInstanceType
—Identifies the ML compute instance type to use when hosting the model in SageMaker. By default, one ML compute instance is assumed. -
endpointInitialInstanceCount
—Identifies the number of ML compute instances initially backing the endpoint hosting the model in SageMaker. -
sagemakerRole
—SageMaker assumes this IAM role to perform tasks on your behalf. For example, for model training, it reads data from S3 and writes training results (model artifacts) to S3.Note
This example implicitly creates a SageMaker client. To create this client, you must provide your credentials. The API uses these credentials to authenticate requests to SageMaker. For example, it uses the credentials to authenticate requests to create a training job and API calls for deploying the model using SageMaker hosting services.
-
After the
KMeansSageMakerEstimator
object has been created, you set the following parameters, are used in model training:-
The number of clusters that the k-means algorithm should create during model training. You specify 10 clusters, one for each digit, 0 through 9.
-
Identifies that each input image has 784 features (each handwritten number is a 28 x 28-pixel image, making 784 features).
-
-
-
Calls the estimator
fit
method// train val model = estimator.fit(trainingData)
You pass the input
DataFrame
as a parameter. The model does all the work of training the model and deploying it to SageMaker. For more information see, Integrate your Apache Spark application with SageMaker. In response, you get aSageMakerModel
object, which you can use to get inferences from your model deployed in SageMaker.You provide only the input
DataFrame
. You don't need to specify the registry path to the k-means algorithm used for model training because theKMeansSageMakerEstimator
knows it. -
Calls the
SageMakerModel.transform
method to get inferences from the model deployed in SageMaker.The
transform
method takes aDataFrame
as input, transforms it, and returns anotherDataFrame
containing inferences obtained from the model.val transformedData = model.transform(testData) transformedData.show
For simplicity, we use the same
DataFrame
as input to thetransform
method that we used for model training in this example. Thetransform
method does the following:-
Serializes the
features
column in the inputDataFrame
to protobuf and sends it to the SageMaker endpoint for inference. -
Deserializes the protobuf response into the two additional columns (
distance_to_cluster
andclosest_cluster
) in the transformedDataFrame
.
The
show
method gets inferences to the first 20 rows in the inputDataFrame
:+-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+
You can interpret the data, as follows:
-
A handwritten number with the
label
5 belongs to cluster 4 (closest_cluster
). -
A handwritten number with the
label
0 belongs to cluster 5. -
A handwritten number with the
label
4 belongs to cluster 9. -
A handwritten number with the
label
1 belongs to cluster 6.
-