Transformer 구현 및 학습(2)

정리용 블로그·2024년 3월 19일
0

언어모델

목록 보기
3/5

일단 dataset과 tokenizer를 한번에 통괄하는 클래스를 하나 만들었다.

class NamuwikiDataset:
    def __init__(self, vocab_path, max_seq_len=1024, batch_size=8):
        self.dataset = load_dataset("heegyu/namuwiki-extracted", split='train').select_columns('text')
        self.max_seq_len = max_seq_len
        if os.path.exists(vocab_path):
            self.tokenizer = T5Tokenizer(vocab_path)
            # add bos token
            self.tokenizer.add_special_tokens({'bos_token': '<s>'})
        else:
            sys.stderr.write(f"Tokenizer model not found at {vocab_path}\n")
            sys.stderr.write("Please check the path and try again\n")
            sys.exit(1)

        self.vocab_size = self.tokenizer.vocab_size
        self.pad_id = self.tokenizer.pad_token_id
        self.unk_id = self.tokenizer.unk_token_id
        self.bos_id = self.tokenizer.bos_token_id
        self.eos_id = self.tokenizer.eos_token_id

        self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, collate_fn=self.collate_fn)


    def __len__(self):
        return len(self.dataset)

    def collate_fn(self, batch):
        batch = [x['text'] for x in batch]
        # encode text to ids with max length of the batch
        src = [[self.bos_id] + self.tokenizer.encode(text, add_special_tokens=False) + [self.eos_id] for text in batch] # [bsz, each_text_len + 2]
        max_len = max([len(ids) for ids in src])
        max_len = min(max_len, self.max_seq_len + 1)
        # if id exceeds max_seq_len, truncate it and doesn't append eos token
        src = [ids[:max_len] if len(ids) > max_len else ids + [self.pad_id] * (max_len - len(ids)) for ids in src] # [bsz, max_len+1]
        # <s> + text + </s> + pad if text length is less than max_seq_len else <s> + text[:max_seq_len]
        # when inference, text length is less than max_seq_len, because we set max_prompt_len less than max_seq_len
        src = torch.tensor(src, dtype=torch.long, device=device)    # [bsz, max_len+1]

        return src

    def decode(self, id):
        return self.tokenizer.decode(id)

tokenizer를 로드해서 T5Tokenizer로 wrap하고, bos token을 추가한다.
문제는 dataloader로 데이터를 불러올 때, padding이나 eos를 어떻게 처리할 지 인데, 일단 batch로 불러온 text를 encode해서 id로 만든 후에, 앞에 bos token을 붙인다. 만약 이 token의 길이가 max_seq_len+1을 넘어갈 경우에는 이 만큼 자르고, 넘어가지 않을 경우엔 뒤에 eos token을 붙이고 padding으로 채운다.
이렇게 되면 매우 긴 문장의 경우엔 bos token으로 시작해 중간에 텍스트가 잘린 형태가 될 것이고, 그보다 짧은 문장은 bos token으로 시작해, eos token으로 끝나고 나머지 공간은 pad token으로 채워질 것이다.
max_seq_len+1로 설정한 이유는 뒤에서 src와 tgt으로 나눌 때, src는 bos token을 포함하고 eos token을 포함한 [:, max_seq_len]으로 정하고 tgt은 bos token을 미포함하고 eos token을 포함한 [:, 1:]으로 설정하기 위함이다. 이렇게하면 src와 tgt의 길이가 max_seq_len이 된다. 어차피 pad는 mask되어 attention map에 포함되지도 않고, loss를 구하지도 않기 떄문에 문제가 없다.

    model.train()
    for epoch in tqdm(range(epochs)):
        for batch_idx, src in enumerate(tqdm(data_loader)):
            src = src.to(device)
            # src is [bsz, :max_seq_len]
            # tgt is [bsz, 1:max_seq_len+1]
            tgt = src[:, 1:]
            src = src[:, :-1]
            optimizer.zero_grad()
            out, _ = model(src) # out is [bsz, max_seq_len, vocab_size]
            loss = criterion(out.reshape(-1, out.size(-1)), tgt.reshape(-1))  # this automatically ignores padding tokens
            accelerate.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

