分配训练数据集 (SDK) - Rekognition

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

分配训练数据集 (SDK)

Amazon Rekognition Custom Labels 需要一个训练数据集和一个测试数据集来训练模型。

如果使用的是 API,可以使用 DistributeDatasetEntries API 将 20% 的训练数据集分配到一个空的测试数据集中。如果只有一个清单文件可用,则分配训练数据集会很有用。使用单个清单文件创建训练数据集。然后,创建一个空的测试数据集并使用 DistributeDatasetEntries 填充测试数据集。

注意

如果使用的是 Amazon Rekognition Custom Labels 控制台并从单数据集项目开始,Amazon Rekognition Custom Labels 将在训练期间拆分(分配)训练数据集以创建测试数据集。20% 的训练数据集条目将移至测试数据集。

分配训练数据集 (SDK)
  1. 安装并配置 AWS CLI 和 AWS SDK(如果尚未如此)。有关更多信息,请参阅步骤 4:设置 AWS CLI 和 AWS 软件开发工具包

  2. 创建项目。有关更多信息,请参阅创建 Amazon Rekognition Custom Labels 项目 (SDK)

  3. 创建训练数据集。有关数据集的信息,请参阅创建训练和测试数据集

  4. 创建空测试数据集。

  5. 使用以下示例代码将 20% 的训练数据集条目分配到测试数据集。可以通过调用 DescribeProjects 获取项目数据集的 Amazon 资源名称 (ARN)。有关示例代码,请参阅描述项目 (SDK)

    AWS CLI

    training_dataset-arntest_dataset_arn 的值更改为要使用的数据集的 ARN。

    aws rekognition distribute-dataset-entries --datasets ['{"Arn": "training_dataset_arn"}, {"Arn": "test_dataset_arn"}'] \ --profile custom-labels-access
    Python

    使用以下代码。提供以下命令行参数:

    • training_dataset_arn:从中分配条目的训练数据集的 ARN。

    • test_dataset_arn:将条目分配到的测试数据集的 ARN。

    # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import argparse import logging import time import json import boto3 from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def check_dataset_status(rek_client, dataset_arn): """ Checks the current status of a dataset. :param rek_client: The Amazon Rekognition Custom Labels Boto3 client. :param dataset_arn: The dataset that you want to check. :return: The dataset status and status message. """ finished = False status = "" status_message = "" while finished is False: dataset = rek_client.describe_dataset(DatasetArn=dataset_arn) status = dataset['DatasetDescription']['Status'] status_message = dataset['DatasetDescription']['StatusMessage'] if status == "UPDATE_IN_PROGRESS": logger.info("Distributing dataset: %s ", dataset_arn) time.sleep(5) continue if status == "UPDATE_COMPLETE": logger.info( "Dataset distribution complete: %s : %s : %s", status, status_message, dataset_arn) finished = True continue if status == "UPDATE_FAILED": logger.exception( "Dataset distribution failed: %s : %s : %s", status, status_message, dataset_arn) finished = True break logger.exception( "Failed. Unexpected state for dataset distribution: %s : %s : %s", status, status_message, dataset_arn) finished = True status_message = "An unexpected error occurred while distributing the dataset" break return status, status_message def distribute_dataset_entries(rek_client, training_dataset_arn, test_dataset_arn): """ Distributes 20% of the supplied training dataset into the supplied test dataset. :param rek_client: The Amazon Rekognition Custom Labels Boto3 client. :param training_dataset_arn: The ARN of the training dataset that you distribute entries from. :param test_dataset_arn: The ARN of the test dataset that you distribute entries to. """ try: # List dataset labels. logger.info("Distributing training dataset entries (%s) into test dataset (%s).", training_dataset_arn,test_dataset_arn) datasets = json.loads( '[{"Arn" : "' + str(training_dataset_arn) + '"},{"Arn" : "' + str(test_dataset_arn) + '"}]') rek_client.distribute_dataset_entries( Datasets=datasets ) training_dataset_status, training_dataset_status_message = check_dataset_status( rek_client, training_dataset_arn) test_dataset_status, test_dataset_status_message = check_dataset_status( rek_client, test_dataset_arn) if training_dataset_status == 'UPDATE_COMPLETE' and test_dataset_status == "UPDATE_COMPLETE": print("Distribution complete") else: print("Distribution failed:") print( f"\ttraining dataset: {training_dataset_status} : {training_dataset_status_message}") print( f"\ttest dataset: {test_dataset_status} : {test_dataset_status_message}") except ClientError as err: logger.exception( "Couldn't distribute dataset: %s",err.response['Error']['Message'] ) raise def add_arguments(parser): """ Adds command line arguments to the parser. :param parser: The command line parser. """ parser.add_argument( "training_dataset_arn", help="The ARN of the training dataset that you want to distribute from." ) parser.add_argument( "test_dataset_arn", help="The ARN of the test dataset that you want to distribute to." ) def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print( f"Distributing training dataset entries ({args.training_dataset_arn}) "\ f"into test dataset ({args.test_dataset_arn}).") # Distribute the datasets. session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") distribute_dataset_entries(rekognition_client, args.training_dataset_arn, args.test_dataset_arn) print("Finished distributing datasets.") except ClientError as err: logger.exception("Problem distributing datasets: %s", err) print(f"Problem listing dataset labels: {err}") except Exception as err: logger.exception("Problem distributing datasets: %s", err) print(f"Problem distributing datasets: {err}") if __name__ == "__main__": main()
    Java V2

    使用以下代码。提供以下命令行参数:

    • training_dataset_arn:从中分配条目的训练数据集的 ARN。

    • test_dataset_arn:将条目分配到的测试数据集的 ARN。

    /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ package com.example.rekognition; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.rekognition.RekognitionClient; import software.amazon.awssdk.services.rekognition.model.DatasetDescription; import software.amazon.awssdk.services.rekognition.model.DatasetStatus; import software.amazon.awssdk.services.rekognition.model.DescribeDatasetRequest; import software.amazon.awssdk.services.rekognition.model.DescribeDatasetResponse; import software.amazon.awssdk.services.rekognition.model.DistributeDataset; import software.amazon.awssdk.services.rekognition.model.DistributeDatasetEntriesRequest; import software.amazon.awssdk.services.rekognition.model.RekognitionException; import java.util.ArrayList; import java.util.logging.Level; import java.util.logging.Logger; public class DistributeDatasetEntries { public static final Logger logger = Logger.getLogger(DistributeDatasetEntries.class.getName()); public static DatasetStatus checkDatasetStatus(RekognitionClient rekClient, String datasetArn) throws Exception, RekognitionException { boolean distributed = false; DatasetStatus status = null; // Wait until distribution completes do { DescribeDatasetRequest describeDatasetRequest = DescribeDatasetRequest.builder().datasetArn(datasetArn) .build(); DescribeDatasetResponse describeDatasetResponse = rekClient.describeDataset(describeDatasetRequest); DatasetDescription datasetDescription = describeDatasetResponse.datasetDescription(); status = datasetDescription.status(); logger.log(Level.INFO, " dataset ARN: {0} ", datasetArn); switch (status) { case UPDATE_COMPLETE: logger.log(Level.INFO, "Dataset updated"); distributed = true; break; case UPDATE_IN_PROGRESS: Thread.sleep(5000); break; case UPDATE_FAILED: String error = "Dataset distribution failed: " + datasetDescription.statusAsString() + " " + datasetDescription.statusMessage() + " " + datasetArn; logger.log(Level.SEVERE, error); break; default: String unexpectedError = "Unexpected distribution state: " + datasetDescription.statusAsString() + " " + datasetDescription.statusMessage() + " " + datasetArn; logger.log(Level.SEVERE, unexpectedError); } } while (distributed == false); return status; } public static void distributeMyDatasetEntries(RekognitionClient rekClient, String trainingDatasetArn, String testDatasetArn) throws Exception, RekognitionException { try { logger.log(Level.INFO, "Distributing {0} dataset to {1} ", new Object[] { trainingDatasetArn, testDatasetArn }); DistributeDataset distributeTrainingDataset = DistributeDataset.builder().arn(trainingDatasetArn).build(); DistributeDataset distributeTestDataset = DistributeDataset.builder().arn(testDatasetArn).build(); ArrayList<DistributeDataset> datasets = new ArrayList(); datasets.add(distributeTrainingDataset); datasets.add(distributeTestDataset); DistributeDatasetEntriesRequest distributeDatasetEntriesRequest = DistributeDatasetEntriesRequest.builder() .datasets(datasets).build(); rekClient.distributeDatasetEntries(distributeDatasetEntriesRequest); DatasetStatus trainingStatus = checkDatasetStatus(rekClient, trainingDatasetArn); DatasetStatus testStatus = checkDatasetStatus(rekClient, testDatasetArn); if (trainingStatus == DatasetStatus.UPDATE_COMPLETE && testStatus == DatasetStatus.UPDATE_COMPLETE) { logger.log(Level.INFO, "Successfully distributed dataset: {0}", trainingDatasetArn); } else { throw new Exception("Failed to distribute dataset: " + trainingDatasetArn); } } catch (RekognitionException e) { logger.log(Level.SEVERE, "Could not distribute dataset: {0}", e.getMessage()); throw e; } } public static void main(String[] args) { String trainingDatasetArn = null; String testDatasetArn = null; final String USAGE = "\n" + "Usage: " + "<training_dataset_arn> <test_dataset_arn>\n\n" + "Where:\n" + " training_dataset_arn - the ARN of the dataset that you want to distribute from.\n\n" + " test_dataset_arn - the ARN of the dataset that you want to distribute to.\n\n"; if (args.length != 2) { System.out.println(USAGE); System.exit(1); } trainingDatasetArn = args[0]; testDatasetArn = args[1]; try { // Get the Rekognition client. RekognitionClient rekClient = RekognitionClient.builder() .credentialsProvider(ProfileCredentialsProvider.create("custom-labels-access")) .region(Region.US_WEST_2) .build(); // Distribute the dataset distributeMyDatasetEntries(rekClient, trainingDatasetArn, testDatasetArn); System.out.println("Datasets distributed."); rekClient.close(); } catch (RekognitionException rekError) { logger.log(Level.SEVERE, "Rekognition client error: {0}", rekError.getMessage()); System.exit(1); } catch (Exception rekError) { logger.log(Level.SEVERE, "Error: {0}", rekError.getMessage()); System.exit(1); } } }