查看模型的混淆矩阵 - Rekognition

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

查看模型的混淆矩阵

您可以通过混淆矩阵查看模型与模型中的其他标签混淆的标签。通过使用混淆矩阵,您可以将改进的重点放在模型上。

在模型评估期间,Amazon Rekognition Custom Labels 会使用测试图像来识别错误识别(混淆)的标签,从而创建混淆矩阵。Amazon Rekognition Custom Labels 只会为分类模型创建混淆矩阵。可以从 Amazon Rekognition Custom Labels 在模型训练期间创建的摘要文件中获取分类矩阵。无法在 Amazon Rekognition Custom Labels 控制台中查看混淆矩阵。

使用混淆矩阵

下表是 Rooms 图像分类示例项目的混淆矩阵。列标题是分配给测试图像的标签(ground truth 标签)。行标题是模型为测试图像预测的标签。每个单元格是对标签(行)应为 ground truth 标签(列)的预测的百分比。例如,对浴室的预测有 67% 被正确地标注为浴室。33% 的浴室被错误地标注为厨房。当预测的标签与 ground truth 标签匹配时,高性能模型具有高单元格值。可以将这些看作是从第一个预测和 ground truth 标签到最后一个预测和 ground truth 标签的对角线。如果单元格值为 0,则表示对单元格的预测标签应为单元格的 ground truth 标签的预测为 0。

注意

由于模型是不确定的,您通过训练 Rooms 项目获得的混淆矩阵单元格值可能与下表不同。

混淆矩阵确定了需要关注的领域。例如,混淆矩阵显示,模型有 50% 的时间将衣柜与卧室混淆。在这种情况下,您就应该在训练数据集中添加更多衣柜和卧室的图像。此外,还要检查现有的衣柜和卧室图像是否被正确标注。这应该有助于模型更好地区分这两个标签。要向数据集中添加更多图像,请参阅向数据集中添加更多图像

虽然混淆矩阵很有帮助,但考虑其他指标也很重要。例如,100% 的预测正确找到了 floor_plan 标签,这表明性能优异。但是,测试数据集只有 2 张带有 floor_plan 标签的图像。它还有 11 张带有 living_space 标签的图像。这种不平衡也存在于训练数据集中(13 张 living_space 图像和 2 张衣柜图像)。要获得更准确的评估,请通过添加更多代表不足的标签的图像(本示例中的平面图)来平衡训练和测试数据集。要获取每个标签的测试图像数量,请参阅获取评估指标(控制台)

Ground Truth 标签
预测标签 backyard bathroom bedroom closet entry_way floor_plan front_yard kitchen living_space patio
backyard 75% 0% 0% 0% 0% 0% 33% 0% 0% 0%
bathroom 0% 67% 0% 0% 0% 0% 0% 0% 0% 0%
bedroom 0% 0% 82% 50% 0% 0% 0% 0% 9% 0%
closet 0% 0% 0% 50% 0% 0% 0% 0% 0% 0%
entry_way 0% 0% 0% 0% 33% 0% 0% 0% 0% 0%
floor_plan 0% 0% 0% 0% 0% 100% 0% 0% 0% 0%
front_yard 25% 0% 0% 0% 0% 0% 67% 0% 0% 0%
kitchen 0% 33% 0% 0% 0% 0% 0% 88% 0% 0%
living_space 0% 0% 18% 0% 67% 0% 0% 12% 91% 33%
patio 0% 0% 0% 0% 0% 0% 0% 0% 0% 67%

获取模型的混淆矩阵

以下代码使用DescribeProjectsDescribeProject版本操作来获取模型的摘要文件。然后,使用摘要文件来显示模型的混淆矩阵。

