[PyTorch] 튜토리얼 (3)

rkqhwkrn·2023년 8월 10일
0

Python

목록 보기
12/13

Torchtext 라이브러리로 텍스트 분류하기

시작하기

Torchtext 데이터셋에 접근하기

import torch
from torchtext.datasets import AG_NEWS
train_iter = iter(AG_NEWS(split='train)

next(train_iter)  # iterator에서 값을 차례대로 꺼내줌
next(train_iter)  # iterator에서 값을 차례대로 꺼내줌

실행결과

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

다음와 같이 AG_NEWS 데이터셋의 label과 문장을 출력함

데이터 처리 파이프라인 준비하기

  • Torchtext 라이브러리는 가공되지 않은 텍스트 문장들을 만드는 (yield) 몇 가지 기초 데이터셋 interator를 제공
  • Tokenizer 및 vocab을 이용하여 자연어로 구성된 데이터를 처리할 수 있음

1) 가공되지 않은 train dataset으로 vocab (어휘집) 생성

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
	for _, text in data_iter:
    	yield tokenizer(text)
    
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
  • vocabulary block은 토큰을 정수로 변환
vocab(['here','is','an','example'])

실행결과

[475, 21, 30, 5297]

2) 텍스트 처리 파이프라인을 준비

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

text_pipeline("here is the an example')
label_pipeline("10")

실행결과

[475, 21, 2, 30, 5297]
9

다음과 같이 text_pipeline을 통해 lookup table에 기반하여 텍스트 데이터를 정수로 변환함

데이터 배치 (batch)와 반복자 (iterator) 생성

from torch.utils.data import DataLoader

# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 배치 데이터를 만들기 위한 collate_batch 함수 정의
def collate_batch(batch):
	label_list, text_list, offsets = [],[],[0]
    for (_label, _text) in batch:
    	label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # cumsum: 누적합
    text_list = torch.cat(text_list)  # cat: 합치기
    return label_list.to(device), text_list.to(device), offsets.to(device)
    
# 데이터셋과 dataloader 정의하기
train_iter = AG_NEWS(split="train")
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
  • 주어진 data batch의 텍스트 항목들을 list에 담은 뒤 nn.EmbeddingBag에 입력하기 위해 하나의 tensor로 concat하여 합침.
  • label은 개별 텍스트 항목의 label을 저장하는 tensor
  • offset은 텍스트 tensor에서 개별 시퀀스 시작 인텍스를 표현하기 위한 구분자 tensor

모델 정의하기

  • 모델 구성: nn.EmbeddingBag layer + linear layer (분류 목적)
from torch import nn

class TextClassificationModel(nn.Module):
	def __init__(self, vocab_size, embed_dim, num_class):
    	super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
    	initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
    	embedded = self.embedding(text, offsets)
        return self.fc(embedded)

train_iter = AG_News(split="train")
num_classes = len(set(label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

모델을 학습하고 결과를 평가하는 함수 정의

import time

def train(dataloader):
	model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()
    for idx, (label, text, offsets) in enumerate(dataloader):
    	optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
        	elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()
            
def evaluate(dataloader):
  model.eval()
  total_acc, total_count = 0, 0

  with torch.no_grad():
    for idx, (label, text, offsets) in enumerate(dataloader):
      predicted_label = model(text, offsets)
      loss = criterion(predicted_label, label)
      total_acc += (predicted_label.argmax(1) == label).sum().item()
      total_count += label.size(0)
    return total_acc/total_count

데이터셋을 분할하고 모델 수행하기

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

실행결과

| epoch   1 |   500/ 1782 batches | accuracy    0.685
| epoch   1 |  1000/ 1782 batches | accuracy    0.855
| epoch   1 |  1500/ 1782 batches | accuracy    0.876
-----------------------------------------------------------
| end of epoch   1 | time: 11.03s | valid accuracy    0.874 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.898
| epoch   2 |  1000/ 1782 batches | accuracy    0.900
| epoch   2 |  1500/ 1782 batches | accuracy    0.906
-----------------------------------------------------------
| end of epoch   2 | time:  9.61s | valid accuracy    0.891 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.916
| epoch   3 |  1000/ 1782 batches | accuracy    0.915
| epoch   3 |  1500/ 1782 batches | accuracy    0.918
-----------------------------------------------------------
| end of epoch   3 | time: 10.54s | valid accuracy    0.888 
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.933
| epoch   4 |  1000/ 1782 batches | accuracy    0.929
| epoch   4 |  1500/ 1782 batches | accuracy    0.930
-----------------------------------------------------------
| end of epoch   4 | time: 10.44s | valid accuracy    0.898 
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.931
| epoch   5 |  1000/ 1782 batches | accuracy    0.931
| epoch   5 |  1500/ 1782 batches | accuracy    0.932
-----------------------------------------------------------
| end of epoch   5 | time: 10.58s | valid accuracy    0.901 
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.936
| epoch   6 |  1000/ 1782 batches | accuracy    0.932
| epoch   6 |  1500/ 1782 batches | accuracy    0.931
-----------------------------------------------------------
| end of epoch   6 | time:  9.34s | valid accuracy    0.900 
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.935
| epoch   7 |  1000/ 1782 batches | accuracy    0.936
| epoch   7 |  1500/ 1782 batches | accuracy    0.934
-----------------------------------------------------------
| end of epoch   7 | time: 10.82s | valid accuracy    0.901 
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.932
| epoch   8 |  1000/ 1782 batches | accuracy    0.935
| epoch   8 |  1500/ 1782 batches | accuracy    0.936
-----------------------------------------------------------
| end of epoch   8 | time: 10.54s | valid accuracy    0.900 
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.933
| epoch   9 |  1000/ 1782 batches | accuracy    0.938
| epoch   9 |  1500/ 1782 batches | accuracy    0.934
-----------------------------------------------------------
| end of epoch   9 | time: 10.55s | valid accuracy    0.901 
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.934
| epoch  10 |  1000/ 1782 batches | accuracy    0.935
| epoch  10 |  1500/ 1782 batches | accuracy    0.937
-----------------------------------------------------------
| end of epoch  10 | time:  9.40s | valid accuracy    0.901 
-----------------------------------------------------------

모델평가

print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

실행결과

Checking the results of test dataset.
test accuracy    0.904

임의의 뉴스로 평가

ag_news_label = {
    1: "World",
    2: "Sports",
    3: "Business",
    4: "Sci/Tec"
}

def predict(text, text_pipeline):
  with torch.no_grad():
    text = torch.tensor(text_pipeline(text))
    output = model(text, torch.tensor([0]))
    return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

실행결과

This is a Sports news

0개의 댓글