训练自定义 Amazon Comprehend 分类器并使用软件开发工具包对文档进行分类 AWS - Amazon Comprehend

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

训练自定义 Amazon Comprehend 分类器并使用软件开发工具包对文档进行分类 AWS

以下代码示例展示了如何:

  • 创建 Amazon Comprehend 多标签分类器。

  • 在示例数据上训练分类器。

  • 对第二组数据运行分类任务。

  • 从 Amazon S3 提取任务输出数据。

Python
SDK for Python (Boto3)
注意

还有更多相关信息 GitHub。在 AWS 代码示例存储库中查找完整示例,了解如何进行设置和运行。

创建一个包装类来调用 Amazon Comprehend 文档分类器操作。

class ComprehendClassifier: """Encapsulates an Amazon Comprehend custom classifier.""" def __init__(self, comprehend_client): """ :param comprehend_client: A Boto3 Comprehend client. """ self.comprehend_client = comprehend_client self.classifier_arn = None def create( self, name, language_code, training_bucket, training_key, data_access_role_arn, mode, ): """ Creates a custom classifier. After the classifier is created, it immediately starts training on the data found in the specified Amazon S3 bucket. Training can take 30 minutes or longer. The `describe_document_classifier` function can be used to get training status and returns a status of TRAINED when the classifier is ready to use. :param name: The name of the classifier. :param language_code: The language the classifier can operate on. :param training_bucket: The Amazon S3 bucket that contains the training data. :param training_key: The prefix used to find training data in the training bucket. If multiple objects have the same prefix, all of them are used. :param data_access_role_arn: The Amazon Resource Name (ARN) of a role that grants Comprehend permission to read from the training bucket. :return: The ARN of the newly created classifier. """ try: response = self.comprehend_client.create_document_classifier( DocumentClassifierName=name, LanguageCode=language_code, InputDataConfig={"S3Uri": f"s3://{training_bucket}/{training_key}"}, DataAccessRoleArn=data_access_role_arn, Mode=mode.value, ) self.classifier_arn = response["DocumentClassifierArn"] logger.info("Started classifier creation. Arn is: %s.", self.classifier_arn) except ClientError: logger.exception("Couldn't create classifier %s.", name) raise else: return self.classifier_arn def describe(self, classifier_arn=None): """ Gets metadata about a custom classifier, including its current status. :param classifier_arn: The ARN of the classifier to look up. :return: Metadata about the classifier. """ if classifier_arn is not None: self.classifier_arn = classifier_arn try: response = self.comprehend_client.describe_document_classifier( DocumentClassifierArn=self.classifier_arn ) classifier = response["DocumentClassifierProperties"] logger.info("Got classifier %s.", self.classifier_arn) except ClientError: logger.exception("Couldn't get classifier %s.", self.classifier_arn) raise else: return classifier def list(self): """ Lists custom classifiers for the current account. :return: The list of classifiers. """ try: response = self.comprehend_client.list_document_classifiers() classifiers = response["DocumentClassifierPropertiesList"] logger.info("Got %s classifiers.", len(classifiers)) except ClientError: logger.exception( "Couldn't get classifiers.", ) raise else: return classifiers def delete(self): """ Deletes the classifier. """ try: self.comprehend_client.delete_document_classifier( DocumentClassifierArn=self.classifier_arn ) logger.info("Deleted classifier %s.", self.classifier_arn) self.classifier_arn = None except ClientError: logger.exception("Couldn't deleted classifier %s.", self.classifier_arn) raise def start_job( self, job_name, input_bucket, input_key, input_format, output_bucket, output_key, data_access_role_arn, ): """ Starts a classification job. The classifier must be trained or the job will fail. Input is read from the specified Amazon S3 input bucket and written to the specified output bucket. Output data is stored in a tar archive compressed in gzip format. The job runs asynchronously, so you can call `describe_document_classification_job` to get job status until it returns a status of SUCCEEDED. :param job_name: The name of the job. :param input_bucket: The Amazon S3 bucket that contains input data. :param input_key: The prefix used to find input data in the input bucket. If multiple objects have the same prefix, all of them are used. :param input_format: The format of the input data, either one document per file or one document per line. :param output_bucket: The Amazon S3 bucket where output data is written. :param output_key: The prefix prepended to the output data. :param data_access_role_arn: The Amazon Resource Name (ARN) of a role that grants Comprehend permission to read from the input bucket and write to the output bucket. :return: Information about the job, including the job ID. """ try: response = self.comprehend_client.start_document_classification_job( DocumentClassifierArn=self.classifier_arn, JobName=job_name, InputDataConfig={ "S3Uri": f"s3://{input_bucket}/{input_key}", "InputFormat": input_format.value, }, OutputDataConfig={"S3Uri": f"s3://{output_bucket}/{output_key}"}, DataAccessRoleArn=data_access_role_arn, ) logger.info( "Document classification job %s is %s.", job_name, response["JobStatus"] ) except ClientError: logger.exception("Couldn't start classification job %s.", job_name) raise else: return response def describe_job(self, job_id): """ Gets metadata about a classification job. :param job_id: The ID of the job to look up. :return: Metadata about the job. """ try: response = self.comprehend_client.describe_document_classification_job( JobId=job_id ) job = response["DocumentClassificationJobProperties"] logger.info("Got classification job %s.", job["JobName"]) except ClientError: logger.exception("Couldn't get classification job %s.", job_id) raise else: return job def list_jobs(self): """ Lists the classification jobs for the current account. :return: The list of jobs. """ try: response = self.comprehend_client.list_document_classification_jobs() jobs = response["DocumentClassificationJobPropertiesList"] logger.info("Got %s document classification jobs.", len(jobs)) except ClientError: logger.exception( "Couldn't get document classification jobs.", ) raise else: return jobs

