일단 dataset과 tokenizer를 한번에 통괄하는 클래스를 하나 만들었다.
class NamuwikiDataset:
def __init__(self, vocab_path, max_seq_len=1024, batch_size=8):
self.dataset = load_dataset("heegyu/namuwiki-extracted", split='train').select_columns('text')
self.max_seq_len = max_seq_len
if os.path.exists(vocab_path):
self.tokenizer = T5Tokenizer(vocab_path)
# add bos token
self.tokenizer.add_special_tokens({'bos_token': '<s>'})
else:
sys.stderr.write(f"Tokenizer model not found at {vocab_path}\n")
sys.stderr.write("Please check the path and try again\n")
sys.exit(1)
self.vocab_size = self.tokenizer.vocab_size
self.pad_id = self.tokenizer.pad_token_id
self.unk_id = self.tokenizer.unk_token_id
self.bos_id = self.tokenizer.bos_token_id
self.eos_id = self.tokenizer.eos_token_id
self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, collate_fn=self.collate_fn)
def __len__(self):
return len(self.dataset)
def collate_fn(self, batch):
batch = [x['text'] for x in batch]
# encode text to ids with max length of the batch
src = [[self.bos_id] + self.tokenizer.encode(text, add_special_tokens=False) + [self.eos_id] for text in batch] # [bsz, each_text_len + 2]
max_len = max([len(ids) for ids in src])
max_len = min(max_len, self.max_seq_len + 1)
# if id exceeds max_seq_len, truncate it and doesn't append eos token
src = [ids[:max_len] if len(ids) > max_len else ids + [self.pad_id] * (max_len - len(ids)) for ids in src] # [bsz, max_len+1]
# <s> + text + </s> + pad if text length is less than max_seq_len else <s> + text[:max_seq_len]
# when inference, text length is less than max_seq_len, because we set max_prompt_len less than max_seq_len
src = torch.tensor(src, dtype=torch.long, device=device) # [bsz, max_len+1]
return src
def decode(self, id):
return self.tokenizer.decode(id)
tokenizer를 로드해서 T5Tokenizer로 wrap하고, bos token을 추가한다.
문제는 dataloader로 데이터를 불러올 때, padding이나 eos를 어떻게 처리할 지 인데, 일단 batch로 불러온 text를 encode해서 id로 만든 후에, 앞에 bos token을 붙인다. 만약 이 token의 길이가 max_seq_len+1을 넘어갈 경우에는 이 만큼 자르고, 넘어가지 않을 경우엔 뒤에 eos token을 붙이고 padding으로 채운다.
이렇게 되면 매우 긴 문장의 경우엔 bos token으로 시작해 중간에 텍스트가 잘린 형태가 될 것이고, 그보다 짧은 문장은 bos token으로 시작해, eos token으로 끝나고 나머지 공간은 pad token으로 채워질 것이다.
max_seq_len+1로 설정한 이유는 뒤에서 src와 tgt으로 나눌 때, src는 bos token을 포함하고 eos token을 포함한 [:, max_seq_len]으로 정하고 tgt은 bos token을 미포함하고 eos token을 포함한 [:, 1:]으로 설정하기 위함이다. 이렇게하면 src와 tgt의 길이가 max_seq_len이 된다. 어차피 pad는 mask되어 attention map에 포함되지도 않고, loss를 구하지도 않기 떄문에 문제가 없다.
model.train()
for epoch in tqdm(range(epochs)):
for batch_idx, src in enumerate(tqdm(data_loader)):
src = src.to(device)
# src is [bsz, :max_seq_len]
# tgt is [bsz, 1:max_seq_len+1]
tgt = src[:, 1:]
src = src[:, :-1]
optimizer.zero_grad()
out, _ = model(src) # out is [bsz, max_seq_len, vocab_size]
loss = criterion(out.reshape(-1, out.size(-1)), tgt.reshape(-1)) # this automatically ignores padding tokens
accelerate.backward(loss)
optimizer.step()
optimizer.zero_grad()
따라서 이를 이용해 DecoderOnlyTransformer 모델을 학습한다.
그런데 문제가 좀 있는데, 어떤 모델을 만들지 처음에 정해놓지 않고 중구난방으로 만들다보니 DecoderOnlyTransformer 모델이 마치 EncoderOnly마냥 self attention을 하는 형태로 만들어지게 되었다. 그렇기 떄문에 auto regressive하게 생성하기 위해 q, k, v를 통째로 self-attention을 해야한다는 문제가 생겼다.
나중에 llama2 모델을 확인해보니, 모델 안에 kv_cache를 만들어, 이전의 WK, WV 갚은 저장을 해두어 input으로 들어온 한 token의 query만 이에 attention하여 적은 소모로 계산이 가능하게 해놓았던데, 나중에 다시 모델을 구성할 때, 이러한 기능도 꼭 넣어야 한다고 생각한다. 또한 self attention만 이용하면 이러한 방식과 결과도 달라진다고 생각하고 이로 인해 결과가 잘 나오지 않을 수도 있다고 생각을 하긴 했는데,
이런 텍스트로 generate를 해보았을 때,
def generate(model,
dataset: NamuwikiDataset,
input_text,
max_gen_length=100):
"""
Generate text using the model and tokenizer
:param model:
:param dataset:
:param input_text: list[str]
:param max_gen_length:
:return:
"""
max_prompt_len = dataset.max_seq_len // 2 + 1 # 1 for bos token
# tokenizer encodes batch of text to ids with padding of max length of the batch
# input_text should not be beyond max_seq_len
input_ids = [[dataset.bos_id] + dataset.tokenizer.encode(text, add_special_tokens=False) for text in input_text]
max_prompt_len = max([len(ids) for ids in input_ids])
min_prompt_len = min([len(ids) for ids in input_ids])
full_len = min(dataset.max_seq_len, max_prompt_len + max_gen_length)
assert max_prompt_len <= max_prompt_len, f"input text length should be less than {max_prompt_len}"
# add padding to max_seq_len to make input_ids have the same length
token = torch.full((len(input_ids), dataset.max_seq_len), dataset.pad_id, dtype=torch.long, device=device)
for i, ids in enumerate(input_ids):
token[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)
# generate text
is_eos = torch.zeros(len(input_ids), dtype=torch.bool)
model.eval()
with torch.no_grad():
for i in range(min_prompt_len, full_len):
out, _ = model(token[:, :i])
next_token = out[:, -1].argmax(dim=-1) # [bsz] with the highest probability of vocab
# if token[i] has pad token, replace it with next_token else leave it as it is
w = token[:, i]
token[:, i] = torch.where(w == dataset.pad_id, next_token, w)
# if all tokens are eos token, break
is_eos = is_eos | (token[:, i] == dataset.eos_id)
if is_eos.all():
break
# decode token to text
generated_text = [dataset.decode(ids) for ids in token]
return generated_text
위와 같이 말이 되지 않는 문장이 생기게 되었다.
이에 대한 예상 원인으로는 여러가지가 있는데,
일단 위에서 말한 decoder의 설계 문제가 있을 수 있고,
namuwiki dataset만으로 tokenizer와 pretraining을 하기엔 너무 데이터셋이 적고 preprocessing을 거치지 않았기 때문에 쓸모없는 데이터가 너무 많았을 수도 있다.
결과를 보면 세 문장 다 eos로 끝나지 않은 모습을 볼 수 있는데,
dataloader로 불러올 때, 긴 문장은 eos를 넣지 않고 그냥 끊어버렸는데,
나무위키 데이터셋에 내 생각보다도 긴 문장이 많아서 eos가 들어간 문장이 적었을 수도 있을 것 같다.
아니면 내가 찾지 못한 모델 내에서의 오타나 결함이 있을 수도 있을 것 같다.
모델을 정해놓고 참고해가면서 했어야했는데, 그냥 머리박기로 하다보니 모델이 제대로 만들어지지 않았을 수도 있다.
다음에 모델을 다시 만든다면 위에서 나온 결함들을 해결해, 더 정제된 데이터셋을 이용해 확실하게 언어 모델을 학습하고 싶다.