Créer une tâche de référence de qualité des modèles - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Créer une tâche de référence de qualité des modèles

Créez une tâche de référence qui compare les prédictions de votre modèle aux étiquettes Ground Truth d'un jeu de données de référence que vous avez stocké dans Amazon S3. En règle générale, vous utilisez un jeu de données d'entraînement comme jeu de données de référence. La tâche de référence calcule les métriques pour le modèle et suggère des contraintes à utiliser pour contrôler l'écart dans la qualité du modèle.

Pour créer une tâche de référence, vous devez disposer d'un jeu de données contenant des prédictions de votre modèle et des étiquettes Ground Truth de vos données.

Pour créer une tâche de référence, utilisez la ModelQualityMonitor classe fournie par le SDK SageMaker Python et effectuez les étapes suivantes.

Pour créer une tâche de référence de qualité de modèle
  1. Tout d'abord, créez une instance de la classe ModelQualityMonitor. L'exemple de code suivant vous montre comment procéder.

    from sagemaker import get_execution_role, session, Session from sagemaker.model_monitor import ModelQualityMonitor role = get_execution_role() session = Session() model_quality_monitor = ModelQualityMonitor( role=role, instance_count=1, instance_type='ml.m5.xlarge', volume_size_in_gb=20, max_runtime_in_seconds=1800, sagemaker_session=session )
  2. Maintenant, appelez la méthode suggest_baseline de l'objet ModelQualityMonitor pour exécuter une tâche de référence. L'extrait de code suivant suppose que le jeu de données de référence dont vous disposez contient des prédictions et des étiquettes stockées dans Amazon S3.

    baseline_job_name = "MyBaseLineJob" job = model_quality_monitor.suggest_baseline( job_name=baseline_job_name, baseline_dataset=baseline_dataset_uri, # The S3 location of the validation dataset. dataset_format=DatasetFormat.csv(header=True), output_s3_uri = baseline_results_uri, # The S3 location to store the results. problem_type='BinaryClassification', inference_attribute= "prediction", # The column in the dataset that contains predictions. probability_attribute= "probability", # The column in the dataset that contains probabilities. ground_truth_attribute= "label" # The column in the dataset that contains ground truth labels. ) job.wait(logs=False)
  3. Une fois la tâche de référence terminée, les contraintes générées par la tâche s'affichent. Tout d'abord, obtenez les résultats de la tâche de référence en appelant la méthode latest_baselining_job de l'objet ModelQualityMonitor.

    baseline_job = model_quality_monitor.latest_baselining_job
  4. La tâche de référence suggère des contraintes, qui sont des seuils pour les métriques mesurées par Model Monitor. Si une métrique dépasse le seuil suggéré, Model Monitor signale une violation. Pour afficher les contraintes générées par la tâche de référence, appelez la méthode suggested_constraints de la tâche de référence. L'extrait de code suivant charge les contraintes pour un modèle de classification binaire dans un dataframe Pandas.

    import pandas as pd pd.DataFrame(baseline_job.suggested_constraints().body_dict["binary_classification_constraints"]).T

    Nous vous recommandons d'afficher les contraintes générées et de les modifier si nécessaire avant de les utiliser pour la surveillance. Par exemple, si une contrainte est trop agressive, vous pourrez obtenir un nombre excessif d'alertes de violation.

    Si votre contrainte contient des nombres exprimés en notation scientifique, vous devrez les convertir en nombres à virgule flottante. L'exemple de script de prétraitement Python suivant montre comment convertir des nombres en notation scientifique en nombres à virgule flottante.

    import csv def fix_scientific_notation(col): try: return format(float(col), "f") except: return col def preprocess_handler(csv_line): reader = csv.reader([csv_line]) csv_record = next(reader) #skip baseline header, change HEADER_NAME to the first column's name if csv_record[0] == “HEADER_NAME”: return [] return { str(i).zfill(20) : fix_scientific_notation(d) for i, d in enumerate(csv_record)}

    Vous pouvez ajouter votre script de prétraitement à une base de référence ou à un calendrier de surveillance en tant que record_preprocessor_script, tel que défini dans la documentation de Model Monitor.

  5. Lorsque les contraintes vous conviennent, transmettez-les comme paramètre constraints dans le programme de surveillance que vous créez. Pour plus d’informations, consultez Planifier des tâches de surveillance de la qualité du modèle.

Les contraintes de référence suggérées sont contenues dans le fichier constraints.json à l'emplacement que vous spécifiez avec output_s3_uri. Pour de plus amples informations sur le schéma de ce fichier, veuillez consulter Schéma des contraintes (fichier constraints.json).