오뚝이개발자

허깅페이스(Huggingface) custom loss로 Trainer 학습시키기 본문

AI/AI 개발

허깅페이스(Huggingface) custom loss로 Trainer 학습시키기

땅어 2022. 1. 28. 14:59
728x90
300x250

 

 

허깅페이스의 transformers 패키지를 사용할 때 custom loss로 최적화를 해야하는 경우가 있다. 이럴 땐 Trainer클래스를 상속받아 새로운 CustomTrainer 클래스를 만들고 그 안의 compute_loss 함수를 새로 작성해주면 된다.

from torch import nn
from transformers import Trainer


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

compute_loss 내에서 본인이 원하는 loss를 구해 최적화하도록 코드를 바꿔주면 된다.

 

 

 

728x90
300x250
Comments