Amazon SageMaker
Developer Guide

Using Custom Algorithms for Model Training and Hosting on Amazon SageMaker with Apache Spark

In Example 1: Using Amazon SageMaker for Training and Inference with Apache Spark, you use the kMeansSageMakerEstimator because the example uses the k-means algorithm provided by Amazon SageMaker for model training. You might choose to use your own custom algorithm for model training instead. Assuming that you have already created a Docker image, you can create your own SageMakerEstimator and specify the Amazon Elastic Container Registry path for your custom image.

The following code sample shows how to create a KMeansSageMakerEstimator from the SageMakerEstimator. In the new estimator, you explicitly specify the Docker registry path to your training and inference code images.

import import import import val estimator = new SageMakerEstimator( trainingImage = "", modelImage = "", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

In the code, the parameters in the SageMakerEstimator constructor include:

  • trainingImage —Identifies the Docker registry path to the training image containing your custom code.

  • modelImage —Identifies the Docker registry path to the image containing inference code.

  • requestRowSerializer —Implements

    This parameter serializes rows in the input DataFrame to send them to the model hosted in Amazon SageMaker for inference.

  • responseRowDeserializer —Implements

    This parameter deserializes responses from the model, hosted in Amazon SageMaker, back into a DataFrame.

  • trainingSparkDataFormat —Specifies the data format that Spark uses when uploading training data from a DataFrame to S3. For example, "sagemaker" for protobuf format, "csv" for comma separated values, and "libsvm" for LibSVM format.

You can implement your own RequestRowSerializer and ResponseRowDeserializer to serialize and deserialize rows from a data format that your inference code supports, such as libsvm or .csv.