创建帮组运行场景的类。

class ClassifierDemo: """ Encapsulates functions used to run the demonstration. """ def __init__(self, demo_resources): """ :param demo_resources: A ComprehendDemoResources class that manages resources for the demonstration. """ self.demo_resources = demo_resources self.training_prefix = "training/" self.input_prefix = "input/" self.input_format = JobInputFormat.per_line self.output_prefix = "output/" def setup(self): """Creates AWS resources used by the demo.""" self.demo_resources.setup("comprehend-classifier-demo") def cleanup(self): """Deletes AWS resources used by the demo.""" self.demo_resources.cleanup() @staticmethod def _sanitize_text(text): """Removes characters that cause errors for the document parser.""" return text.replace("\r", " ").replace("\n", " ").replace(",", ";") @staticmethod def _get_issues(query, issue_count): """ Gets issues from GitHub using the specified query parameters. :param query: The query string used to request issues from the GitHub API. :param issue_count: The number of issues to retrieve. :return: The list of issues retrieved from GitHub. """ issues = [] logger.info("Requesting issues from %s?%s.", GITHUB_SEARCH_URL, query) response = requests.get(f"{GITHUB_SEARCH_URL}?{query}&per_page={issue_count}") if response.status_code == 200: issue_page = response.json()["items"] logger.info("Got %s issues.", len(issue_page)) issues = [ { "title": ClassifierDemo._sanitize_text(issue["title"]), "body": ClassifierDemo._sanitize_text(issue["body"]), "labels": {label["name"] for label in issue["labels"]}, } for issue in issue_page ] else: logger.error( "GitHub returned error code %s with message %s.", response.status_code, response.json(), ) logger.info("Found %s issues.", len(issues)) return issues def get_training_issues(self, training_labels): """ Gets issues used for training the custom classifier. Training issues are closed issues from the Boto3 repo that have known labels. Comprehend requires a minimum of ten training issues per label. :param training_labels: The issue labels to use for training. :return: The set of issues used for training. """ issues = [] per_label_count = 15 for label in training_labels: issues += self._get_issues( f"q=type:issue+repo:boto/boto3+state:closed+label:{label}", per_label_count, ) for issue in issues: issue["labels"] = issue["labels"].intersection(training_labels) return issues def get_input_issues(self, training_labels): """ Gets input issues from GitHub. For demonstration purposes, input issues are open issues from the Boto3 repo with known labels, though in practice any issue could be submitted to the classifier for labeling. :param training_labels: The set of labels to query for. :return: The set of issues used for input. """ issues = [] per_label_count = 5 for label in training_labels: issues += self._get_issues( f"q=type:issue+repo:boto/boto3+state:open+label:{label}", per_label_count, ) return issues def upload_issue_data(self, issues, training=False): """ Uploads issue data to an Amazon S3 bucket, either for training or for input. The data is first put into the format expected by Comprehend. For training, the set of pipe-delimited labels is prepended to each document. For input, labels are not sent. :param issues: The set of issues to upload to Amazon S3. :param training: Indicates whether the issue data is used for training or input. """ try: obj_key = ( self.training_prefix if training else self.input_prefix ) + "issues.txt" if training: issue_strings = [ f"{'|'.join(issue['labels'])},{issue['title']} {issue['body']}" for issue in issues ] else: issue_strings = [ f"{issue['title']} {issue['body']}" for issue in issues ] issue_bytes = BytesIO("\n".join(issue_strings).encode("utf-8")) self.demo_resources.bucket.upload_fileobj(issue_bytes, obj_key) logger.info( "Uploaded data as %s to bucket %s.", obj_key, self.demo_resources.bucket.name, ) except ClientError: logger.exception( "Couldn't upload data to bucket %s.", self.demo_resources.bucket.name ) raise def extract_job_output(self, job): """Extracts job output from Amazon S3.""" return self.demo_resources.extract_job_output(job) @staticmethod def reconcile_job_output(input_issues, output_dict): """ Reconciles job output with the list of input issues. Because the input issues have known labels, these can be compared with the labels added by the classifier to judge the accuracy of the output. :param input_issues: The list of issues used as input. :param output_dict: The dictionary of data that is output by the classifier. :return: The list of reconciled input and output data. """ reconciled = [] for archive in output_dict.values(): for line in archive["data"]: in_line = int(line["Line"]) in_labels = input_issues[in_line]["labels"] out_labels = { label["Name"] for label in line["Labels"] if float(label["Score"]) > 0.3 } reconciled.append( f"{line['File']}, line {in_line} has labels {in_labels}.\n" f"\tClassifier assigned {out_labels}." ) logger.info("Reconciled input and output labels.") return reconciled