따라서 이를 이용해 DecoderOnlyTransformer 모델을 학습한다.
그런데 문제가 좀 있는데, 어떤 모델을 만들지 처음에 정해놓지 않고 중구난방으로 만들다보니 DecoderOnlyTransformer 모델이 마치 EncoderOnly마냥 self attention을 하는 형태로 만들어지게 되었다. 그렇기 떄문에 auto regressive하게 생성하기 위해 q, k, v를 통째로 self-attention을 해야한다는 문제가 생겼다.
나중에 llama2 모델을 확인해보니, 모델 안에 kv_cache를 만들어, 이전의 WK, WV 갚은 저장을 해두어 input으로 들어온 한 token의 query만 이에 attention하여 적은 소모로 계산이 가능하게 해놓았던데, 나중에 다시 모델을 구성할 때, 이러한 기능도 꼭 넣어야 한다고 생각한다. 또한 self attention만 이용하면 이러한 방식과 결과도 달라진다고 생각하고 이로 인해 결과가 잘 나오지 않을 수도 있다고 생각을 하긴 했는데,


이런 텍스트로 generate를 해보았을 때,

def generate(model,
             dataset: NamuwikiDataset,
             input_text,
             max_gen_length=100):
    """
    Generate text using the model and tokenizer
    :param model:
    :param dataset:
    :param input_text: list[str]
    :param max_gen_length:
    :return:
    """
    max_prompt_len = dataset.max_seq_len // 2 + 1 # 1 for bos token

    # tokenizer encodes batch of text to ids with padding of max length of the batch
    # input_text should not be beyond max_seq_len
    input_ids = [[dataset.bos_id] + dataset.tokenizer.encode(text, add_special_tokens=False) for text in input_text]
    max_prompt_len = max([len(ids) for ids in input_ids])
    min_prompt_len = min([len(ids) for ids in input_ids])
    full_len = min(dataset.max_seq_len, max_prompt_len + max_gen_length)
    assert max_prompt_len <= max_prompt_len, f"input text length should be less than {max_prompt_len}"

    # add padding to max_seq_len to make input_ids have the same length
    token = torch.full((len(input_ids), dataset.max_seq_len), dataset.pad_id, dtype=torch.long, device=device)
    for i, ids in enumerate(input_ids):
        token[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=device)

    # generate text
    is_eos = torch.zeros(len(input_ids), dtype=torch.bool)
    model.eval()
    with torch.no_grad():
        for i in range(min_prompt_len, full_len):
            out, _ = model(token[:, :i])
            next_token = out[:, -1].argmax(dim=-1) # [bsz] with the highest probability of vocab
            # if token[i] has pad token, replace it with next_token else leave it as it is
            w = token[:, i]
            token[:, i] = torch.where(w == dataset.pad_id, next_token, w)

            # if all tokens are eos token, break
            is_eos = is_eos | (token[:, i] == dataset.eos_id)
            if is_eos.all():
                break

    # decode token to text
    generated_text = [dataset.decode(ids) for ids in token]
    return generated_text


위와 같이 말이 되지 않는 문장이 생기게 되었다.
이에 대한 예상 원인으로는 여러가지가 있는데,
일단 위에서 말한 decoder의 설계 문제가 있을 수 있고,
namuwiki dataset만으로 tokenizer와 pretraining을 하기엔 너무 데이터셋이 적고 preprocessing을 거치지 않았기 때문에 쓸모없는 데이터가 너무 많았을 수도 있다.
결과를 보면 세 문장 다 eos로 끝나지 않은 모습을 볼 수 있는데,
dataloader로 불러올 때, 긴 문장은 eos를 넣지 않고 그냥 끊어버렸는데,
나무위키 데이터셋에 내 생각보다도 긴 문장이 많아서 eos가 들어간 문장이 적었을 수도 있을 것 같다.
아니면 내가 찾지 못한 모델 내에서의 오타나 결함이 있을 수도 있을 것 같다.
모델을 정해놓고 참고해가면서 했어야했는데, 그냥 머리박기로 하다보니 모델이 제대로 만들어지지 않았을 수도 있다.
다음에 모델을 다시 만든다면 위에서 나온 결함들을 해결해, 더 정제된 데이터셋을 이용해 확실하게 언어 모델을 학습하고 싶다.

0개의 댓글

관련 채용 정보