本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
您只需使用几行代码即可对内置算法或预训练模型进行微调 SageMaker Python SDK。
-
首先,在内置算法与预训练模型表
中找到所选模型的模型 ID。 -
使用模型 ID,将您的训练作业定义为 JumpStart估算器。
from sagemaker.jumpstart.estimator import JumpStartEstimator model_id =
"huggingface-textgeneration1-gpt-j-6b"
estimator = JumpStartEstimator(model_id=model_id) -
在模型上运行
estimator.fit()
,指向用于微调的训练数据。estimator.fit( {"train":
training_dataset_s3_path
, "validation":validation_dataset_s3_path
} ) -
然后,使用
deploy
方法自动部署模型进行推理。在这个例子中,我们使用来自的 GPT-J 6B 模型 Hugging Face.predictor = estimator.deploy()
-
然后,您就可以使用
predict
方法对已部署的模型进行推理。question =
"What is Southern California often abbreviated as?"
response = predictor.predict(question) print(response)
注意
此示例使用基础模型 GPT-J 6B,该模型适用于各种文本生成使用场景,包括问题解答、命名实体识别、摘要等。有关模型使用场景的更多信息,请参阅 可用的基础模型。
创建 JumpStartEstimator
时,您可以选择指定模型版本或实例类型。有关该JumpStartEstimator
类及其参数的更多信息,请参见JumpStartEstimator
检查默认实例类型
在使用 JumpStartEstimator
类对预训练模型进行微调时,您可以选择包含特定的模型版本或实例类型。所有 JumpStart 模型都有默认的实例类型。使用以下代码读取默认训练实例类型:
from sagemaker import instance_types
instance_type = instance_types.retrieve_default(
model_id=model_id,
model_version=model_version,
scope="training"
)
print(instance_type)
您可以使用instance_types.retrieve()
方法查看给定 JumpStart 模型的所有支持的实例类型。
检查默认超参数
要检查用于训练的默认超参数,可以使用 hyperparameters
类中的 retrieve_default()
方法。
from sagemaker import hyperparameters
my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version)
print(my_hyperparameters)
# Optionally override default hyperparameters for fine-tuning
my_hyperparameters["epoch"] = "3"
my_hyperparameters["per_device_train_batch_size"] = "4"
# Optionally validate hyperparameters for the model
hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)
有关可用超参数的更多信息,请参阅 通常支持的微调超参数。
检查默认指标定义
您还可以检查默认指标定义:
print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))