훈련 데이터 세트 배포(SDK) - Rekognition

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

훈련 데이터 세트 배포(SDK)

Amazon Rekognition Custom Labels가 모델을 훈련하려면 훈련 데이터 세트와 테스트 데이터 세트가 필요합니다.

API를 사용하는 경우 DistributeDataSetEntries API를 사용하여 훈련 데이터 세트의 20%를 빈 테스트 데이터 세트에 배포할 수 있습니다. 사용 가능한 매니페스트 파일이 하나뿐인 경우 훈련 데이터 세트를 배포하는 것이 유용할 수 있습니다. 단일 매니페스트 파일을 사용하여 훈련 데이터 세트를 만드세요. 그런 다음 빈 테스트 데이터 세트를 만들고 DistributeDatasetEntries 항목을 사용하여 테스트 데이터 세트를 채웁니다.

참고

Amazon Rekognition Custom Labels 콘솔을 사용하고 데이터 세트 하나로 프로젝트를 시작하는 경우, Amazon Rekognition Custom Labels는 훈련 중에 훈련 데이터 세트를 분할(배포)하여 테스트 데이터 세트를 생성합니다. 훈련 데이터 세트 항목의 20%가 테스트 데이터 세트로 이동합니다.

훈련 데이터 세트를 배포하려면(SDK)
  1. 아직 하지 않았다면 AWS CLI 및 AWS SDK를 설치하고 구성하세요. 자세한 내용은 4단계: 설정 AWS CLI 그리고 AWS SDKs 섹션을 참조하세요.

  2. 프로젝트 생성 자세한 내용은 Amazon Rekognition Custom Labels 프로젝트 생성(SDK) 섹션을 참조하세요.

  3. 훈련 데이터 세트를 생성합니다. 데이터 세트에 관한 자세한 내용은 훈련 및 테스트 데이터 세트 생성 항목을 참조하세요.

  4. 빈 테스트 데이터 세트를 생성합니다.

  5. 다음 예제 코드를 사용하여 훈련 데이터 세트 항목의 20%를 테스트 데이터 세트에 배포합니다. DescribeProjects를 호출하여 프로젝트 데이터 세트의 Amazon 리소스 이름(ARN)을 가져올 수 있습니다. 예제 코드는 프로젝트 설명 () SDK 항목을 참조하세요.

    AWS CLI

    training_dataset-arntest_dataset_arn의 값을 사용하려는 데이터 세트의 ARN으로 변경합니다.

    aws rekognition distribute-dataset-entries --datasets ['{"Arn": "training_dataset_arn"}, {"Arn": "test_dataset_arn"}'] \ --profile custom-labels-access
    Python

    다음 코드를 사용합니다. 다음 명령줄 파라미터를 제공하세요.

    • training_dataset_arn: 항목을 배포하는 데 사용되는 훈련 데이터 세트의 ARN입니다.

    • test_dataset_arn: 항목을 배포하는 데 사용되는 테스트 데이터 세트의 ARN입니다.

    # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import argparse import logging import time import json import boto3 from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def check_dataset_status(rek_client, dataset_arn): """ Checks the current status of a dataset. :param rek_client: The Amazon Rekognition Custom Labels Boto3 client. :param dataset_arn: The dataset that you want to check. :return: The dataset status and status message. """ finished = False status = "" status_message = "" while finished is False: dataset = rek_client.describe_dataset(DatasetArn=dataset_arn) status = dataset['DatasetDescription']['Status'] status_message = dataset['DatasetDescription']['StatusMessage'] if status == "UPDATE_IN_PROGRESS": logger.info("Distributing dataset: %s ", dataset_arn) time.sleep(5) continue if status == "UPDATE_COMPLETE": logger.info( "Dataset distribution complete: %s : %s : %s", status, status_message, dataset_arn) finished = True continue if status == "UPDATE_FAILED": logger.exception( "Dataset distribution failed: %s : %s : %s", status, status_message, dataset_arn) finished = True break logger.exception( "Failed. Unexpected state for dataset distribution: %s : %s : %s", status, status_message, dataset_arn) finished = True status_message = "An unexpected error occurred while distributing the dataset" break return status, status_message def distribute_dataset_entries(rek_client, training_dataset_arn, test_dataset_arn): """ Distributes 20% of the supplied training dataset into the supplied test dataset. :param rek_client: The Amazon Rekognition Custom Labels Boto3 client. :param training_dataset_arn: The ARN of the training dataset that you distribute entries from. :param test_dataset_arn: The ARN of the test dataset that you distribute entries to. """ try: # List dataset labels. logger.info("Distributing training dataset entries (%s) into test dataset (%s).", training_dataset_arn,test_dataset_arn) datasets = json.loads( '[{"Arn" : "' + str(training_dataset_arn) + '"},{"Arn" : "' + str(test_dataset_arn) + '"}]') rek_client.distribute_dataset_entries( Datasets=datasets ) training_dataset_status, training_dataset_status_message = check_dataset_status( rek_client, training_dataset_arn) test_dataset_status, test_dataset_status_message = check_dataset_status( rek_client, test_dataset_arn) if training_dataset_status == 'UPDATE_COMPLETE' and test_dataset_status == "UPDATE_COMPLETE": print("Distribution complete") else: print("Distribution failed:") print( f"\ttraining dataset: {training_dataset_status} : {training_dataset_status_message}") print( f"\ttest dataset: {test_dataset_status} : {test_dataset_status_message}") except ClientError as err: logger.exception( "Couldn't distribute dataset: %s",err.response['Error']['Message'] ) raise def add_arguments(parser): """ Adds command line arguments to the parser. :param parser: The command line parser. """ parser.add_argument( "training_dataset_arn", help="The ARN of the training dataset that you want to distribute from." ) parser.add_argument( "test_dataset_arn", help="The ARN of the test dataset that you want to distribute to." ) def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print( f"Distributing training dataset entries ({args.training_dataset_arn}) "\ f"into test dataset ({args.test_dataset_arn}).") # Distribute the datasets. session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") distribute_dataset_entries(rekognition_client, args.training_dataset_arn, args.test_dataset_arn) print("Finished distributing datasets.") except ClientError as err: logger.exception("Problem distributing datasets: %s", err) print(f"Problem listing dataset labels: {err}") except Exception as err: logger.exception("Problem distributing datasets: %s", err) print(f"Problem distributing datasets: {err}") if __name__ == "__main__": main()
    Java V2

    다음 코드를 사용합니다. 다음 명령줄 파라미터를 제공하세요.

    • training_dataset_arn: 항목을 배포하는 데 사용되는 훈련 데이터 세트의 ARN입니다.

    • test_dataset_arn: 항목을 배포하는 데 사용되는 테스트 데이터 세트의 ARN입니다.

    /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ package com.example.rekognition; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.rekognition.RekognitionClient; import software.amazon.awssdk.services.rekognition.model.DatasetDescription; import software.amazon.awssdk.services.rekognition.model.DatasetStatus; import software.amazon.awssdk.services.rekognition.model.DescribeDatasetRequest; import software.amazon.awssdk.services.rekognition.model.DescribeDatasetResponse; import software.amazon.awssdk.services.rekognition.model.DistributeDataset; import software.amazon.awssdk.services.rekognition.model.DistributeDatasetEntriesRequest; import software.amazon.awssdk.services.rekognition.model.RekognitionException; import java.util.ArrayList; import java.util.logging.Level; import java.util.logging.Logger; public class DistributeDatasetEntries { public static final Logger logger = Logger.getLogger(DistributeDatasetEntries.class.getName()); public static DatasetStatus checkDatasetStatus(RekognitionClient rekClient, String datasetArn) throws Exception, RekognitionException { boolean distributed = false; DatasetStatus status = null; // Wait until distribution completes do { DescribeDatasetRequest describeDatasetRequest = DescribeDatasetRequest.builder().datasetArn(datasetArn) .build(); DescribeDatasetResponse describeDatasetResponse = rekClient.describeDataset(describeDatasetRequest); DatasetDescription datasetDescription = describeDatasetResponse.datasetDescription(); status = datasetDescription.status(); logger.log(Level.INFO, " dataset ARN: {0} ", datasetArn); switch (status) { case UPDATE_COMPLETE: logger.log(Level.INFO, "Dataset updated"); distributed = true; break; case UPDATE_IN_PROGRESS: Thread.sleep(5000); break; case UPDATE_FAILED: String error = "Dataset distribution failed: " + datasetDescription.statusAsString() + " " + datasetDescription.statusMessage() + " " + datasetArn; logger.log(Level.SEVERE, error); break; default: String unexpectedError = "Unexpected distribution state: " + datasetDescription.statusAsString() + " " + datasetDescription.statusMessage() + " " + datasetArn; logger.log(Level.SEVERE, unexpectedError); } } while (distributed == false); return status; } public static void distributeMyDatasetEntries(RekognitionClient rekClient, String trainingDatasetArn, String testDatasetArn) throws Exception, RekognitionException { try { logger.log(Level.INFO, "Distributing {0} dataset to {1} ", new Object[] { trainingDatasetArn, testDatasetArn }); DistributeDataset distributeTrainingDataset = DistributeDataset.builder().arn(trainingDatasetArn).build(); DistributeDataset distributeTestDataset = DistributeDataset.builder().arn(testDatasetArn).build(); ArrayList<DistributeDataset> datasets = new ArrayList(); datasets.add(distributeTrainingDataset); datasets.add(distributeTestDataset); DistributeDatasetEntriesRequest distributeDatasetEntriesRequest = DistributeDatasetEntriesRequest.builder() .datasets(datasets).build(); rekClient.distributeDatasetEntries(distributeDatasetEntriesRequest); DatasetStatus trainingStatus = checkDatasetStatus(rekClient, trainingDatasetArn); DatasetStatus testStatus = checkDatasetStatus(rekClient, testDatasetArn); if (trainingStatus == DatasetStatus.UPDATE_COMPLETE && testStatus == DatasetStatus.UPDATE_COMPLETE) { logger.log(Level.INFO, "Successfully distributed dataset: {0}", trainingDatasetArn); } else { throw new Exception("Failed to distribute dataset: " + trainingDatasetArn); } } catch (RekognitionException e) { logger.log(Level.SEVERE, "Could not distribute dataset: {0}", e.getMessage()); throw e; } } public static void main(String[] args) { String trainingDatasetArn = null; String testDatasetArn = null; final String USAGE = "\n" + "Usage: " + "<training_dataset_arn> <test_dataset_arn>\n\n" + "Where:\n" + " training_dataset_arn - the ARN of the dataset that you want to distribute from.\n\n" + " test_dataset_arn - the ARN of the dataset that you want to distribute to.\n\n"; if (args.length != 2) { System.out.println(USAGE); System.exit(1); } trainingDatasetArn = args[0]; testDatasetArn = args[1]; try { // Get the Rekognition client. RekognitionClient rekClient = RekognitionClient.builder() .credentialsProvider(ProfileCredentialsProvider.create("custom-labels-access")) .region(Region.US_WEST_2) .build(); // Distribute the dataset distributeMyDatasetEntries(rekClient, trainingDatasetArn, testDatasetArn); System.out.println("Datasets distributed."); rekClient.close(); } catch (RekognitionException rekError) { logger.log(Level.SEVERE, "Rekognition client error: {0}", rekError.getMessage()); System.exit(1); } catch (Exception rekError) { logger.log(Level.SEVERE, "Error: {0}", rekError.getMessage()); System.exit(1); } } }