모델의 오차 행렬 보기 - Rekognition

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

모델의 오차 행렬 보기

오차 행렬을 사용하면 모델이 모델의 다른 레이블과 혼동하는 레이블을 확인할 수 있습니다. 오차 행렬을 사용하면 모델 개선에 집중할 수 있습니다.

모델 평가 중에 Amazon Rekognition Custom Labels는 테스트 이미지를 사용하여 잘못 식별된(혼동된) 레이블을 식별함으로써 오차 행렬을 생성합니다. Amazon Rekognition Custom Labels는 분류 모델에 오차 행렬만 생성합니다. Amazon Rekognition Custom Labels가 모델 교육 중에 생성하는 요약 파일에서 분류 행렬에 액세스할 수 있습니다. Amazon Rekognition Custom Labels 콘솔에서는 오차 행렬을 볼 수 없습니다.

오차 행렬 사용

다음 표는 Rooms 이미지 분류 예제 프로젝트의 오차 행렬입니다. 열 제목은 테스트 이미지에 할당된 레이블(실측 정보 레이블)입니다. 행 제목은 모델이 테스트 이미지에 대해 예측하는 레이블입니다. 각 셀은 실측 정보 레이블(열)이 되어야 하는 레이블(행)에 대한 예측의 백분율입니다. 예를 들어, 욕실에 대한 예측의 67%가 욕실로 올바르게 레이블 지정되었고, 욕실의 33%가 주방으로 잘못 레이블 지정되었을 수 있습니다. 성능이 높은 모델은 예측 레이블이 실측 정보 레이블과 일치하여 셀 값이 높습니다. 이를 첫 번째 예측 레이블부터 마지막 예측 레이블 및 실측 정보 레이블까지 대각선으로 볼 수 있습니다. 셀 값이 0인 경우 셀의 실측 정보 레이블이어야 하는 셀의 예측 레이블에 대한 예측은 수행되지 않았다는 뜻입니다.

참고

모델은 비결정적이므로 Rooms 프로젝트를 훈련하여 얻은 오차 행렬 셀 값은 다음 표와 다를 수 있습니다.

오차 행렬은 집중해야 할 영역을 식별합니다. 예를 들어, 오차 행렬은 모델이 옷장을 침실과 혼동한 경우가 50%라는 것을 보여줍니다. 이 상황에서는 훈련 데이터 세트에 옷장과 침실 이미지를 더 추가해야 합니다. 또한 기존 옷장 및 침실 이미지에 레이블이 올바르게 지정되어 있는지도 확인해야 합니다. 이렇게 하면 모델이 두 레이블을 더 잘 구분할 수 있습니다. 데이터 세트에 이미지를 더 추가하려면 데이터 세트에 더 많은 이미지 추가 항목을 참조하세요.

오차 행렬도 유용하지만 다른 지표도 고려하는 것이 중요합니다. 예를 들어, 예측의 100%가 평면도 레이블을 올바르게 찾았는데, 이는 성능이 우수하다는 뜻입니다. 하지만 테스트 데이터 세트에는 평면도 레이블이 있는 이미지가 2개만 있습니다. 또한 거실 레이블이 지정된 11개의 이미지도 있습니다. 이러한 불균형은 훈련 데이터 세트(거실 이미지 13개, 옷장 이미지 2개)에도 있습니다. 더 정확하게 평가하려면 제대로 대표되지 않은 레이블의 이미지(이 예에서는 평면도)를 더 추가하여 훈련 데이터 세트와 테스트 데이터 세트의 균형을 맞추세요. 레이블당 테스트 이미지 수를 가져오려면 평가 지표 액세스(콘솔) 항목을 참조하세요.

실측 정보 레이블
예측 레이블 뒷마당 욕실 침실 옷장 진입로 평면도 앞마당 주방 거실 파티오
뒷마당 75% 0% 0% 0% 0% 0% 33% 0% 0% 0%
욕실 0% 67% 0% 0% 0% 0% 0% 0% 0% 0%
침실 0% 0% 82% 50% 0% 0% 0% 0% 9% 0%
옷장 0% 0% 0% 50% 0% 0% 0% 0% 0% 0%
진입로 0% 0% 0% 0% 33% 0% 0% 0% 0% 0%
평면도 0% 0% 0% 0% 0% 100% 0% 0% 0% 0%
앞마당 25% 0% 0% 0% 0% 0% 67% 0% 0% 0%
주방 0% 33% 0% 0% 0% 0% 0% 88% 0% 0%
거실 0% 0% 18% 0% 67% 0% 0% 12% 91% 33%
파티오 0% 0% 0% 0% 0% 0% 0% 0% 0% 67%

모델의 오차 행렬 가져오기

다음 코드는 DescribeProjectsDescribeProject버전 연산을 사용하여 모델의 요약 파일을 가져옵니다. 그런 다음 요약 파일을 사용하여 모델의 오차 행렬을 표시합니다.

