AI/AI 개발
허깅페이스(Huggingface) transformers로 early stopping 사용하기
땅어
2022. 3. 28. 10:51
728x90
300x250
허깅페이스의 transformers 패키지를 사용하는데 early stopping 방식으로 학습을 시키고 싶을 땐 아래와 같이 early stopping callback을 넣어주면 된다.
from transformers import EarlyStoppingCallback
batch_size = 3
args = Seq2SeqTrainingArguments(
"saved_model",
evaluation_strategy = "steps",
eval_steps = 5,
load_best_model_at_end = True,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=2,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=5,
predict_with_generate=True,
fp16=False,
report_to='wandb',
run_name="ut_del_three_per_each_ver2_early_stop_4" # name of the W&B run (optional)
)
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)
다른 파라미터는 임의로 바꾸어도 되지만 아래의 4가지는 꼭 지정해주어야 한다.
- load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
- evaluation_strategy = 'steps' instead of 'epoch'.
- eval_steps = 50 (evaluate the metrics after N steps).
- metric_for_best_model = 'f1',
metric_for_best_model은 지정해주지 않으면 기본적으로 validation loss가 default이다. 즉, loss를 보고서 early stopping을 할 지점을 정한다.
참고로, early stopping은 훈련 단위가 step이기 때문에 오히려 epoch으로 훈련한 경우보다 test 성능이 더 안좋게 나오는 경우도 발생한다.
728x90
300x250