训练 Amazon Rekognition Custom Labels 模型 - Rekognition

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

训练 Amazon Rekognition Custom Labels 模型

可以使用 Amazon Rekognition Custom Labels 控制台或 Amazon Rekognition Custom Labels API 来训练模型。如果模型训练失败,请按照调试失败的模型训练中的说明查找失败的原因。

注意

您需要按照成功训练模型所花费的时间付费。通常,训练需要 30 分钟到 24 小时才能完成。有关更多信息,请参阅训练时长

每次训练模型都会创建一个新的模型版本。Amazon Rekognition Custom Labels 会为模型创建一个名称,该名称是项目名称和模型创建时的时间戳的组合。

为了训练您的模型,Amazon Rekognition Custom Labels 会复制您的源训练图像和测试图像。默认情况下,复制的图像使用 AWS 拥有和管理的密钥进行静态加密。您也可以选择使用自己的 AWS KMS key。如果使用自己的 KMS 密钥,则需要对该 KMS 密钥具有以下权限。

  • kms:CreateGrant

  • kms:DescribeKey

有关更多信息,请参阅 AWS Key Management Service 概念。源图像不受影响。

可以使用 KMS 服务器端加密 (SSE-KMS) 加密 Amazon S3 存储桶中的训练和测试图像,然后再将它们复制到 Amazon Rekognition Custom Labels 中。要允许 Amazon Rekognition Custom Labels 访问您的图像,您的 AWS 账户需要对 KMS 密钥拥有以下权限。

  • kms:GenerateDataKey

  • kms:Decrypt

有关更多信息,请参阅使用存储在 AWS Key Management Service 中的 KMS 密钥通过服务器端加密 (SSE-KMS) 保护数据

训练模型后,您可以评估其性能并进行改进。有关更多信息,请参阅改进经过训练的 Amazon Rekognition Custom Labels 模型

有关其他模型任务(例如标记模型),请参阅管理 Amazon Rekognition Custom Labels 模型

训练模型(控制台)

可以使用 Amazon Rekognition Custom Labels 控制台训练项目。

训练需要一个包含训练数据集和测试数据集的项目。如果项目没有测试数据集,Amazon Rekognition Custom Labels 控制台会在训练期间拆分训练数据集,为项目创建一个测试数据集。所选图像是具有代表性的采样,不会用于训练数据集。建议您仅在没有可供使用的替代测试数据集时才拆分训练数据集。拆分训练数据集会减少可用于训练的图像数量。

注意

您需要按照训练模型所花费的时间付费。有关更多信息,请参阅训练时长

训练模型(控制台)
  1. 通过以下网址打开 Amazon Rekognition 控制台:https://console.aws.amazon.com/rekognition/

  2. 选择使用自定义标签

  3. 在左侧导航窗格中,选择项目

  4. 项目页面上,选择包含要训练的模型的项目。

  5. 项目页面上,选择训练模型

  6. (可选)如果要使用自己的 AWS KMS 加密密钥,请执行以下操作:

    1. 图像数据加密中,选择自定义加密设置(高级)

    2. encryption.aws_kms_key 中,输入您的密钥的 Amazon 资源名称 (ARN),或者选择现有的 AWS KMS 密钥。要创建新密钥,请选择创建 AWS IMS 密钥

  7. (可选)如果要向模型添加标签,请执行以下操作:

    1. 标签部分中,选择添加新标签

    2. 输入以下信息:

      1. 中输入键名称。

      2. 中输入键值。

    3. 要添加更多标签,请重复步骤 6a 和 6b。

    4. (可选)如果要移除标签,请选择要移除的标签旁的移除。如果移除的是先前保存的标签,则会在保存更改时将其移除。

  8. 训练模型页面上,选择训练模型。项目的 Amazon 资源名称 (ARN) 应位于选择项目编辑框中。如果没有,请输入项目的 ARN。

  9. 是否要训练您的模型?对话框中,选择训练模型

  10. 在项目页面的模型部分,可以在 Model Status 列中查看当前状态,状态显示训练正在进行。训练模型需要一些时间才能完成。

  11. 训练完成后,选择模型名称。当模型状态为 TRAINING_COMPLETED 时,训练即告完成。如果训练失败,请参阅调试失败的模型训练

  12. 下一步:评估您的模型。有关更多信息,请参阅改进经过训练的 Amazon Rekognition Custom Labels 模型

训练模型 (SDK)

您可以通过调用 CreateProjectVersion 来训练模型。要训练模型,需要提供以下信息:

  • 名称:模型版本的唯一名称。

  • 项目 ARN:管理模型的项目的 Amazon 资源名称 (ARN)。

  • 训练结果位置:存放结果的 Amazon S3 位置。可以使用与控制台 Amazon S3 存储桶相同的位置,也可以选择其他位置。建议您选择其他位置,因为这样您就可以设置权限,并避免与使用 Amazon Rekognition Custom Labels 控制台时的训练输出发生潜在的命名冲突。

训练使用与项目关联的训练和测试数据集。有关更多信息,请参阅管理数据集

注意

