Instructions pour l'utilisation de points de contrôle avec le parallélisme de tenseur - Amazon SageMaker

Instructions pour l'utilisation de points de contrôle avec le parallélisme de tenseur

La bibliothèque de parallélisme de modèles SageMaker prend en charge l'enregistrement de points de contrôle partiels ou complets avec le parallélisme de tenseur. Le guide suivant explique comment modifier le script pour enregistrer et charger un point de contrôle lors de l'utilisation du parallélisme de tenseur.

  1. Préparez un objet de modèle et enveloppez-le avec la fonction wrapper smp.DistributedModel() de la bibliothèque.

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Préparez un optimiseur pour le modèle. Un ensemble de paramètres de modèle est un argument itérable requis par les fonctions de l'optimiseur. Pour préparer un ensemble de paramètres de modèle, vous devez traiter model.parameters() pour attribuer des ID uniques à des paramètres de modèle individuels.

    Si plusieurs paramètres partagent le même ID dans l'argument itérable de paramètres de modèle, le chargement de l'état de l'optimiseur à points de contrôle échoue. Pour créer un argument itérable de paramètres de modèle avec des ID uniques pour l'optimiseur, consultez le code suivant :

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Enveloppez l'optimiseur à l'aide de la fonction wrapper smp.DistributedOptimizer() de la bibliothèque.

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Enregistrez le modèle et l'état de l'optimiseur à l'aide de smp.save(). Selon la manière dont vous souhaitez enregistrer les points de contrôle, choisissez l'une des deux options suivantes :

    • Option 1 : enregistrez un modèle partiel sur chaque mp_rank pour un MP_GROUP unique.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Avec le parallélisme de tenseur, la bibliothèque enregistre les fichiers à points de contrôle nommés selon le format suivant : checkpoint.pt_{pp_rank}_{tp_rank}.

      Note

      Avec le parallélisme de tenseur, assurez-vous de définir l'instruction if comme if smp.rdp_rank() == 0 et non comme if smp.dp_rank() == 0. Si l'état de l'optimiseur est partitionné avec un parallélisme de tenseur, tous les rangs parallèles aux données réduites doivent enregistrer leur propre partition de l'état de l'optimiseur. L'utilisation d'une mauvaise instruction if pour les points de contrôle peut entraîner un blocage de la tâche d'entraînement. Pour plus d'informations sur l'utilisation de if smp.dp_rank() == 0 sans parallélisme de tenseur, consultez la section General Instruction for Saving and Loading dans la documentation du kit SDK Python SageMaker.

    • Option 2 : enregistrez le modèle complet.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      Note

      Tenez compte des points suivants pour la création de points de contrôle complets :

      • Si vous définissez gather_to_rank0=True, tous les rangs autres que 0 renvoient des dictionnaires vides.

      • Pour la création de points de contrôle complets, vous ne pouvez créer des points de contrôle que pour le modèle. La création de points de contrôle complets des états de l'optimiseur n'est actuellement pas prise en charge.

      • Le modèle complet doit uniquement être enregistré sur smp.rank() == 0.

  5. Chargez les points de contrôle à l'aide de smp.load(). Selon la manière dont vous avez enregistré les points de contrôle à l'étape précédente, choisissez l'une des deux options suivantes :

    • Option 1 : chargez les points de contrôle partiels.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Vous pouvez définir same_partition_load=True dans model.load_state_dict() pour une charge plus rapide si vous savez que la partition ne changera pas.

    • Option 2 : chargez les points de contrôle complets.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      La condition if smp.rdp_rank() == 0 n'est pas nécessaire, mais elle peut aider à éviter un chargement redondant entre différents MP_GROUP. La création de points de contrôle complets du dictionnaire des états de l'optimiseur n'est actuellement pas prise en charge avec le parallélisme de tenseur.