檢視模型的混淆矩陣 - Rekognition

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

檢視模型的混淆矩陣

混淆矩陣可讓您查看模型與模型中其他標籤混淆的標籤。透過使用混淆矩陣,您可以將改進重點集中在模型上。

在模型評估期間,Amazon Rekognition 自訂標籤會使用測試影像來識別錯誤識別 (混淆) 的標籤,藉此建立混淆矩陣。Amazon Rekognition 自訂標籤只會建立分類模型的混淆矩陣。分類矩陣可以從 Amazon Rekognition 自訂標籤在模型訓練期間建立的摘要檔案存取。您無法在 Amazon Rekognition 自訂標籤主控台中檢視混淆矩陣。

使用混淆矩陣

下表是 Rooms 影像分類範例專案的混淆矩陣。欄標題是分配給測試影像的標籤 (Ground Truth 標籤)。列標題是模型針對測試影像所預測的標籤。每個儲存格都是標籤 (列) 的預測百分比,並應為 Ground Truth 標籤 (資料欄)。例如,針對浴室所做的預測中有 67% 會正確地標記為浴室。33% 的浴室則會錯誤地標記為廚房。當預測的標籤和 Ground Truth 標籤相符時,高效能模型會具有較高的儲存格值。您可以看到這些會呈現對角線:從第一個到最後一個預測和 Ground Truth 標籤。如果儲存格值為 0,則不會針對該儲存格的預測標籤進行預測,該標籤應為儲存格的 Ground Truth 標籤。

注意

由於模型不具確定性,因此您從訓練 Rooms 專案所取得的混淆矩陣儲存格值可能會與下表不同。

混淆矩陣可識別要專注的區域。例如,混淆矩陣會顯示模型有 50% 的時間會將衣櫥和臥室混淆。在此情況下,您應該在訓練資料集中新增更多衣櫥和臥室的影像。還要檢查現有的衣櫥和臥室影像是否標記正確。這應該有助於模型將兩個標籤區分得更清楚。若要將更多影像新增至資料集,請參閱 將更多影像新增至資料集

雖然混淆矩陣很實用,但請務必考慮其他指標。例如,100% 的預測正確地找到了 floor_plan 標籤,這表示效能極為卓越。但是,測試資料集只有 2 個具有 floor_plan 標籤的影像。它還有 11 個具有 living_space 標籤的影像。這種不平衡的現象也出現在訓練資料集中 (13 個 living_space 影像和 2 個衣櫥影像)。為了取得更準確的評估,請透過新增更多代表性不足的標籤影像 (此範例中為建築平面圖) 來平衡訓練和測試資料集。若要取得每個標籤的測試影像數目,請參閱 存取評估指標 (主控台)

Ground Truth 標籤
預測標籤 後院 浴室 臥室 衣櫥 入口 建築平面圖 前院 廚房 living_space 露台
後院 75% 0% 0% 0% 0% 0% 33% 0% 0% 0%
浴室 0% 67% 0% 0% 0% 0% 0% 0% 0% 0%
臥室 0% 0% 82% 50% 0% 0% 0% 0% 9% 0%
衣櫥 0% 0% 0% 50% 0% 0% 0% 0% 0% 0%
入口 0% 0% 0% 0% 33% 0% 0% 0% 0% 0%
建築平面圖 0% 0% 0% 0% 0% 100% 0% 0% 0% 0%
前院 25% 0% 0% 0% 0% 0% 67% 0% 0% 0%
廚房 0% 33% 0% 0% 0% 0% 0% 88% 0% 0%
living_space 0% 0% 18% 0% 67% 0% 0% 12% 91% 33%
露台 0% 0% 0% 0% 0% 0% 0% 0% 0% 67%

檢視模型的混淆矩陣

下列程式碼會使用DescribeProjects和 Ver DescribeProjectsions 作業來取得模型的摘要檔案。然後,它會使用摘要檔案來顯示模型的混淆矩陣。

