pytorch model의 'eval' 함수

leeway·2023년 1월 6일
0

python

목록 보기
3/4

자연어처리에는 여러 딥러닝 모델을 활용하며, 특히 PyTorch는 복잡한 모델을 쉽게 정의하고 훈련할 수 있는 유연하고 강력한 딥러닝 프레임워크이기 때문에, BERT 등의 모델을 훈련하고 미세 조정하는 데 많이 사용되고 있음

추론(inference)을 위해 저장된 PyTorch 모델을 불러와서 eval() 함수를 호출하는 것을 볼 수 있음

eval 함수 사용하는 이유

'파이토치 한국 사용자 모임'에서는 아래와 같이 설명함

[PyTorch Tutorials] 추론을 실행하기 전에 반드시 model.eval() 을 호출하여 드롭아웃 및 배치 정규화를 평가 모드로 설정하여야 합니다. 이 과정을 거치지 않으면 일관성 없는 추론 결과가 출력됩니다.

즉, evaluation 과정에서 layer는 훈련 중에 계산된 평균과 분산의 이동평균을 사용하므로써 드롭아웃이 비활성화되고 업데이트되지 않음

eval 사용하기

PyTorch에서는 모델의 eval() 방법을 사용하여 모델을 평가 모드로 전환할 수 있음

Example

[chatGPT] MNIST 데이터셋을 사용하여 PyTorch 모델을 평가하는 방법의 예시

import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# Load the test dataset and create a DataLoader
dataset = MNIST(root='data/', train=False, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

# Set the model to evaluation mode
model.eval()

# Initialize a list to store the predictions
predictions = []

# Iterate over the test dataset
for inputs, labels in dataloader:
    # Pass the input data to the model to make predictions
    output = model(inputs)

    # Append the predictions to the list
    predictions.extend(output.argmax(dim=1).tolist())

# Compute the accuracy of the model
accuracy = (predictions == labels).float().mean().item()
print(f'Accuracy: {accuracy:.2f}')


Reference

profile
자연어처리 개발자

0개의 댓글