显示模型的混淆矩阵 (SDK)
  1. 如果您尚未这样做,请安装和配置和 AWS SDK。 AWS CLI 有关更多信息,请参阅 步骤 4:设置 AWS CLI 和 AWS 软件开发工具包

  2. 使用以下代码显示模型的混淆矩阵。提供以下命令行参数:

    • project_name:要使用的项目的名称。可以从 Amazon Rekognition Custom Labels 控制台的项目页面获取项目名称。

    • version_name:要使用的模型版本。可以从 Amazon Rekognition Custom Labels 控制台的项目详细信息页面获取版本名称。

    # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """ Purpose Shows how to display the confusion matrix for an Amazon Rekognition Custom labels image classification model. """ import json import argparse import logging import boto3 import pandas as pd from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def get_model_summary_location(rek_client, project_name, version_name): """ Get the summary file location for a model. :param rek_client: A Boto3 Rekognition client. :param project_arn: The Amazon Resource Name (ARN) of the project that contains the model. :param model_arn: The Amazon Resource Name (ARN) of the model. :return: The location of the model summary file. """ try: logger.info( "Getting summary file for model %s in project %s.", version_name, project_name) summary_location = "" # Get the project ARN from the project name. response = rek_client.describe_projects(ProjectNames=[project_name]) assert len(response['ProjectDescriptions']) > 0, \ f"Project {project_name} not found." project_arn = response['ProjectDescriptions'][0]['ProjectArn'] # Get the summary file location for the model. describe_response = rek_client.describe_project_versions(ProjectArn=project_arn, VersionNames=[version_name]) assert len(describe_response['ProjectVersionDescriptions']) > 0, \ f"Model {version_name} not found." model=describe_response['ProjectVersionDescriptions'][0] evaluation_results=model['EvaluationResult'] summary_location=(f"s3://{evaluation_results['Summary']['S3Object']['Bucket']}" f"/{evaluation_results['Summary']['S3Object']['Name']}") return summary_location except ClientError as err: logger.exception( "Couldn't get summary file location: %s", err.response['Error']['Message']) raise def show_confusion_matrix(summary): """ Shows the confusion matrix for an Amazon Rekognition Custom Labels image classification model. :param summary: The summary file JSON object. """ pd.options.display.float_format = '{:.0%}'.format # Load the model summary JSON into a DataFrame. summary_df = pd.DataFrame( summary['AggregatedEvaluationResults']['ConfusionMatrix']) # Get the confusion matrix. confusion_matrix = summary_df.pivot_table(index='PredictedLabel', columns='GroundTruthLabel', fill_value=0.0).astype(float) # Display the confusion matrix. print(confusion_matrix) def get_summary(s3_resource, summary): """ Gets the summary file. : return: The summary file in bytes. """ try: summary_bucket, summary_key = summary.replace( "s3://", "").split("/", 1) bucket = s3_resource.Bucket(summary_bucket) obj = bucket.Object(summary_key) body = obj.get()['Body'].read() logger.info( "Got summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) except ClientError: logger.exception( "Couldn't get summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) raise else: return body def add_arguments(parser): """ Adds command line arguments to the parser. : param parser: The command line parser. """ parser.add_argument( "project_name", help="The ARN of the project in which the model resides." ) parser.add_argument( "version_name", help="The version of the model that you want to describe." ) def main(): """ Entry point for script. """ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get the command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print( f"Showing confusion matrix for: {args.version_name} for project {args.project_name}.") session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") s3_resource = session.resource('s3') # Get the summary file for the model. summary_location = get_model_summary_location(rekognition_client, args.project_name, args.version_name ) summary = json.loads(get_summary(s3_resource, summary_location)) # Check that the confusion matrix is available. assert 'ConfusionMatrix' in summary['AggregatedEvaluationResults'], \ "Confusion matrix not found in summary. Is the model a classification model?" # Show the confusion matrix. show_confusion_matrix(summary) print("Done") except ClientError as err: logger.exception("Problem showing confusion matrix: %s", err) print(f"Problem describing model: {err}") except AssertionError as err: logger.exception( "Error: %s.\n", err) print( f"Error: {err}\n") if __name__ == "__main__": main()