오뚝이개발자

허깅페이스(Huggingface) transformers로 early stopping 사용하기 본문

AI/AI 개발

허깅페이스(Huggingface) transformers로 early stopping 사용하기

땅어 2022. 3. 28. 10:51
728x90
300x250

 

허깅페이스의 transformers 패키지를 사용하는데 early stopping 방식으로 학습을 시키고 싶을 땐 아래와 같이 early stopping callback을 넣어주면 된다.

from transformers import EarlyStoppingCallback

batch_size = 3
args = Seq2SeqTrainingArguments(
   "saved_model",
   evaluation_strategy = "steps",
   eval_steps = 5,
   load_best_model_at_end = True,
   learning_rate=2e-5,
   per_device_train_batch_size=batch_size,
   per_device_eval_batch_size=batch_size,
   gradient_accumulation_steps=2,
   weight_decay=0.01,
   save_total_limit=3,
   num_train_epochs=5,
   predict_with_generate=True,
   fp16=False,
   report_to='wandb',
   run_name="ut_del_three_per_each_ver2_early_stop_4"  # name of the W&B run (optional)
)
trainer = Seq2SeqTrainer(
   model,
   args,
   train_dataset=tokenized_datasets["train"],
   eval_dataset=tokenized_datasets["validation"],
   data_collator=data_collator,
   tokenizer=tokenizer,
   compute_metrics=compute_metrics,
   callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)

다른 파라미터는 임의로 바꾸어도 되지만 아래의 4가지는 꼭 지정해주어야 한다.

  1. load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
  2. evaluation_strategy = 'steps' instead of 'epoch'.
  3. eval_steps = 50 (evaluate the metrics after N steps).
  4. metric_for_best_model = 'f1',

metric_for_best_model은 지정해주지 않으면 기본적으로 validation loss가 default이다. 즉, loss를 보고서 early stopping을 할 지점을 정한다.

참고로, early stopping은 훈련 단위가 step이기 때문에 오히려 epoch으로 훈련한 경우보다 test 성능이 더 안좋게 나오는 경우도 발생한다.

 

728x90
300x250
Comments