Ajuste - Amazon SageMaker AI

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

Ajuste

El afinamiento es un proceso de entrenamiento continuo de modelos previamente entrenados para mejorar el rendimiento en casos de uso específicos.

Es fácil ajustar los modelos pequeños que caben completamente en una sola GPU, o aquellos que caben completamente en 8 copias del modelo. CPUs No se requiere ningún cambio especial con respecto al entrenamiento de FSDP ordinario. En el caso de modelos de mayor tamaño, hay que considerar la posibilidad de utilizar la función de inicialización diferida de parámetros, que puede resultar complicada.

Para solucionar este problema, la biblioteca de SMP carga el modelo completo en uno de los rangos, mientras que el resto crea modelos con ponderaciones vacías en un metadispositivo. A continuación, el PyTorch FSDP inicializa las ponderaciones de los rangos distintos de cero mediante la init_weights función y sincroniza las ponderaciones de todas las filas con las ponderaciones de la fila 0 si se establece en. sync_module_states True En el siguiente fragmento de código se muestra cómo debe configurarlo en su script de entrenamiento.

import torch.distributed as dist from transformers import AutoModelForCasalLM from accelerate import init_empty_weights from torch.sagemaker.delayed_param import DelayedParamIniter if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) delayed_initer = DelayedParamIniter(model) model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None )

Afinamiento de un modelo de Hugging Face Transformer con paralelismo de tensores de SMP

En esta sección se analiza la carga de modelos de transformador para dos casos de uso: afinamiento de los modelos de transformador pequeños y ajuste de modelos de transformador grandes. Para modelos más pequeños sin demorar la inicialización de los parámetros, empaquete el modelo con la API antes de empaquetarlo con el torch.sagemaker.transform FSDP. PyTorch

import functools from transformers import AutoModelForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.sagemaker import transform model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True) # Transform model while loading state dictionary from rank 0. tp_model = transform(model, load_state_dict_from_rank0=True) # Wrap with FSDP. model = FSDP( tp_model, ... sync_module_states=True, )

En el caso de los modelos mayores, el método anterior hace que se agote la memoria de la CPU. Le recomendamos que utilice la inicialización diferida de parámetros para evitar estos problemas de memoria de la CPU. En este caso, puede aplicar la API torch.sagemaker.transform y la API torch.sagemaker.delayed_param.DelayedParamIniter como se muestra en el siguiente código de ejemplo.

from transformers import AutoModelForCausalLM from torch.sagemaker import transform from torch.sagemaker.delayed_param import DelayedParamIniter # Create one instance of model without delayed param # on CPU, on one rank. if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) # Transform model while loading state dictionary from rank 0 model = transform(model, load_state_dict_from_rank0=True) if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks delayed_initer = DelayedParamIniter(model) else: delayed_initer = None with ( delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext() ): # Wrap the model with FSDP model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None )