选择您的 Cookie 首选项

我们使用必要 Cookie 和类似工具提供我们的网站和服务。我们使用性能 Cookie 收集匿名统计数据,以便我们可以了解客户如何使用我们的网站并进行改进。必要 Cookie 无法停用,但您可以单击“自定义”或“拒绝”来拒绝性能 Cookie。

如果您同意,AWS 和经批准的第三方还将使用 Cookie 提供有用的网站功能、记住您的首选项并显示相关内容,包括相关广告。要接受或拒绝所有非必要 Cookie,请单击“接受”或“拒绝”。要做出更详细的选择,请单击“自定义”。

使用 JumpStartEstimator 类微调公开可用的基础模型

聚焦模式
使用 JumpStartEstimator 类微调公开可用的基础模型 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

您只需使用几行代码即可对内置算法或预训练模型进行微调 SageMaker Python SDK。

  1. 首先,在内置算法与预训练模型表中找到所选模型的模型 ID。

  2. 使用模型 ID,将您的训练作业定义为 JumpStart估算器。

    from sagemaker.jumpstart.estimator import JumpStartEstimator model_id = "huggingface-textgeneration1-gpt-j-6b" estimator = JumpStartEstimator(model_id=model_id)
  3. 在模型上运行 estimator.fit(),指向用于微调的训练数据。

    estimator.fit( {"train": training_dataset_s3_path, "validation": validation_dataset_s3_path} )
  4. 然后,使用 deploy 方法自动部署模型进行推理。在这个例子中,我们使用来自的 GPT-J 6B 模型 Hugging Face.

    predictor = estimator.deploy()
  5. 然后,您就可以使用 predict 方法对已部署的模型进行推理。

    question = "What is Southern California often abbreviated as?" response = predictor.predict(question) print(response)
注意

此示例使用基础模型 GPT-J 6B,该模型适用于各种文本生成使用场景,包括问题解答、命名实体识别、摘要等。有关模型使用场景的更多信息,请参阅 可用的基础模型

创建 JumpStartEstimator 时,您可以选择指定模型版本或实例类型。有关该JumpStartEstimator 类及其参数的更多信息,请参见JumpStartEstimator

检查默认实例类型

在使用 JumpStartEstimator 类对预训练模型进行微调时,您可以选择包含特定的模型版本或实例类型。所有 JumpStart 模型都有默认的实例类型。使用以下代码读取默认训练实例类型:

from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope="training") print(instance_type)

您可以使用instance_types.retrieve()方法查看给定 JumpStart 模型的所有支持的实例类型。

检查默认超参数

要检查用于训练的默认超参数,可以使用 hyperparameters 类中的 retrieve_default() 方法。

from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)

有关可用超参数的更多信息,请参阅 通常支持的微调超参数

检查默认指标定义

您还可以检查默认指标定义:

print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))
隐私网站条款Cookie 首选项
© 2025, Amazon Web Services, Inc. 或其附属公司。保留所有权利。