nn.TransformerEncoder
모델을 language modeling task에 대해 학습시키고자 함.import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float=0.5):
super().__init__()
self.model_type = "Transformer"
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
def generate_square_subsequent_mask(sz: int) -> Tensor:
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float=0.1, max_len: int=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position*div_term)
pe[:, 0, 1::2] = torch.cos(position*div_term)
self.register_buffer("pe", pe)
def forward(self, x:Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
train_iter = WikiText2(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def batchify(data: Tensor, bsz: int) -> Tensor:
# 데이터셋을 bsz 파트들로 나눔
seq_len = data.size(0) // bsz
# 나머지가 있는 경우 나머지 정리
data = data[:seq_len * bsz]
# 데이터를 bsz 배치들로 동일하게 나눔
data = data.view(bsz, seq_len).t().contiguous()
return data.to(device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
ntokens = len(vocab) # 단어 사전(어휘집)의 크기
emsize = 200 # 임베딩 차원
d_hid = 200 # 'nn.TransformerEncoder' 에서 피트포워드 네트워크 모델의 차원
nlayers = 2 # 'nn.TransformerEncoder' 내부의 nn.TransformerEncoderLayer 개수
nhead = 2 # 'nn.MultiheadAttention'의 헤드 개수
dropout = 0.2
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
import copy
import time
criterion = nn.CrossEntropyLoss()
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
def train(model: nn.Module) -> None:
model.train()
total_loss = 0.
log_interval = 200
start_time = time.time()
src_mask = generate_square_subsequent_mask(bptt).to(device)
num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0,train_data.size(0)-1, bptt)):
data, targets = get_batch(train_data, i)
seq_len = data.size(0)
if seq_len != bptt:
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
ppl = math.exp(cur_loss)
print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
total_loss = 0
start_time = time.time()
def evaluate(model: nn.Module, eval_data: Tensor) -> float:
model.eval() # 평가 모드 시작
total_loss = 0.
src_mask = generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i)
seq_len = data.size(0)
if seq_len != bptt:
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += seq_len * criterion(output_flat, targets).item()
return total_loss / (len(eval_data) - 1)
best_val_loss = float('inf')
epochs = 3
with TemporaryDirectory() as tempdir:
best_model_params_path = os.path.join(tempdir, "best_model_params.pt")
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(model)
val_loss = evaluate(model, val_data)
val_ppl = math.exp(val_loss)
elapsed = time.time() - epoch_start_time
print('-' * 89)
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
print('-' * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), best_model_params_path)
scheduler.step()
model.load_state_dict(torch.load(best_model_params_path)) # load best model states
실행결과
| epoch 1 | 200/ 2928 batches | lr 5.00 | ms/batch 17.85 | loss 8.25 | ppl 3839.33
| epoch 1 | 400/ 2928 batches | lr 5.00 | ms/batch 15.74 | loss 6.95 | ppl 1039.31
| epoch 1 | 600/ 2928 batches | lr 5.00 | ms/batch 17.71 | loss 6.47 | ppl 647.04
| epoch 1 | 800/ 2928 batches | lr 5.00 | ms/batch 15.34 | loss 6.31 | ppl 552.52
| epoch 1 | 1000/ 2928 batches | lr 5.00 | ms/batch 15.38 | loss 6.19 | ppl 488.62
| epoch 1 | 1200/ 2928 batches | lr 5.00 | ms/batch 15.71 | loss 6.16 | ppl 474.81
| epoch 1 | 1400/ 2928 batches | lr 5.00 | ms/batch 15.98 | loss 6.12 | ppl 452.61
| epoch 1 | 1600/ 2928 batches | lr 5.00 | ms/batch 15.51 | loss 6.10 | ppl 447.29
| epoch 1 | 1800/ 2928 batches | lr 5.00 | ms/batch 15.53 | loss 6.03 | ppl 415.70
| epoch 1 | 2000/ 2928 batches | lr 5.00 | ms/batch 16.38 | loss 6.02 | ppl 411.22
| epoch 1 | 2200/ 2928 batches | lr 5.00 | ms/batch 16.20 | loss 5.90 | ppl 365.17
| epoch 1 | 2400/ 2928 batches | lr 5.00 | ms/batch 15.66 | loss 5.97 | ppl 390.53
| epoch 1 | 2600/ 2928 batches | lr 5.00 | ms/batch 15.72 | loss 5.95 | ppl 384.23
| epoch 1 | 2800/ 2928 batches | lr 5.00 | ms/batch 15.74 | loss 5.88 | ppl 358.90
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 49.20s | valid loss 5.84 | valid ppl 345.37
-----------------------------------------------------------------------------------------
| epoch 2 | 200/ 2928 batches | lr 4.75 | ms/batch 15.96 | loss 5.87 | ppl 354.45
| epoch 2 | 400/ 2928 batches | lr 4.75 | ms/batch 15.86 | loss 5.86 | ppl 350.68
| epoch 2 | 600/ 2928 batches | lr 4.75 | ms/batch 15.92 | loss 5.67 | ppl 289.90
| epoch 2 | 800/ 2928 batches | lr 4.75 | ms/batch 16.46 | loss 5.71 | ppl 301.70
| epoch 2 | 1000/ 2928 batches | lr 4.75 | ms/batch 16.16 | loss 5.66 | ppl 285.79
| epoch 2 | 1200/ 2928 batches | lr 4.75 | ms/batch 16.02 | loss 5.69 | ppl 295.86
| epoch 2 | 1400/ 2928 batches | lr 4.75 | ms/batch 16.03 | loss 5.69 | ppl 296.75
| epoch 2 | 1600/ 2928 batches | lr 4.75 | ms/batch 16.27 | loss 5.71 | ppl 302.64
| epoch 2 | 1800/ 2928 batches | lr 4.75 | ms/batch 16.26 | loss 5.66 | ppl 286.10
| epoch 2 | 2000/ 2928 batches | lr 4.75 | ms/batch 15.95 | loss 5.67 | ppl 290.46
| epoch 2 | 2200/ 2928 batches | lr 4.75 | ms/batch 15.87 | loss 5.56 | ppl 259.70
| epoch 2 | 2400/ 2928 batches | lr 4.75 | ms/batch 16.02 | loss 5.66 | ppl 286.64
| epoch 2 | 2600/ 2928 batches | lr 4.75 | ms/batch 16.28 | loss 5.65 | ppl 285.15
| epoch 2 | 2800/ 2928 batches | lr 4.75 | ms/batch 15.80 | loss 5.59 | ppl 266.57
-----------------------------------------------------------------------------------------
| end of epoch 2 | time: 48.83s | valid loss 5.66 | valid ppl 288.23
-----------------------------------------------------------------------------------------
| epoch 3 | 200/ 2928 batches | lr 4.51 | ms/batch 16.00 | loss 5.61 | ppl 274.49
| epoch 3 | 400/ 2928 batches | lr 4.51 | ms/batch 16.20 | loss 5.63 | ppl 278.12
| epoch 3 | 600/ 2928 batches | lr 4.51 | ms/batch 15.70 | loss 5.43 | ppl 227.61
| epoch 3 | 800/ 2928 batches | lr 4.51 | ms/batch 15.71 | loss 5.49 | ppl 242.05
| epoch 3 | 1000/ 2928 batches | lr 4.51 | ms/batch 15.80 | loss 5.44 | ppl 229.89
| epoch 3 | 1200/ 2928 batches | lr 4.51 | ms/batch 16.29 | loss 5.47 | ppl 238.51
| epoch 3 | 1400/ 2928 batches | lr 4.51 | ms/batch 15.70 | loss 5.50 | ppl 244.43
| epoch 3 | 1600/ 2928 batches | lr 4.51 | ms/batch 15.77 | loss 5.52 | ppl 250.63
| epoch 3 | 1800/ 2928 batches | lr 4.51 | ms/batch 15.70 | loss 5.47 | ppl 236.73
| epoch 3 | 2000/ 2928 batches | lr 4.51 | ms/batch 16.12 | loss 5.48 | ppl 240.57
| epoch 3 | 2200/ 2928 batches | lr 4.51 | ms/batch 16.01 | loss 5.36 | ppl 212.49
| epoch 3 | 2400/ 2928 batches | lr 4.51 | ms/batch 15.73 | loss 5.46 | ppl 234.90
| epoch 3 | 2600/ 2928 batches | lr 4.51 | ms/batch 15.75 | loss 5.47 | ppl 236.71
| epoch 3 | 2800/ 2928 batches | lr 4.51 | ms/batch 15.99 | loss 5.41 | ppl 222.82
-----------------------------------------------------------------------------------------
| end of epoch 3 | time: 48.48s | valid loss 5.59 | valid ppl 268.29
-----------------------------------------------------------------------------------------
test_loss = evaluate(model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
f'test ppl {test_ppl:8.2f}')
print('=' * 89)
실행결과
=========================================================================================
| End of training | test loss 5.50 | test ppl 245.70
=========================================================================================