SentencePiece 알고리즘 센텐스피스 subword tokenization

안 형준·2021년 10월 11일
2

NLP

목록 보기
1/1
post-thumbnail

Sentencepiece tokenizer는 언어에 무관하고, 띄어쓰기 유무에 영향을 받지 않으며, 매우 빠르고, 기존의 vocab_size를 벗어난 경우 발생하던 [UNK] 토큰을 확연히 줄여줍니다. 어절 안쪽을 쪼개서("안녕하세요" -> "안녕, 하세요") tokenize하기 때문에 더 발전된 언어 모델을 만들 수 있습니다. 그 원리는 무엇일까요?

KAIST AAILab youtube 기계학습 강좌 1~9강, SentencePiece Demystified을 참고했습니다

Unigram Language Model

글은 순서를 갖는 sequence 이기 때문에 어떠한 문장이든 앞에 나온 단어에 기반하여 뒤에 나올 단어를 유추할 수 있습니다.

  • ex) 오늘 마라탕 먹어야 하는데 같이 갈 사람 구함

위 문장에서 맥락을 고려한다면 "갈" 뒤에 "사람" 이 나올 확률은 "앵무새" 보다 높을 것입니다. 이전 몇 개의 토큰을 바탕으로 예측하는지에 따라 N-gram 에서 N이 바뀝니다. 이를 N-gram model이라고 합니다. 예를 들어, N=2인 경우 bigram model이 되며, 오늘 마라탕, 마라탕 먹어야 등에서 볼 수 있듯 직전 1개의 토큰을 고려한 확률분포를 사용합니다.

그러나 unigram model은 맥락을 전혀 신경쓰지 않습니다.어떠한 문장이 등장할 확률은 그저 전체 말뭉치(Corpus)에서 각 토큰이 등장할 확률을 곱한 것에 불과하고, 이는 토큰의 순서를 고려하지 않으므로 어절별로 토큰화한다고 했을 때

  • ex) 오늘 마라탕 먹어야 하는데 같이 갈 사람 구함

  • ex) 오늘 사람 갈 먹어야 하는데 마라탕 같이 구함

위의 두 문장은 같은 등장 확률을 보입니다.

img

구현한 SentencePiece 역시 Unigram Model을 사용하고, 정식 package는 Byte Pair Encoding 혹은 Unigram(특정한 tokenizer model)을 지원합니다.

Byte Pair Encoding

BPE는 말뭉치에서 자주 등장하는 연이은 토큰이 있다면, 그것을 하나의 토큰으로 합쳐버리는 과정입니다. 아래의 그림에서 AB가 X로 , XC가 Y로 바뀌는 과정이 핵심입니다. 이 때, 합쳐진 토큰의 등장 빈도도 함께 저장합니다. BPE는 SentencePiece 모델에서 초기 unigram 빈도를 얻는데 활용됩니다.

import re
import collections


class BytePairEncoder:
    def __init__(self):
        self.merges = None
        self.characters = None
        self.tokens = None
        self.vocab = None

    def format_word(self, text, space_token="_"):
        return " ".join(list(text)) + " " + space_token

    def initialize_vocab(self, text):
        text = re.sub("\s+", " ", text)
        all_words = text.split()
        vocab = {}
        for word in all_words:
            word = self.format_word(word)
            vocab[word] = vocab.get(word, 0) + 1
        tokens = collections.Counter(text)
        return vocab, tokens

    def get_bigram_counts(self, vocab):
        pairs = {}
        for word, count in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pair = (symbols[i], symbols[i + 1])
                pairs[pair] = pairs.get(pair, 0) + count
        return pairs

    def merge_vocab(self, pair, vocab_in):
        vocab_out = {}
        bigram = re.escape(" ".join(pair))
        p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
        bytepair = "".join(pair)
        for word in vocab_in:
            w_out = p.sub(bytepair, word)
            vocab_out[w_out] = vocab_in[word]
        return vocab_out, (bigram, bytepair)

    def find_merges(self, vocab, tokens, num_merges):
        merges = []
        for i in range(num_merges):
            pairs = self.get_bigram_counts(vocab)
            best_pair = max(pairs, key=pairs.get)
            best_count = pairs[best_pair]
            vocab, (bigram, bytepair) = self.merge_vocab(best_pair, vocab)
            merges.append((r"(?<!\S)" + bigram + r"(?!\S)", bytepair))
            tokens[bytepair] = best_count
        return vocab, tokens, merges

    def fit(self, text, num_merges):
        vocab, tokens = self.initialize_vocab(text)
        self.characters = set(tokens.keys())
        self.vocab, self.tokens, self.merges = self.find_merges(
            vocab, tokens, num_merges
        )


