300x250
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- 운영체제
- 동적프로그래밍
- 알고리즘
- kick start
- 네트워크
- 리눅스
- google coding competition
- 순열
- PYTHON
- 동적 프로그래밍
- BFS
- 딥러닝
- AI
- 그래프
- OS
- 백준
- linux
- 프로그래머스
- dp
- 프로그래밍
- 코딩 테스트
- 파이썬
- 브루트포스
- 킥스타트
- 구글 킥스타트
- DFS
- 코딩테스트
- nlp
- 코딩
- CSS
Archives
- Today
- Total
오뚝이개발자
허깅페이스(Huggingface) transformers로 early stopping 사용하기 본문
728x90
300x250
허깅페이스의 transformers 패키지를 사용하는데 early stopping 방식으로 학습을 시키고 싶을 땐 아래와 같이 early stopping callback을 넣어주면 된다.
from transformers import EarlyStoppingCallback
batch_size = 3
args = Seq2SeqTrainingArguments(
"saved_model",
evaluation_strategy = "steps",
eval_steps = 5,
load_best_model_at_end = True,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=2,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=5,
predict_with_generate=True,
fp16=False,
report_to='wandb',
run_name="ut_del_three_per_each_ver2_early_stop_4" # name of the W&B run (optional)
)
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)
다른 파라미터는 임의로 바꾸어도 되지만 아래의 4가지는 꼭 지정해주어야 한다.
- load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
- evaluation_strategy = 'steps' instead of 'epoch'.
- eval_steps = 50 (evaluate the metrics after N steps).
- metric_for_best_model = 'f1',
metric_for_best_model은 지정해주지 않으면 기본적으로 validation loss가 default이다. 즉, loss를 보고서 early stopping을 할 지점을 정한다.
참고로, early stopping은 훈련 단위가 step이기 때문에 오히려 epoch으로 훈련한 경우보다 test 성능이 더 안좋게 나오는 경우도 발생한다.
728x90
300x250
'AI > AI 개발' 카테고리의 다른 글
허깅페이스(Huggingface) 모델 inference(pipeline) GPU로 돌리기 (2) | 2022.06.26 |
---|---|
텍스트를 문장 단위로 분할하기(nltk, sentence tokenizing) (0) | 2022.05.15 |
텍스트로부터 키워드 추출하기(KeyBERT) (0) | 2022.02.21 |
허깅페이스(Huggingface) custom loss로 Trainer 학습시키기 (0) | 2022.01.28 |
허깅페이스(Huggingface)로 내 모델 포팅(porting)하기 (0) | 2022.01.20 |
Comments