顯示模型的混淆矩陣 (SDK)
  1. 如果您尚未這樣做,請安裝並設定 AWS CLI 和 AWS SDK。如需詳細資訊,請參閱 步驟 4:設定 AWS CLI 和開 AWS 發套件

  2. 使用下列程式碼來顯示模型的混淆矩陣。提供下列命令列參數:

    • project_name — 您要使用的專案名稱。您可以從 Amazon Rekognition 自訂標籤主控台中的專案頁面取得專案名稱。

    • version_name — 您要使用的模型版本。您可以從 Amazon Rekognition 自訂標籤主控台中的專案詳細資料頁面取得版本名稱。

    # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """ Purpose Shows how to display the confusion matrix for an Amazon Rekognition Custom labels image classification model. """ import json import argparse import logging import boto3 import pandas as pd from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def get_model_summary_location(rek_client, project_name, version_name): """ Get the summary file location for a model. :param rek_client: A Boto3 Rekognition client. :param project_arn: The Amazon Resource Name (ARN) of the project that contains the model. :param model_arn: The Amazon Resource Name (ARN) of the model. :return: The location of the model summary file. """ try: logger.info( "Getting summary file for model %s in project %s.", version_name, project_name) summary_location = "" # Get the project ARN from the project name. response = rek_client.describe_projects(ProjectNames=[project_name]) assert len(response['ProjectDescriptions']) > 0, \ f"Project {project_name} not found." project_arn = response['ProjectDescriptions'][0]['ProjectArn'] # Get the summary file location for the model. describe_response = rek_client.describe_project_versions(ProjectArn=project_arn, VersionNames=[version_name]) assert len(describe_response['ProjectVersionDescriptions']) > 0, \ f"Model {version_name} not found." model=describe_response['ProjectVersionDescriptions'][0] evaluation_results=model['EvaluationResult'] summary_location=(f"s3://{evaluation_results['Summary']['S3Object']['Bucket']}" f"/{evaluation_results['Summary']['S3Object']['Name']}") return summary_location except ClientError as err: logger.exception( "Couldn't get summary file location: %s", err.response['Error']['Message']) raise def show_confusion_matrix(summary): """ Shows the confusion matrix for an Amazon Rekognition Custom Labels image classification model. :param summary: The summary file JSON object. """ pd.options.display.float_format = '{:.0%}'.format # Load the model summary JSON into a DataFrame. summary_df = pd.DataFrame( summary['AggregatedEvaluationResults']['ConfusionMatrix']) # Get the confusion matrix. confusion_matrix = summary_df.pivot_table(index='PredictedLabel', columns='GroundTruthLabel', fill_value=0.0).astype(float) # Display the confusion matrix. print(confusion_matrix) def get_summary(s3_resource, summary): """ Gets the summary file. : return: The summary file in bytes. """ try: summary_bucket, summary_key = summary.replace( "s3://", "").split("/", 1) bucket = s3_resource.Bucket(summary_bucket) obj = bucket.Object(summary_key) body = obj.get()['Body'].read() logger.info( "Got summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) except ClientError: logger.exception( "Couldn't get summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) raise else: return body def add_arguments(parser): """ Adds command line arguments to the parser. : param parser: The command line parser. """ parser.add_argument( "project_name", help="The ARN of the project in which the model resides." ) parser.add_argument( "version_name", help="The version of the model that you want to describe." ) def main(): """ Entry point for script. """ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get the command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print( f"Showing confusion matrix for: {args.version_name} for project {args.project_name}.") session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") s3_resource = session.resource('s3') # Get the summary file for the model. summary_location = get_model_summary_location(rekognition_client, args.project_name, args.version_name ) summary = json.loads(get_summary(s3_resource, summary_location)) # Check that the confusion matrix is available. assert 'ConfusionMatrix' in summary['AggregatedEvaluationResults'], \ "Confusion matrix not found in summary. Is the model a classification model?" # Show the confusion matrix. show_confusion_matrix(summary) print("Done") except ClientError as err: logger.exception("Problem showing confusion matrix: %s", err) print(f"Problem describing model: {err}") except AssertionError as err: logger.exception( "Error: %s.\n", err) print( f"Error: {err}\n") if __name__ == "__main__": main()