if __name__ == "__main__":
    bpe = BytePairEncoder()
    for i in range(0, 3):
        bpe.fit("house home hostile", i)
        print("--" * 12)
        print(f"BPE vocabs - {i} merges")
        print(bpe.vocab, "\n")
        print(f"BPE tokens - {i} merges")
        print(bpe.tokens)


"house home hostile"로 BPE를 해본 결과입니다. 초기에는 'h o u s e _' 처럼 모든 글자가 분해되어 있고, 토큰 역시 한 글자씩 이루어진 것을 볼 수 있습니다. 1번 merge하면, 'ho u s e _', 'ho m e _'로 'ho'가 합쳐지고, 'ho'와 그 등장빈도 3번이 토큰에 기록되었습니다. 2번째 merge에서는 'e_'가 추가되었습니다.

Training SentencePiece

SentencePiece를 훈련하는 과정은 Variational inference의 일종으로, 관측한 데이터(Evidence)와 모델 파라미터(theta)가 있을 때, 가설(Hypothesis)에 대한 분포 PP를 variational parameter 를 도입해 QQ로 근사합니다. 근사의 목적은 PP 자체를 찾아내는 것은 너무 복잡하기 때문입니다. 이때 HH가 여러 개일 경우 각각은 서로 독립이고, 해당하는 숨겨진 변수에만 의존한다는 Mean Field theory를 사용합니다.

P(HE,θ)    Q(HE,λ)=iHqi(Hiλi)P(H| E, \theta) \rightarrow\;\; Q(H|E, \lambda) = \prod_i^{|H|} q_i(H_i|\lambda_i)

근사의 목적은 Evidence Lower Bound L(λ,θ)L(\lambda, \theta) 를 극대화하는 것으로, 이 값이 커질수록 근사 QQ와 실제의 사후분포 PP의 차이가 작아집니다.

L(λ,θ)=H[Q(HE,λ)ln  P(H,Eθ)Q(HE,λ)ln  Q(HE,λ)]L(\lambda, \theta) = \sum_H [Q(H|E, \lambda)\ln\; P(H, E|\theta) - Q(H|E, \lambda)\ln\; Q(H|E, \lambda)]

동시에 최적화(L(λ,θ)L(\lambda, \theta)의 극대화) 하는 대신, 각 λi\lambda_i마다 차례로 최적화한다면,
jj 번째 variational parameter λj\lambda_j에 대해 극대화해야 하는 식은 다음과 같습니다.

L(λj)=HiHqi(HiE,λi){ln  P(HE,θ)kHln  qk(HkE,λk)}L(\lambda_j) = \sum_{H} \prod_i^{|H|}q_i(H_i|E, \lambda_i)\{\ln \;P(H|E,\theta) -\sum_k^{|H|}\ln \;q_k(H_k|E,\lambda_k) \}

정리하면 아래의 식이 됩니다

HjKL(qj(HjE,λj)P~(H,Eθ))+C  ,(where  ln  P~(H,Eθ)=Eqij[ln  P(H,Eθ)]+C)\sum_{H_j} -KL(q_j(H_j|E,\lambda_j)||\tilde{P}(H, E|\theta)) + C'\;, \\(where\; \ln \; \tilde{P}(H, E|\theta) = E_{q_{i\neq j}}[\ln \; P(H,E|\theta)]+ C)