모델의 오차 행렬 표시하기(SDK)
  1. 아직 설치하지 않았다면 및 AWS SDK를 설치하고 구성하세요. AWS CLI 자세한 정보는 4단계: 및 SDK 설정 AWS CLIAWS을 참조하세요.

  2. 다음 코드를 사용하여 모델의 오차 행렬을 표시할 수 있습니다. 다음 명령줄 인수를 제공하세요.

    • project_name: 사용하려는 프로젝트의 이름 Amazon Rekognition Custom Labels 콘솔의 프로젝트 페이지에서 프로젝트 이름을 가져올 수 있습니다.

    • version_name: 사용하려는 모델의 버전 Amazon Rekognition Custom Labels 콘솔의 프로젝트 세부 정보 페이지에서 버전 이름을 가져올 수 있습니다.

    # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """ Purpose Shows how to display the confusion matrix for an Amazon Rekognition Custom labels image classification model. """ import json import argparse import logging import boto3 import pandas as pd from botocore.exceptions import ClientError logger = logging.getLogger(__name__) def get_model_summary_location(rek_client, project_name, version_name): """ Get the summary file location for a model. :param rek_client: A Boto3 Rekognition client. :param project_arn: The Amazon Resource Name (ARN) of the project that contains the model. :param model_arn: The Amazon Resource Name (ARN) of the model. :return: The location of the model summary file. """ try: logger.info( "Getting summary file for model %s in project %s.", version_name, project_name) summary_location = "" # Get the project ARN from the project name. response = rek_client.describe_projects(ProjectNames=[project_name]) assert len(response['ProjectDescriptions']) > 0, \ f"Project {project_name} not found." project_arn = response['ProjectDescriptions'][0]['ProjectArn'] # Get the summary file location for the model. describe_response = rek_client.describe_project_versions(ProjectArn=project_arn, VersionNames=[version_name]) assert len(describe_response['ProjectVersionDescriptions']) > 0, \ f"Model {version_name} not found." model=describe_response['ProjectVersionDescriptions'][0] evaluation_results=model['EvaluationResult'] summary_location=(f"s3://{evaluation_results['Summary']['S3Object']['Bucket']}" f"/{evaluation_results['Summary']['S3Object']['Name']}") return summary_location except ClientError as err: logger.exception( "Couldn't get summary file location: %s", err.response['Error']['Message']) raise def show_confusion_matrix(summary): """ Shows the confusion matrix for an Amazon Rekognition Custom Labels image classification model. :param summary: The summary file JSON object. """ pd.options.display.float_format = '{:.0%}'.format # Load the model summary JSON into a DataFrame. summary_df = pd.DataFrame( summary['AggregatedEvaluationResults']['ConfusionMatrix']) # Get the confusion matrix. confusion_matrix = summary_df.pivot_table(index='PredictedLabel', columns='GroundTruthLabel', fill_value=0.0).astype(float) # Display the confusion matrix. print(confusion_matrix) def get_summary(s3_resource, summary): """ Gets the summary file. : return: The summary file in bytes. """ try: summary_bucket, summary_key = summary.replace( "s3://", "").split("/", 1) bucket = s3_resource.Bucket(summary_bucket) obj = bucket.Object(summary_key) body = obj.get()['Body'].read() logger.info( "Got summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) except ClientError: logger.exception( "Couldn't get summary file '%s' from bucket '%s'.", obj.key, obj.bucket_name) raise else: return body def add_arguments(parser): """ Adds command line arguments to the parser. : param parser: The command line parser. """ parser.add_argument( "project_name", help="The ARN of the project in which the model resides." ) parser.add_argument( "version_name", help="The version of the model that you want to describe." ) def main(): """ Entry point for script. """ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") try: # Get the command line arguments. parser = argparse.ArgumentParser(usage=argparse.SUPPRESS) add_arguments(parser) args = parser.parse_args() print( f"Showing confusion matrix for: {args.version_name} for project {args.project_name}.") session = boto3.Session(profile_name='custom-labels-access') rekognition_client = session.client("rekognition") s3_resource = session.resource('s3') # Get the summary file for the model. summary_location = get_model_summary_location(rekognition_client, args.project_name, args.version_name ) summary = json.loads(get_summary(s3_resource, summary_location)) # Check that the confusion matrix is available. assert 'ConfusionMatrix' in summary['AggregatedEvaluationResults'], \ "Confusion matrix not found in summary. Is the model a classification model?" # Show the confusion matrix. show_confusion_matrix(summary) print("Done") except ClientError as err: logger.exception("Problem showing confusion matrix: %s", err) print(f"Problem describing model: {err}") except AssertionError as err: logger.exception( "Error: %s.\n", err) print( f"Error: {err}\n") if __name__ == "__main__": main()