torch.nn.funtional.nll_loss와 torch.nn.NLLLoss의 차이

조호원·2020년 11월 18일
1

AI 공부 기록지

목록 보기
2/2

torch.nn.funtional.nll_loss와 torch.nn.NLLLoss는 서로 같은 이름을 가지고 있는데 심지어 사용 용도도 거의 같다. 왜 이렇게 번거롭게 같은 기능을 여러개 모듈로 만들어 놨을까? 사실 이 질문이 나오기전에 나는 아래의 오류를 만났다.

def fit(epoch, model, data_loader, phase='training'):
    if phase=='training':
        model.train()
    if phase=='validation':
        model.eval()
        volatile=True

    running_loss = 0.0
    running_correct = 0.0
    
    for batch_index, (data, target) in enumerate(data_loader):
        if is_cuda:
            data, target = data.cuda(), target.cuda()

        if phase== 'training':
            optimizer.zero_grad()

        output = model(data)
        print(torch.sum(output[0]))
        loss = nn.NLLLoss(output, target)

        if phase== 'training':
            loss.backward()
            optimizer.step()
        if phase== 'validation':
            exp_lr_scheduler.step()
        
    loss = running_loss/len(data_loader.dataset)
    accuracy = 100. * running_correct/len(data_loader.dataset)

    print ('[{}] epoch: {:2d} loss: {:.8f} accuracy: {:.8f}'.format(phase, epoch, loss, accuracy))
    return loss, accuracy

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

F.nll_loss는 잘먹히는데 nn.NLLLoss를 손실 함수로 사용하자 오류가 뜬 것이다. 왜 그런가 찾아보니 torch.nn.functional.nll_loss(이하 F.nll_loss)는 함수지만 torch.nn.NLLLoss(nn.NLLLoss)은 클래스이다. 그래서 손실을 구할 때 두 개의 차이를 살펴보자.

loss1 = torch.nn.functional.nll_loss(output, target)

loss2 = torch.nn.NLLLoss()(output, target)

그렇다. NLLLoss는 근본적으로 Class이고 __call__()을 호출해서 output의 손실을 구할 수 있는데 __init__()에 output과 target이 들어가게 된 것이다. 두 번째 인자는 boolean 타입을 받아야 하는데 target(토치 텐서)를 받았기 때문에 오류가 발생한 것이다. 이 두 함수와 클래스가 구체적으로 기능들이 어떻게 다른지는 나중에 더 공부해봐야 할 과제이다.

profile
세상을 바꾸는 사람

0개의 댓글