使用自訂演算法進行模型訓練和託管在 Amazon SageMaker 與 Apache Spark - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

使用自訂演算法進行模型訓練和託管在 Amazon SageMaker 與 Apache Spark

在中SageMaker 斯卡拉的例子火花,您可以使用kMeansSageMakerEstimator因為範例使用 Amazon 提供的 k 均值演算法 SageMaker 進行模型訓練。不過,您也可以選擇使用專屬的自訂演算法來訓練模型。假設您已建立 Docker 影像,就可以建立您專屬的 SageMakerEstimator,並指定自訂影像的 Amazon Elastic Container Registry 路徑。

以下範例會說明從 SageMakerEstimator 建立 KMeansSageMakerEstimator 的方式。請在新的估算器中明確地指定 Docker 登錄檔路徑,以便訓練和推論程式碼影像。

import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer val estimator = new SageMakerEstimator( trainingImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", modelImage = "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1", requestRowSerializer = new ProtobufRequestRowSerializer(), responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(), hyperParameters = Map("k" -> "10", "feature_dim" -> "784"), sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1, trainingSparkDataFormat = "sagemaker")

SageMakerEstimator 建構函式中的參數會包含以下程式碼:

  • trainingImage - 可識別訓練影像的 Docker 登錄檔路徑,該訓練影像包含自訂程式碼。

  • modelImage - 可識別影像的 Docker 登錄檔路徑,該影像包含推論程式碼。

  • requestRowSerializer - 實作 com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer

    此參數序列化輸入中的資料列,以便將其傳送DataFrame至主控於中的模型 SageMaker 進行推論。

  • responseRowDeserializer - 實作

    com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer.

    此參數會反序列化來自模型 (以中 SageMaker為主體) 的回應。DataFrame

  • trainingSparkDataFormat - 可指定 DataFrame 訓練資料上傳至 S3 期間,Spark 會使用的資料格式。例如,"sagemaker" 適用於 protobuf 格式、"csv" 適用於逗號分隔值,而 "libsvm" 適用於 LibSVM 格式。

您可以實作專屬的 RequestRowSerializerResponseRowDeserializer,將使用您推論程式碼支援之資料格式 (如 libsvm 或 .csv) 的資料列序列化及還原序列化。