使用已知标签对分类器进行一系列 GitHub 问题训练,然后将第二组 GitHub 问题发送给分类器以便对其进行标记。

def usage_demo(): print("-" * 88) print("Welcome to the Amazon Comprehend custom document classifier demo!") print("-" * 88) logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") comp_demo = ClassifierDemo( ComprehendDemoResources(boto3.resource("s3"), boto3.resource("iam")) ) comp_classifier = ComprehendClassifier(boto3.client("comprehend")) classifier_trained_waiter = ClassifierTrainedWaiter( comp_classifier.comprehend_client ) training_labels = {"bug", "feature-request", "dynamodb", "s3"} print("Setting up storage and security resources needed for the demo.") comp_demo.setup() print("Getting training data from GitHub and uploading it to Amazon S3.") training_issues = comp_demo.get_training_issues(training_labels) comp_demo.upload_issue_data(training_issues, True) classifier_name = "doc-example-classifier" print(f"Creating document classifier {classifier_name}.") comp_classifier.create( classifier_name, "en", comp_demo.demo_resources.bucket.name, comp_demo.training_prefix, comp_demo.demo_resources.data_access_role.arn, ClassifierMode.multi_label, ) print( f"Waiting until {classifier_name} is trained. This typically takes " f"30–40 minutes." ) classifier_trained_waiter.wait(comp_classifier.classifier_arn) print(f"Classifier {classifier_name} is trained:") pprint(comp_classifier.describe()) print("Getting input data from GitHub and uploading it to Amazon S3.") input_issues = comp_demo.get_input_issues(training_labels) comp_demo.upload_issue_data(input_issues) print("Starting classification job on input data.") job_info = comp_classifier.start_job( "issue_classification_job", comp_demo.demo_resources.bucket.name, comp_demo.input_prefix, comp_demo.input_format, comp_demo.demo_resources.bucket.name, comp_demo.output_prefix, comp_demo.demo_resources.data_access_role.arn, ) print(f"Waiting for job {job_info['JobId']} to complete.") job_waiter = JobCompleteWaiter(comp_classifier.comprehend_client) job_waiter.wait(job_info["JobId"]) job = comp_classifier.describe_job(job_info["JobId"]) print(f"Job {job['JobId']} complete:") pprint(job) print( f"Getting job output data from Amazon S3: " f"{job['OutputDataConfig']['S3Uri']}." ) job_output = comp_demo.extract_job_output(job) print("Job output:") pprint(job_output) print("Reconciling job output with labels from GitHub:") reconciled_output = comp_demo.reconcile_job_output(input_issues, job_output) print(*reconciled_output, sep="\n") answer = input(f"Do you want to delete the classifier {classifier_name} (y/n)? ") if answer.lower() == "y": print(f"Deleting {classifier_name}.") comp_classifier.delete() print("Cleaning up resources created for the demo.") comp_demo.cleanup() print("Thanks for watching!") print("-" * 88)

有关 S AWS DK 开发者指南和代码示例的完整列表,请参阅将 Amazon Comprehend 与 SDK 配合 AWS 使用。本主题还包括有关入门的信息以及有关先前的 SDK 版本的详细信息。