Guarde y cargue los puntos de control mientras usa SMP - Amazon SageMaker

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

Guarde y cargue los puntos de control mientras usa SMP

La biblioteca SMP admite PyTorch las API para los puntos de control y proporciona API que ayudan a realizar los controles correctamente mientras se utiliza la biblioteca SMP.

PyTorch El FSDP admite tres tipos de puntos de control: completos, fragmentados y locales. Estos sirven para diferentes propósitos. Lo ideal sería utilizar un punto de control completo solo al exportar el modelo una vez finalizado el entrenamiento, ya que generar un punto de control completo es caro. El punto de control fragmentado es el enfoque recomendado para guardar y cargar los puntos de control durante el entrenamiento. Al utilizar puntos de control fragmentados, también puedes cambiar el tamaño del clúster al reanudar el entrenamiento. Los puntos de control locales son más restrictivos. Con los puntos de control locales, es necesario reanudar el entrenamiento con el mismo número de GPU y, actualmente, no se admite el paralelismo tensorial con SMP. Tenga en cuenta que los puntos de control del FSDP requieren la escritura en un sistema de archivos de red compartido, como FSx.

Puntos de control compartidos

El siguiente procedimiento describe lo que debe hacer para adaptar su guion de entrenamiento a fin de guardar y cargar puntos de control fragmentados con o sin la función de paralelismo tensorial SMP.

  1. torch.sagemakerImporte el paquete SMP.

    import torch.sagemaker as tsm
  2. Configure variables auxiliares para guardar y cargar los puntos de control.

    1. Establezca un rango de coordinador para realizar operaciones colectivas comunicativas, tales como. AllReduce

      coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group))
    2. Con las torch.sagemaker.state enumeraciones, configure el rango de acción para determinar si se debe permitir que los rangos participen en los puntos de control. Y añada una sentencia if para guardar los puntos de control en función del uso del paralelismo tensorial SMP v2.

      action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = ""
  3. Siga usando las API de puntos de control del PyTorch FSDP tal como están.

El siguiente ejemplo de código muestra un script de entrenamiento completo del PyTorch FSDP con las API de puntos de control del FSDP.

import torch.distributed as dist from torch.distributed.checkpoint.optimizer import ( load_sharded_optimizer_state_dict ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType ) import torch.sagemaker as tsm sharding_strategy, state_dict_type = ..., ... global_rank = dist.get_rank() # 0. Auxiliary variables to save and load checkpoints. # Used when performing comm collectives such as allreduce. coordinator_rank: int = min(dist.get_process_group_ranks(model.process_group)) # To determine whether to take part in checkpointing. action_rank: bool = global_rank < (tsm.state.hybrid_shard_degree * tsm.state.tp_size) if tsm.state.tp_size > 1: # Tensor parallel groups will have their own sub directories. sub_dir = f"tp{tsm.state.tp_size}-{tsm.state.tp_rank}" else: sub_dir = "" # 1. Save checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # Save from one single replication group. if action_rank: dist.checkpoint.save_state_dict( state_dict=state_dict, storage_writer=dist.checkpoint.FileSystemWriter(os.path.join(save_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) # 2. Load checkpoints. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 2.1 Load model and everything else except the optimizer. state_dict = { # All states except optimizer state can be passed here. "model": model.state_dict() } dist.checkpoint.load_state_dict( state_dict=state_dict, storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # Potentially process more customized and non-optimizer dict states. # 2.2 Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=dist.checkpoint.FileSystemReader(os.path.join(load_dir, sub_dir)), process_group=model.process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group, ) optimizer.load_state_dict(flattened_optimizer_state)

Puntos de control de modelo completo

Al final del entrenamiento, puede guardar un punto de control completo que combine todos los fragmentos de un modelo en un único archivo de puntos de control del modelo. La biblioteca SMP es totalmente compatible con la API de puntos de control del modelo PyTorch completo, por lo que no es necesario realizar ningún cambio.

Tenga en cuenta que si utiliza el SMPParalelismo de tensores, la biblioteca SMP transforma el modelo. Al comprobar el modelo completo en este caso, la biblioteca SMP vuelve a traducir el modelo al formato de punto de control Hugging Face Transformers de forma predeterminada.

En los casos en los que entrenes con el paralelismo tensorial SMP y desactives el proceso de traducción SMP, puedes usar el translate_on_save argumento de la PyTorch FullStateDictConfig API para activar o desactivar la traducción automática SMP según sea necesario. Por ejemplo, si te estás centrando en entrenar un modelo, no necesitas añadir el proceso de traducción, lo que supone una sobrecarga. En ese caso, le recomendamos que configuretranslate_on_save=False. Además, si planea seguir utilizando la traducción SMP del modelo para seguir formándose en el futuro, puede desactivarla para guardar la traducción SMP del modelo para usarla más adelante. Es necesario volver a traducir el modelo al formato de punto de control del modelo Hugging Face Transformers cuando termines el entrenamiento de tu modelo y lo utilices como inferencia.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Ten en cuenta que la opción FullStateDictConfig(rank0_only=True, offload_to_cpu=True) consiste en recopilar el modelo en la CPU del dispositivo de rango 0 para ahorrar memoria al entrenar modelos grandes.

Para volver a cargar el modelo para su inferencia, haga lo que se muestra en el siguiente ejemplo de código. Ten en cuenta que la clase AutoModelForCausalLM podría cambiar a otras clases de creación de factores en Hugging Face Transformers, por ejemplo, AutoModelForSeq2SeqLM según tu modelo. Para obtener más información, consulte la documentación de Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)