KL Divergence는 0 이상이므로, L(λj)L(\lambda_j)를 최대화하려면 KLD를 0으로 만들면 되고, 이는 두 분포를 같게 만듦으로서 실현됩니다. 따라서 아래와 같은 결론을 얻습니다

qj(HjE,λj)=P~(H,Eθ)ln  qj(HjE,λj)=Eqij[ln  P(H,Eθ)]+Const.q_j^*(H_j|E,\lambda_j) = \tilde{P}(H, E|\theta) \\ \ln \;q_j^*(H_j|E,\lambda_j) = E_{q_{i\neq j}}[\ln \; P(H, E|\theta)]+ Const.

이 계산을 하다보면 HiH_i 의 분포를 계산할 때, HjH_j의 기댓값, 분산이 필요한 경우가 있습니다. 특히 서로의 기댓값 등을 필요로 하는 경우, HiH_i에 대해 최적화하고, 이를 바탕으로 HjH_j에 대해 최적화하고, 다시 HiH_i 에 대해 최적화하여 분포가 수렴할 때까지 iterative하게 반복하는 coordinated optimization을 사용하는데, SentencePiece도 그런 경우입니다.

SentencePiece의 evidence log-likelihood

img

unigram probability를 사용하기 위해 숨겨진 변수 π\pi 를 도입합니다. 또한 베이즈 정리를 사용하기 위해, 사전분포로 Dirichlet Distribution을 이용합니다. Dirichlet Distribution은 model parameter α\alpha에 의존하는 확률분포로 베이즈 정리에 의해 다음이 성립할 때,

PosteriorLikelihoodPriorPosterior \propto Likelihood * Prior

다항분포를 가능도로, Dirichlet Distribution을 사전분포로 하는 경우 사후분포 역시 Dirichlet Distribution이 되어 conjugacy를 가지므로, 다항 분포의 모델링에 장점이 있어 Dirichlet prior을 사용합니다.

p(πα)=Dir(π,aK)=k=1KπkαK1p(xπ)=n=1Nk=1Kπkxnkp(\pi|\alpha) = Dir(\pi, a_K) = \prod_{k=1}^K \pi_k^{\alpha_K -1} \\ p(x|\pi) = \prod_{n=1}^N\prod_{k=1}^K \pi_k^{x_{nk}}

xnkx_{nk}는 sequence에서 nn번째 토큰이 kk번째 unigram인 경우에는 1이고, 아닌 경우에는 0이 되어, p(xπ)p(x|\pi)는 sequence을 특정한 segmentation으로 tokenize 했을 때, 등장하는 각 토큰의 unigram probability을 곱한 것과 같습니다. (unigram language model)

중간정리) π\pi는 unigram probability를 표현하기 위해 도입한 숨겨진 변수(variational parameter)이며, zz는 BPE 등으로 얻어진 가능한 토큰을 표현하기 위해 도입한 숨겨진 변수(variational parameter)입니다.(ex- h ell o 로 쪼개졌다면 h, ell, o 모두 zjkz_{jk}로 표현할 수 있습니다 )


사후분포 QQ를 추정하기 위해 mean field 근사로 두 변수(π,z)(\pi, z)이 독립이라고 하면, 위 섹션의 결론인 (ln  qj(HjE,λj)=Eqij[ln  P(H,Eθ)]+Const.\ln \;q_j^*(H_j|E,\lambda_j) = E_{q_{i\neq j}}[\ln \; P(H, E|\theta)]+ Const.)에 의해 다음이 성립합니다.

img

img

실제 분포를 넣어서 계산하면 다음으로 정리할 수 있습니다.

ln  p(xπ,z)p(πα)=ln  [(n=1Nk=1Kπkznk)(k=1KπkαK1)]=n=1Nk=1K(znk+αK1)ln  πk\ln \;p(x|\pi,z)p(\pi|\alpha) = \ln \;[(\prod_{n=1}^N\prod_{k=1}^K \pi_k^{z_{nk}})(\prod_{k=1}^K \pi_k^{\alpha_K -1})] \\ =\sum_{n=1}^N \sum_{k=1}^K (z_{nk} + \alpha_K -1) \ln \; \pi_k

