TabTransformer - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

TabTransformer

TabTransformer是一種用於監督學習的新型深度表格數據建模體系結構。該 TabTransformer 架構建立在 self-attention-based 變形金剛上。轉換器層將分類特徵的內嵌項目轉換為強大的內容內嵌項目,以實現較高的預測準確性。此外,從中學到的上下文嵌入 TabTransformer 非常強大,可防止丟失和嘈雜的數據功能,並提供更好的解釋性。

如何使用 SageMaker TabTransformer

您可以 TabTransformer 將其用作 Amazon SageMaker 內置算法。下一節將說明如何 TabTransformer 搭配 SageMaker Python 開發套件使用。如需有關如何 TabTransformer 從 Amazon SageMaker 工作室經典使用者介面使用的資訊,請參閱SageMaker JumpStart

  • 用 TabTransformer 作內置算法

    使用 TabTransformer 內建演算法建置 TabTransformer 訓練容器,如下列程式碼範例所示。您可以使用 SageMaker image_uris.retrieve API (如果使用 Amazon SageMaker Python 開發套件第 2 版,則可以使用 get_image_uri API 自動發現 TabTransformer 內建演算法影像 URI)。

    指定 TabTransformer 映像 URI 之後,您可以使用 TabTransformer 容器使用估計器 API 建構估 SageMaker 算器,並啟動訓練工作。 TabTransformer 內建演算法會以指令碼模式執行,但是訓練指令碼是為您提供的,而且不需要取代它。如果您有使用指令碼模式建立 SageMaker 訓練工作的豐富經驗,則可以合併自己的 TabTransformer 訓練指令碼。

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "pytorch-tabtransformerclassification-model", "*", "training" training_instance_type = "ml.p3.2xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_binary/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "n_epochs" ] = "50" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "training": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    如需有關如何設定 TabTransformer 為內建演算法的詳細資訊,請參閱下列筆記本範例。

TabTransformer 算法的輸入和輸出接口

TabTransformer 對表格資料進行操作,列表示觀測值,一欄表示目標變數或標示,其餘欄表示圖徵。

實 SageMaker 施 TabTransformer 支持 CSV 進行培訓和推論:

  • 對於訓練 ContentType,有效輸入必須是文字 /csv

  • 對於推論 ContentType,有效輸入必須是文字 /csv。

注意

對於 CSV 訓練,演算法假設目標變數在第一個欄,且 CSV 沒有標題記錄。

對於 CSV 推論,演算法假設 CSV 輸入沒有標籤欄。

訓練資料、驗證資料和分類功能的輸入格式

請注意如何格式化訓練資料,以便輸入至 TabTransformer 模型。您必須提供包含訓練和驗證資料之 Amazon S3 儲存貯體的路徑。您也可以內涵分類功能清單。同時使用trainingvalidation通道來提供您的輸入資料。或者,您可以只使用training頻道。

同時使用trainingvalidation通道

您可以透過兩個 S3 路徑提供輸入資料,一個用於training通道,另一個用於validation通道。每個 S3 路徑可以是指向一或多個 CSV 檔案的 S3 前置詞,也可以是指向一個特定 CSV 檔案的完整 S3 路徑。目標變數應位於 CSV 檔案的第一欄中。預測變量 (功能) 應該位於其餘列中。如果為trainingvalidation通道提供了多個 CSV 檔案, TabTransformer 演算法會將檔案串聯起來。驗證資料用於計算每次增加迭代結束時的驗證分數。當驗證分數停止改善時,會套用提前停止。

如果您的預測值包含分類功能,您可以提供名categorical_index.json為與訓練資料檔案相同的位置的 JSON 檔案。如果您提供用於分類功能的 JSON 檔案,您的training頻道必須指向 S3 前置詞,而不是特定的 CSV 檔案。這個文件應該包含一個 Python 字典,其中索引鍵是字串 "cat_index_list",該值是唯一整數的清單。值清單中的每個整數應指出訓練資料 CSV 檔案中對應分類特徵的欄索引。每個值都應該是一個正整數 (大於零,因為零表示目標值)、小於 Int32.MaxValue (2147483647),且小於資料欄的總數。應該只有一個分類索引 JSON 檔案。

僅使用training通道

或者,您也可以透過training通道的單一 S3 路徑提供輸入資料。此 S3 路徑應指向具有名為的子目錄,training/該目錄包含一或多個 CSV 檔案。您可以選擇性地將另一個子目錄包含在位於同一個位置,且同樣具有一或多個 CSV 檔案,名為 validation/ 的子目錄。如果未提供驗證資料,則會隨機抽樣 20% 的訓練資料,做為驗證資料。如果您的預測值包含分類功能,您可以提供名categorical_index.json為與訓練資料檔案相同的位置的 JSON 檔案。

注意

對於 CSV 訓練輸入模式,可供演算法使用的總記憶體 (執行個體計數乘以在 InstanceType 中可用的記憶體) 需可保留訓練資料集。

Amazon EC2 實例推薦 TabTransformer算法

SageMaker TabTransformer 支援單一執行個體 CPU 和單一執行個體 GPU 訓練。雖然每個執行個體的成本較高,但 GPU 的訓練速度更快,更具成本效益。若要充分利用 GPU 訓練,請將執行個體類型指定為其中一個 GPU 執行個體 (例如 P3)。 SageMaker TabTransformer 目前不支援多 GPU 訓練。

TabTransformer 範例筆記本

下表概述了解決 Amazon SageMaker TabTransformer 演算法不同使用案例的各種範例筆記本。

筆記本標題 Description

使用 Amazon SageMaker TabTransformer 算法表格分類

本筆記本示範如何使用 Amazon 演 SageMaker TabTransformer算法來訓練和託管表格分類模型。

使用 Amazon SageMaker TabTransformer 算法表格回歸

本筆記本示範如何使用 Amazon 演 SageMaker TabTransformer算法來訓練和託管表格迴歸模型。

如需如何建立及存取 Jupyter 筆記本執行個體 (您可以用來執行中範例) 的指示 SageMaker,請參閱。Amazon SageMaker 筆記本實建立筆記本執行個體並開啟之後,請選擇 [SageMaker範例] 索引標籤以查看所有 SageMaker 範例的清單。若要開啟筆記本,請選擇其使用標籤,然後選擇建立複本