기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.
JumpStartEstimator
클래스를 사용하여 공개적으로 사용 가능한 파운데이션 모델을 미세 조정합니다.
를 사용하여 몇 줄의 코드만으로 내장 알고리즘 또는 사전 훈련된 모델을 미세 조정할 수 있습니다. SageMaker Python SDK.
-
모델 ID를 사용하여 훈련 작업을 JumpStart 추정기로 정의합니다.
from sagemaker.jumpstart.estimator import JumpStartEstimator model_id =
"huggingface-textgeneration1-gpt-j-6b"
estimator = JumpStartEstimator(model_id=model_id) -
미세 조정
estimator.fit()
에 사용할 훈련 데이터를 가리키면서 모델에서 를 실행합니다.estimator.fit( {"train":
training_dataset_s3_path
, "validation":validation_dataset_s3_path
} ) -
그런 다음
deploy
메서드를 사용하여 추론을 위해 모델을 자동으로 배포합니다. 이 예제에서는 의 GPT-J 6B 모델을 사용합니다.Hugging Face.predictor = estimator.deploy()
-
그런 다음
predict
메서드를 사용하여 배포된 모델로 추론을 실행할 수 있습니다.question =
"What is Southern California often abbreviated as?"
response = predictor.predict(question) print(response)
참고
이 예제에서는 파운데이션 모델 GPT-J 6B 를 사용합니다. 이 모델은 질문 응답, 명명된 엔터티 인식, 요약 등을 포함한 다양한 텍스트 생성 사용 사례에 적합합니다. 모델 사용 사례에 대한 자세한 내용은 섹션을 참조하세요사용 가능한 파운데이션 모델.
를 생성할 때 선택적으로 모델 버전 또는 인스턴스 유형을 지정할 수 있습니다JumpStartEstimator
. JumpStartEstimator
클래스 및 해당 파라미터에 대한 자세한 내용은 섹션을 참조하세요JumpStartEstimator
기본 인스턴스 유형 확인
JumpStartEstimator
클래스를 사용하여 사전 학습된 모델을 미세 조정할 때 선택적으로 특정 모델 버전 또는 인스턴스 유형을 포함할 수 있습니다. 모든 JumpStart 모델에는 기본 인스턴스 유형이 있습니다. 다음 코드를 사용하여 기본 훈련 인스턴스 유형을 검색합니다.
from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope=
"training"
) print(instance_type)
instance_types.retrieve()
메서드를 사용하여 지정된 JumpStart 모델에 대해 지원되는 모든 인스턴스 유형을 볼 수 있습니다.
기본 하이퍼파라미터 확인
훈련에 사용되는 기본 하이퍼파라미터를 확인하려면 hyperparameters
클래스의 retrieve_default()
메서드를 사용할 수 있습니다.
from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)
사용 가능한 하이퍼파라미터에 대한 자세한 내용은 섹션을 참조하세요일반적으로 지원되는 미세 조정 하이퍼파라미터.
기본 지표 정의 확인
기본 지표 정의를 확인할 수도 있습니다.
print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))