或者,也可以指定项目外部的训练和测试数据集清单文件。如果在使用外部清单文件训练模型后打开控制台,Amazon Rekognition Custom Labels 会使用最后一组用于训练的清单文件为您创建数据集。不能再通过指定外部清单文件来训练项目的模型版本。有关更多信息,请参阅 CreatePrjectVersion

CreateProjectVersion 的响应是一个 ARN,用于在后续请求中识别模型版本。您还可以使用 ARN 来保护模型版本。有关更多信息,请参阅保护 Amazon Rekognition Custom Labels 项目

训练模型版本需要一些时间才能完成。本主题中的 Python 和 Java 示例使用 waiter 来等待训练完成。waiter 是一种实用程序方法,用于轮询是否发生了特定状态。或者,您也可以通过调用 DescribeProjectVersions 获取训练的当前状态。当 Status 字段的值为 TRAINING_COMPLETED 时,即表示训练已完成。训练完成后,您可以通过查看评估结果来评估模型的质量。

训练模型 (SDK)

以下示例说明了如何使用与项目关联的训练和测试数据集来训练模型。

训练模型 (SDK)
  1. 安装并配置 AWS CLI 和 AWS SDK(如果尚未如此)。有关更多信息,请参阅步骤 4:设置 AWS CLI 和 AWS 软件开发工具包

  2. 使用以下示例代码来训练项目。

    AWS CLI

    以下示例会创建模型。会拆分训练数据集以创建测试数据集。替换以下内容:

    • my_project_arn 替换为项目的 Amazon 资源名称 (ARN)。

    • version_name 替换为您选择的唯一版本名称。

    • output_bucket 替换为 Amazon Rekognition Custom Labels 保存训练结果的 Amazon S3 存储桶的名称。

    • output_folder 替换为保存训练结果的文件夹的名称。

    • (可选参数)将 --kms-key-id 替换为您的 AWS Key Management Service 客户主密钥的标识符。

    aws rekognition create-project-version \ --project-arn project_arn \ --version-name version_name \ --output-config '{"S3Bucket":"output_bucket", "S3KeyPrefix":"output_folder"}' \ --profile custom-labels-access
    Python

    以下示例会创建模型。提供以下命令行参数:

    • project_arn:项目的 Amazon 资源名称 (ARN)。

    • version_name:您选择的模型的唯一版本名称。

    • output_bucket:Amazon Rekognition Custom Labels 保存训练结果的 Amazon S3 存储桶的名称。

    • output_folder:保存训练结果的文件夹的名称。

    或者,提供以下命令行参数以将标签附加到模型:

    • tag:您选择的要附加到模型的标签名称。

    • tag_value:标签值。

    #Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. #PDX-License-Identifier: MIT-0 (For details, see https://github.com/awsdocs/amazon-rekognition-custom-labels-developer-guide/blob/master/LICENSE-SAMPLECODE.) import argparse import logging import json import boto3 from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def train_model(rek_client, project_arn, version_name, output_bucket, output_folder, tag_key, tag_key_value): """ Trains an Amazon Rekognition Custom Labels model. :param rek_client: The Amazon Rekognition Custom Labels Boto3 client. :param project_arn: The ARN of the project in which you want to train a model. :param version_name: A version for the model. :param output_bucket: The S3 bucket that hosts training output. :param output_folder: The path for the training output within output_bucket :param tag_key: The name of a tag to attach to the model. Pass None to exclude :param tag_key_value: The value of the tag. Pass None to exclude """ try: #Train the model status="" logger.info("training model version %s for project %s", version_name, project_arn) output_config = json.loads( '{"S3Bucket": "' + output_bucket + '", "S3KeyPrefix": "' + output_folder + '" } ' ) tags={} if tag_key is not None and tag_key_value is not None: tags = json.loads( '{"' + tag_key + '":"' + tag_key_value + '"}' ) response=rek_client.create_project_version( ProjectArn=project_arn, VersionName=version_name, OutputConfig=output_config, Tags=tags ) logger.info("Started training: %s", response['ProjectVersionArn']) # Wait for the project version training to complete. project_version_training_completed_waiter = rek_client.get_waiter('project_version_training_completed') project_version_training_completed_waiter.wait(ProjectArn=project_arn, VersionNames=[version_name]) # Get the completion status. describe_response=rek_client.describe_project_versions(ProjectArn=project_arn, VersionNames=[version_name]) for model in describe_response['ProjectVersionDescriptions']: logger.info("Status: %s", model['Status']) logger.info("Message: %s", model['StatusMessage']) status=model['Status'] logger.info("finished training") return response['ProjectVersionArn'], status except ClientError as err: logger.exception("Couldn't create model: %s", err.response['Error']['Message'] ) raise def add_arguments(parser): """ Adds command line arguments to the parser. :param parser: The command line parser. """ parser.add_argument( "project_arn", help="The ARN of the project in which you want to train a model" ) parser.add_argument( "version_name", help="A version name of your choosing." ) parser.add_argument( "output_bucket", help="The S3 bucket that receives the training results." ) parser.add_argument( "output_folder", help="The folder in the S3 bucket where training results are stored." ) parser.add_argument( "--tag_name", help="The name of a tag to attach to the model", required=False ) parser.add_argument( "--tag_value", help="The value for the tag.", required=False ) def main(): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print(f"Training model version {args.version_name} for project {args.project_arn}") # Train the model. session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") model_arn, status=train_model(rekognition_client, args.project_arn, args.version_name, args.output_bucket, args.output_folder, args.tag_name, args.tag_value) print(f"Finished training model: {model_arn}") print(f"Status: {status}") except ClientError as err: logger.exception("Problem training model: %s", err) print(f"Problem training model: {err}") except Exception as err: logger.exception("Problem training model: %s", err) print(f"Problem training model: {err}") if __name__ == "__main__": main()
    Java V2

    以下示例会训练模型。提供以下命令行参数:

    • project_arn:项目的 Amazon 资源名称 (ARN)。

    • version_name:您选择的模型的唯一版本名称。

    • output_bucket:Amazon Rekognition Custom Labels 保存训练结果的 Amazon S3 存储桶的名称。

    • output_folder:保存训练结果的文件夹的名称。

    /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier: Apache-2.0 */ package com.example.rekognition; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.core.waiters.WaiterResponse; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.rekognition.RekognitionClient; import software.amazon.awssdk.services.rekognition.model.CreateProjectVersionRequest; import software.amazon.awssdk.services.rekognition.model.CreateProjectVersionResponse; import software.amazon.awssdk.services.rekognition.model.DescribeProjectVersionsRequest; import software.amazon.awssdk.services.rekognition.model.DescribeProjectVersionsResponse; import software.amazon.awssdk.services.rekognition.model.OutputConfig; import software.amazon.awssdk.services.rekognition.model.ProjectVersionDescription; import software.amazon.awssdk.services.rekognition.model.RekognitionException; import software.amazon.awssdk.services.rekognition.waiters.RekognitionWaiter; import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; public class TrainModel { public static final Logger logger = Logger.getLogger(TrainModel.class.getName()); public static String trainMyModel(RekognitionClient rekClient, String projectArn, String versionName, String outputBucket, String outputFolder) { try { OutputConfig outputConfig = OutputConfig.builder().s3Bucket(outputBucket).s3KeyPrefix(outputFolder).build(); logger.log(Level.INFO, "Training Model for project {0}", projectArn); CreateProjectVersionRequest createProjectVersionRequest = CreateProjectVersionRequest.builder() .projectArn(projectArn).versionName(versionName).outputConfig(outputConfig).build(); CreateProjectVersionResponse response = rekClient.createProjectVersion(createProjectVersionRequest); logger.log(Level.INFO, "Model ARN: {0}", response.projectVersionArn()); logger.log(Level.INFO, "Training model..."); // wait until training completes DescribeProjectVersionsRequest describeProjectVersionsRequest = DescribeProjectVersionsRequest.builder() .versionNames(versionName) .projectArn(projectArn) .build(); RekognitionWaiter waiter = rekClient.waiter(); WaiterResponse<DescribeProjectVersionsResponse> waiterResponse = waiter .waitUntilProjectVersionTrainingCompleted(describeProjectVersionsRequest); Optional<DescribeProjectVersionsResponse> optionalResponse = waiterResponse.matched().response(); DescribeProjectVersionsResponse describeProjectVersionsResponse = optionalResponse.get(); for (ProjectVersionDescription projectVersionDescription : describeProjectVersionsResponse .projectVersionDescriptions()) { System.out.println("ARN: " + projectVersionDescription.projectVersionArn()); System.out.println("Status: " + projectVersionDescription.statusAsString()); System.out.println("Message: " + projectVersionDescription.statusMessage()); } return response.projectVersionArn(); } catch (RekognitionException e) { logger.log(Level.SEVERE, "Could not train model: {0}", e.getMessage()); throw e; } } public static void main(String args[]) { String versionName = null; String projectArn = null; String projectVersionArn = null; String bucket = null; String location = null; final String USAGE = "\n" + "Usage: " + "<project_name> <version_name> <output_bucket> <output_folder>\n\n" + "Where:\n" + " project_arn - The ARN of the project that you want to use. \n\n" + " version_name - A version name for the model.\n\n" + " output_bucket - The S3 bucket in which to place the training output. \n\n" + " output_folder - The folder within the bucket that the training output is stored in. \n\n"; if (args.length != 4) { System.out.println(USAGE); System.exit(1); } projectArn = args[0]; versionName = args[1]; bucket = args[2]; location = args[3]; try { // Get the Rekognition client. RekognitionClient rekClient = RekognitionClient.builder() .credentialsProvider(ProfileCredentialsProvider.create("custom-labels-access")) .region(Region.US_WEST_2) .build(); // Train model projectVersionArn = trainMyModel(rekClient, projectArn, versionName, bucket, location); System.out.println(String.format("Created model: %s for Project ARN: %s", projectVersionArn, projectArn)); rekClient.close(); } catch (RekognitionException rekError) { logger.log(Level.SEVERE, "Rekognition client error: {0}", rekError.getMessage()); System.exit(1); } } }
  3. 如果训练失败,请参阅调试失败的模型训练