本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
SageMaker 斯卡拉的例子火花
Amazon SageMaker 提供了一個 Apache 的星火庫(SageMaker星火
下載斯卡拉的星火
您可以從星火庫下載源代碼和示例 Python 星火(PySpark)和斯卡拉 GitHub 庫。SageMaker
如需安裝 SageMaker Spark 程式庫的詳細指示,請參閱 SageMakerSpark
SageMaker 斯卡拉星火 SDK 是在 Maven 中央存儲庫可用。在您的 pom.xml
檔案中新增以下相依性,將 Spark 程式庫新增至專案:
-
如果您的項目是使用 Maven 構建的,請將以下內容添加到您的 pom.xml 文件中:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency>
-
如果您的項目依賴於星火 2.1,請將以下內容添加到您的 pom.xml 文件中:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>
斯卡拉的例子火花
本節提供了使用提供的 Apache 星火斯卡拉庫的示例代碼 SageMaker 來訓練在 SageMaker 使用 DataFrame
s 在你的星火集群模型。然後,其次是關於如何使用自訂演算法進行模型訓練和託管在 Amazon SageMaker 與 Apache Spark和的示例在火花管道 SageMakerEstimator中使用。
下列範例會使用主控服務來 SageMaker 託管產生的模型加工品。如需此範例的詳細資訊,請參閱入門: SageMaker 使用 SageMaker Spark SDK 進行 K-Means 叢集
-
使用
KMeansSageMakerEstimator
,擬合 (或訓練) 資料上的模型由於此範例使用提供的 k-means 演算法 SageMaker 來訓練模型,因此您可以使用
KMeansSageMakerEstimator
. 您可以善用來自 MNIST 資料集的手寫個位數字影像,加以訓練模型。請將該影像提供為輸入DataFrame
。為了方便起見,請在 Amazon S3 儲存貯體中 SageMaker 提供此資料集。估算器會在回應中傳回
SageMakerModel
物件。 -
使用訓練過的
SageMakerModel
獲取推論若要從中託管的模型取得推論 SageMaker,請呼叫方
SageMakerModel.transform
法。您可以將DataFrame
傳遞為輸入。該方法會將輸入DataFrame
轉換為另一個DataFrame
,其將包含從模型取得的推論。針對指定的手寫個位數字輸入影像,推論功能會識別該影像所屬的叢集。如需詳細資訊,請參閱 K 平均數演算法。
import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show
此範例程式碼可做到以下操作:
-
從提供的 S3 存儲桶加載 MNIST 數據集 SageMaker(
awsai-sparksdk-dataset
)到一個星火DataFrame
(mnistTrainingDataFrame
):// Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::
account-id
:role/rolename
" trainingData.show()show
方法會在資料框架中顯示前 20 個資料列:+-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows
在每個資料列中:
-
label
欄位會識別影像的標籤。例如,如果手寫數字的影像為數字 5,標籤值即為 5。 -
features
欄位會存放org.apache.spark.ml.linalg.Vector
值的向量 (Double
)。這些值即為手寫數字的 784 特徵。(每個手寫數字的影像均為 28 x 28 像素,因此稱為 784 特徵。)
-
-
創建一個 SageMaker 估計器()
KMeansSageMakerEstimator
此估算器的
fit
方法使用提供的 k 均值演算法 SageMaker 來訓練使用輸入的模型。DataFrame
該方法會在回應中傳回SageMakerModel
物件,讓您可以獲取推論。注意
的
KMeansSageMakerEstimator
擴展 SageMakerSageMakerEstimator
,這擴展了 Apache 的星火Estimator
。val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)
建構函式參數提供用於訓練模型和部署模型的資訊 SageMaker:
-
trainingInstanceType
與trainingInstanceCount
- 可識別用來訓練模型的機器學習 (ML) 運算執行個體類型和數量。 -
endpointInstanceType
識別在 SageMaker中託管模型時要使用的 ML 計算執行個體類型。而根據預設,系統會採用一個機器學習 (ML) 運算執行個體。 -
endpointInitialInstanceCount
識別最初支援主控模型之端點的 ML 運算執行個體數目。 SageMaker -
sagemakerRole
— SageMaker 假設此 IAM 角色代表您執行任務。以模型訓練任務為例,該參數會自 S3 讀取資料並將訓練結果 (模型成品) 寫入至 S3。注意
此範例會以隱含方式建立 SageMaker 用戶端。而您必須提供登入資料,才能建立此用戶端。API 會使用這些認證來驗證要求 SageMaker。例如,它會使用認證來驗證要求,以建立訓練工作和 API 呼叫,以使用 SageMaker 主機服務部署模型。
-
KMeansSageMakerEstimator
物件建立完成後,您即可設定下列參數,以便進行模型訓練:-
訓練模型期間,K 平均數演算法應該建立的叢集數量。您可以指定 10 個叢集,並以數字 0 至 9 編號各叢集。
-
識別每個輸入影像是否皆具備 784 特徵 (每個手寫數字的影像均為 28 x 28 像素,因此稱為 784 特徵)。
-
-
-
呼叫估算器
fit
方法// train val model = estimator.fit(trainingData)
您可以將輸入
DataFrame
傳遞為參數。該模型完成訓練模型並將其部署到的所有工作 SageMaker。若要取得更多資訊,請參閱整合您的 Apache 星火應用程式 SageMaker。作為回應,您會得到一個SageMakerModel
物件,您可以使用該物件從中 SageMaker部署的模型中取得推論。您僅需提供輸入
DataFrame
。不需要為用來訓練模型的 K 平均數演算法指定登錄檔路徑,因為KMeansSageMakerEstimator
已掌握該路徑。 -
呼叫從中 SageMaker部署的模型取得推論的
SageMakerModel.transform
方法。transform
方法會採用DataFrame
做為輸入並進行轉換,接著傳回另一個DataFrame
,其將包含從模型取得的推論。val transformedData = model.transform(testData) transformedData.show
為簡化程序,做為輸入的
DataFrame
會與此範例中用來訓練模型的transform
方法相同。transform
方法會執行下列作業:-
將輸入中的
features
列序列化DataFrame
為 protobuf,並將其發送到 SageMaker 端點進行推論。 -
將 protobuf 回應還原序列化為兩個額外欄位 (
distance_to_cluster
與closest_cluster
),而這兩個欄位會位於轉換後的DataFrame
。
show
方法會取得輸入DataFrame
前 20 個資料列中的推論:+-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+
您即可解譯資料,如下所示:
-
label
為 5 的手寫數字屬於叢集 4 (closest_cluster
)。 -
label
為 0 的手寫數字屬於叢集 5。 -
label
為 4 的手寫數字屬於叢集 9。 -
label
為 1 的手寫數字屬於叢集 6。
-