π\pizz에 대해 기댓값을 구해보면,

img

img

이 되므로 π\pizz의 기댓값이 엮여 있어 coordinated optimization을 시행해야 합니다.


추가적인 정리를 하면 qπq_\pi는 Dirichlet Distribution의 형태인 것을 알 수 있고, znkz_{nk} 는 등장 유무를 0,10,1 로 표현하는 변수이므로, 기댓값은 등장 횟수에 의존합니다. kk 번째 토큰에 대한 zz의 기댓값은 새로운 변수CkC_k를 도입하여 다음과 같이 표현할 수 있습니다.

img

Dirichlet Distribution의 기댓값은 Digamma function 을 이용해 표현할 수 있으므로, (wikipedia)

img

ln  qz\ln \;q_zln  qπ\ln \;q_\pi에서 기댓값E[znk],E[ln  πk]E[z_{nk}], E[ln\; \pi_k]을 위에서 얻은 수식을 통해 정리할 수 있습니다.

ln  qπk=1K(αK+Ck1)ln  πkln  qzn=1Nk=1KznkEqπ[ln  π]=n=1Nk=1Kznk(ψ(αk)ψ(kKαk))\ln \;q_\pi \propto \sum_{k=1}^K (\alpha_K + C_k -1) \ln \; \pi_k \\ \ln \;q_z \propto \sum_{n=1}^N \sum_{k=1}^K z_{nk} E_{q_\pi}[\ln \; \pi] = \sum_{n=1}^N \sum_{k=1}^K z_{nk}(\psi(\alpha_k) - \psi(\sum_{k'}^K\alpha_k'))

α\alphaα+Ck\alpha + C_k로 업데이트하여 계산합니다. 이때 qπkq_{\pi_k}는 model parameter가 (αK+Ck)(\alpha_K + C_k)로 변화한 것을 고려하여 coordinated optimization 한다면 다음과 같습니다.

ln  qz(z)n=1Nk=1Kznk{ψ(αK+Ck)ψ(KαK+kKCk)}qz(z)n=1Nk=1K[eψ(αK+Ck)eψ(KαK+kKCk)]znk\ln \; q_z(z) \propto \sum_{n=1}^N \sum_{k=1}^K z_{nk}\{ \psi(\alpha_K + C_k) - \psi(K \alpha_K + \sum_{k'}^KC_{k'})\} \\ q_z(z) \propto \prod_{n=1}^N \prod_{k=1}^K [\frac { e^{\psi(\alpha_K + C_k)}} {e^{\psi(K \alpha_K + \sum_{k'}^KC_{k'})}}]^{z_{nk}}

마지막으로, αK\alpha_K 를 0으로 지정합니다.

π\pizz에 대해 ln  qz\ln \; q_zln  qπ\ln \; q_\pi가 수렴할때까지 교대로 최적화한다면 qzqπq_z q_\pi는 원래의 분포의 좋은 근사가 됩니다.

qz(z)n=1Nk=1K[eψ(Ck)eψ(kKCk)]znkqπ(π)k=1Kπk(Ck1)q_z(z) \propto \prod_{n=1}^N \prod_{k=1}^K [\frac { e^{\psi(C_k)}} {e^{\psi(\sum_{k'}^KC_{k'})}}]^{z_{nk}} \\ q_\pi(\pi) \propto \prod_{k=1}^K \pi_k ^{(C_k -1)}

python code에서는 가장 높은 확률을 갖는 tokenization을 구한 뒤, 토큰의 등장 빈도를 업데이트하는 과정의 반복으로 구현되었습니다.

자세한 python 코드 구현은 reference를 참고 바랍니다

Reference

SentencePiece Demystified: 코드 구현 및 설명의 일부를 이용했습니다.

이항분포, 다항분포, 베타분포, 디리클레분포

Dirichlet Distribution: Conjugate Prior for Multinomial Distribution

profile
물리학과 졸업/ 인공지능 개발자로의 한 걸음

0개의 댓글