LightGBM - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

LightGBM

LightGBM 是梯度增強決策樹 (GBDT) 演算法的一種熱門且高效率的開放原始碼實作。GBDT 是一種監督式學習演算法,藉由結合一組較簡單且較脆弱的模型之預估值集合來嘗試準確地預測目標變數。LightGBM 使用其他技術來大幅改善傳統 GBDT 的效率和可擴展性。

如何使用 SageMaker 光電

您可以使用 LightGBM 作為 Amazon 的 SageMaker 內置算法。下一節將說明如何搭配 SageMaker Python 開發套件使用。如需如何從 Amazon SageMaker 工作室經典使用者介面使用 LightGBM 的相關資訊,請參閱。SageMaker JumpStart

  • 使用 LightGBM 作為內建演算法

    使用 LightGBM 內建演算法來建置 LightGBM 訓練容器,如下面的程式碼範例所示。您可以使用 SageMaker image_uris.retrieve API 自動發現 LightGBM 內建演算法影像 URI (如果使用 Amazon 開發套件第 2 版 SageMaker ,則為 get_image_uri API)。

    指定 LightGBM 影像 URI 之後,您可以使用 LightGBM 容器使用估算器 API 建構估算器,並啟動訓練工作。 SageMaker LightGBM 內建演算法會以指令碼模式執行,但是訓練指令碼是為您提供的,不需要取代它。如果您有使用指令碼模式建立 SageMaker 訓練工作的豐富經驗,則可以整合自己的 LightGbm 訓練指令碼。

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "lightgbm-classification-model", "*", "training" training_instance_type = "ml.m5.xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_multiclass/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "num_boost_round" ] = "500" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, # for distributed training, specify an instance_count greater than 1 instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "train": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    如需如何將 LightGBM 設定為內建演算法的更多相關資訊,請參閱下列筆記本範例。

LightGBM 演算法的輸入和輸出介面

梯度提升在表格式資料中操作,含有代表觀察的行、還有一個代表目標變數或標籤的欄,而剩下的欄則代表功能。

SageMaker 實作光 GBM 支援用於訓練和推論的 CSV:

  • 對於訓練 ContentType,有效輸入必須是文字 /csv

  • 對於推論 ContentType,有效輸入必須是文字 /csv。

注意

對於 CSV 訓練,演算法假設目標變數在第一個欄,且 CSV 沒有標題記錄。

對於 CSV 推論,演算法假設 CSV 輸入沒有標籤欄。

訓練資料、驗證資料和分類功能的輸入格式

請注意如何設定訓練資料的格式,以輸入至 LightGBM 模型。您必須提供包含訓練和驗證資料之 Amazon S3 儲存貯體的路徑。您也可以內涵分類功能清單。同時使用trainvalidation通道來提供您的輸入資料。或者,您可以只使用train頻道。

注意

traintraining 是 LightGBM 訓練的有效頻道名稱。

同時使用trainvalidation通道

您可以透過兩個 S3 路徑提供輸入資料,一個用於train通道,另一個用於validation通道。每個 S3 路徑可以是指向一或多個 CSV 檔案的 S3 前置詞,也可以是指向一個特定 CSV 檔案的完整 S3 路徑。目標變數應位於 CSV 檔案的第一欄中。預測變量 (功能) 應該位於其餘列中。如果為 trainvalidation 頻道提供了多個 CSV 檔案,則 LightGBM 演算法會串連檔案。驗證資料用於計算每次增加迭代結束時的驗證分數。當驗證分數停止改善時,會套用提前停止。

如果您的預測值包含分類功能,您可以提供名categorical_index.json為與訓練資料檔案相同的位置的 JSON 檔案。如果您提供用於分類功能的 JSON 檔案,您的train頻道必須指向 S3 前置詞,而不是特定的 CSV 檔案。這個文件應該包含一個 Python 字典,其中索引鍵是字串 "cat_index_list",該值是唯一整數的清單。值清單中的每個整數應指出訓練資料 CSV 檔案中對應分類特徵的欄索引。每個值都應該是一個正整數 (大於零,因為零表示目標值)、小於 Int32.MaxValue (2147483647),且小於資料欄的總數。應該只有一個分類索引 JSON 檔案。

僅使用train通道

或者,您也可以透過train通道的單一 S3 路徑提供輸入資料。此 S3 路徑應指向具有名為的子目錄,train/該目錄包含一或多個 CSV 檔案。您可以選擇性地將另一個子目錄包含在位於同一個位置,且同樣具有一或多個 CSV 檔案,名為 validation/ 的子目錄。如果未提供驗證資料,則會隨機抽樣 20% 的訓練資料,做為驗證資料。如果您的預測值包含分類功能,您可以提供名categorical_index.json為與訓練資料檔案相同的位置的 JSON 檔案。

注意

對於 CSV 訓練輸入模式,可供演算法使用的總記憶體 (執行個體計數乘以在 InstanceType 中可用的記憶體) 需可保留訓練資料集。

SageMaker LightGBM 使用 Python Jobblib 模塊序列化或反序列化模型,其可用於保存或加載模型。

若要在模組中使用經過 Li SageMaker ghtGBM 訓練的模型 JobLib
  • 使用以下 Python 程式碼:

    import joblib import tarfile t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() model = joblib.load(model_file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest)

適用於 LightGBM 演算法的 Amazon EC2 執行個體推薦服務

SageMaker LightGBM 目前支援單一執行個體和多執行個體 CPU 訓練。對於多執行個體 CPU 訓練 (分散式訓練),請在定義估算器時指定大於 1 的 instance_count。如需使用 LightGBM 進行分散式訓練的詳細資訊,請參閱使用 Dask 的 Amazon SageMaker LightGBM 分散式訓練。

LightGBM 為記憶體限制型 (相對於運算限制型) 演算法。因此,相較於運算最佳化執行個體 (例如 C4),一般用途的運算執行個體 (例如 M5) 是較好的選擇。此外,我們建議您在所選執行個體中需有足夠的總記憶體才可保留訓練資料。

LightGBM 範例筆記本

下表概述了解決 Amazon SageMaker LightGBM 演算法不同使用案例的各種範例筆記本。

筆記本標題 Description

使用 Amazon SageMaker LightGBM 和算法進行表格分類 CatBoost

本筆記本示範如何使用 Amazon SageMaker LightGBM 演算法來訓練和託管表格分類模型。

利用 Amazon SageMaker LightGBM 和演算法進行表格迴歸 CatBoost

本筆記本示範如何使用 Amazon SageMaker LightGBM 演算法來訓練和託管表格回歸模型。

使用達斯克的 Amazon SageMaker LightGBM 分佈式培訓

本筆記本示範使用 Dask 架構使用 Amazon SageMaker LightGBM 演算法進行的分散式訓練。

如需如何建立及存取 Jupyter 筆記本執行個體 (您可以用來執行中範例) 的指示 SageMaker,請參閱。Amazon SageMaker 筆記本實建立筆記本執行個體並開啟之後,請選擇 [SageMaker範例] 索引標籤以查看所有 SageMaker 範例的清單。若要開啟筆記本,請選擇其使用標籤,然後選擇建立複本