[Bert] CODE

공부·2022년 12월 6일
0

I love Paris.

라는 문장을 Burt 모델에 집어넣어야 하는데 집어넣을 수 없다. 문장을 Burt input shape에 맞게 변형하여 만들어야 한다.

  1. 입력하고자 하는 문장을 벡터로 변환한다 [임베딩]

    임베딩하는 방법
    : 문장의 처음[cls], 단어의 represtation[R], 문장의 끝[sep] 을 표시한다.

  2. 모델을 Fine-Tuning하여 분석의 특징에 맞게 바꾸어준다.

  3. 1번의 결과를 2번에 넣어 결과를 추출한다.

burt를 통해 문장의 임베딩을 추출해낸다. Burt는 기본적으로 transformer를 써야한다. transformer로 나온 것들 중 가장 큰 오픈소스가 Hugging Face다.

from transformers import BertModel, BertTokenizer
import torch
  • transformer 안에 BertModel과 BertTokenizer가 존재한다.

bert-base-uncased 사용

  • 12개의 인코더가 있는 BERT 기반 모델이며 모두 소문자로 변환한 uncased 토큰으로 학습
  • BERT-base를 사용하고 있으므로 표현 벡터(임베딩) 크기는 768
    • 표현 벡터의 크기 = 은닉 벡터의 길이

감성분석은 대소문자가 구분되지 않아도 된다. 객체명 인식은 대소문자 구분이 필요한다.

model = BertModel.from_pretrained('bert-base-uncased')

bert-base-uncased 모델 다운로드 및 로드

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

bert-base-uncased 모델을 사전학습할 때 사용한 tokenizer 다운로드 및 로드

sentence = 'I love Paris'

내가 넣고 싶은 문장을 sentence에 대입한다.

tokens = tokenizer.tokenize(sentence)
print(tokens)

tokenizer 대로 우리 문장을 자른다.

['i', 'love', 'paris']

tokens = ['[CLS]'] + tokens + ['[SEP]']
print(tokens)

burt 모델은 두 문장씩 인식하기 때문에 입력의 처음을 알려주고, 토큰을 넣어주고, 입력의 끝을 알려주어야 한다.

['[CLS]', 'i', 'love', 'paris', '[SEP]']

tokens = tokens + ['[PAD]'] + ['[PAD]']
print(tokens)

처리하고자하는 길이가 각양각색이므로 최대 길이에 맞추어주어야 한다. 이미지의 padding처럼 padd식을 사용한다.

['[CLS]', 'i', 'love', 'paris', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]
print(attention_mask)

pad라고 되어 있는 토큰은 attention score를 계산하지 않는다. 속도에 영향을 끼치기 때문에 반드시 해야 한다.

[1, 1, 1, 1, 1, 0, 0, 0, 0]

token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)

모든 토큰을 토큰 ID로 변환한다.

[101, 1045, 2293, 3000, 102, 0, 0, 0, 0]

token_ids = torch.tensor(token_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)

keras는 tensorflow와 다르게 tensor 구조를 사용하지 않는다. 1차원은 스칼라 값이고, 2차원은 벡터, 3차원 이상은 텐서라고 한다. keras는 그대로 집어넣을 수 있지만, tensorflow를 이용하려면 tensor구조로 변환시켜야 한다.

hidden_rep, cls_head = model(token_ids, attention_mask = attention_mask)
print(hidden_rep)

hidden_rep는 레이어의 가장 마지막 임베딩 값을, cls_head는 해당 값들을 모두 고려한 문장의 요약 값이 나온다. 문장 단위로 처리하고 싶으면 cls_head값을 사용한다.

  • last_hidden_state[0][0] : 첫 번째 토큰인 [CLS]의 표현 벡터
  • last_hidden_state[0][1] : 두 번째 토큰인 'i'의 표현 벡터
  • last_hidden_state[0][2] : 세 번째 토큰인 'love'의 표현 벡터
print(last_hidden_state[0][1])

확인한다.

tensor([ 2.2365e-01, 6.5364e-01, -2.2941e-01, -4.4871e-01, -9.5561e-02,
2.1067e-01, -1.3226e-01, 1.4089e+00, 1.0668e-01, -2.9041e-02,
-2.0937e-01, -5.2475e-01, 3.4771e-02, 2.7329e-01, 2.9269e-01,
2.2714e-01, 4.7734e-01, 3.4942e-01, 1.2349e-01, 8.3038e-01,
6.9123e-01, 2.3612e-01, -8.5010e-01, -2.0250e-02, 3.0894e-01,
-2.4169e-01, -4.3335e-01, 1.5679e-01, 9.1365e-02, -3.6651e-01,
-1.4478e-02, -9.2568e-02, 5.8239e-01, 7.3787e-01, -7.5602e-01,
-1.9031e-01, 3.5895e-01, -2.0138e-01, -4.4486e-01, 1.4417e-01,
8.1281e-02, -3.0345e-01, -1.2730e-01, -6.9157e-01, 2.7232e-01,
-1.2914e+00, 2.2492e-01, -7.1515e-02, 7.0234e-01, -7.8603e-01,
-8.6375e-02, 1.8487e-02, 5.6987e-02, 2.9224e-01, -1.8191e-01,
1.1739e+00, -6.1941e-01, -3.6969e-01, 4.6749e-01, 5.2700e-01,
-2.3669e-02, -1.0395e-01, 5.6715e-01, -6.3963e-01, -4.2078e-01,
9.4930e-01, -4.8859e-01, 1.1718e-01, 1.6313e-01, -4.4423e-01,
7.0174e-01, 1.6644e-01, 1.5139e-02, 9.2753e-02, -2.4862e-01,
1.6582e-01, -5.7120e-01, 5.0441e-01, 1.5555e-02, -4.7822e-01,
-1.2932e-01, 9.3940e-01, -7.1131e-01, 6.2031e-01, -3.7397e-01,
1.0241e-01, -2.1538e-01, 1.5746e-01, -2.7759e-01, -1.0195e+00,
-4.5111e-01, -6.5975e-01, -4.5680e-01, 8.2152e-01, 2.9419e-01,
-1.4503e+00, -6.7572e-01, -4.5254e-01, 1.1727e-02, -4.1652e-01,
1.6480e-02, -5.1592e-01, 1.7268e-01, 7.2106e-01, -9.2694e-01,
-4.1237e-02, -2.0453e-01, 4.3615e-02, 1.6651e+00, -8.0247e-01,
5.2143e-01, -5.8073e-01, 2.2373e-01, -6.6165e-01, 2.8686e-01,
8.2596e-01, 8.5921e-01, -3.6443e-01, -7.9419e-01, -1.3362e-01,
-2.6584e-01, 3.9913e-01, -3.2963e-01, 1.7722e-01, -2.0135e-01,
-7.0184e-01, 2.4723e-01, 3.8644e-01, -2.4687e-01, -7.2885e-01,
-5.4910e-01, 4.7497e-01, -3.4729e-01, -9.8077e-01, 4.2597e-01,
3.8058e-01, 2.9377e-01, 4.7424e-01, 4.7811e-01, -5.7267e-01,
2.7706e-01, 4.4554e-01, 7.0947e-03, -1.0947e-01, 1.0781e+00,
-5.2112e-02, -5.9456e-01, 3.2369e-01, 7.6836e-01, 1.9795e-01,
5.7750e-02, -1.3433e-01, 2.5922e-02, 4.5950e-01, -5.7990e-01,
-4.5908e-01, -3.5885e-01, 6.8502e-01, 1.5277e-02, 3.0647e-01,
3.1753e-01, -5.0508e-01, 5.9512e-02, 4.9015e-01, -3.8842e-01,
5.6514e-01, 3.5490e-01, 5.8253e-01, 1.1216e-01, -4.0576e-01,
-3.4624e-01, -9.2309e-01, 4.8706e-01, 1.0393e+00, 8.4353e-01,
-3.3789e-01, 5.1315e-01, 4.7512e-01, 3.0839e-02, 2.5342e-01,
-4.8350e-01, 2.0173e-01, 8.0378e-02, 2.2809e-02, -4.6203e-01,
-2.2545e-01, 6.0600e-01, -5.9437e-01, 7.3096e-02, 6.1802e-01,
4.5159e-01, -1.5890e-01, 1.2062e-01, 7.6384e-01, 4.3349e-01,
-4.7153e-01, -3.4319e-02, -3.1152e-01, 4.4276e-01, 7.7329e-01,
1.7823e-01, -4.2157e-01, 6.4350e-01, 2.5408e-01, 3.3767e-01,
6.6232e-01, 5.3018e-01, -1.5990e-01, -1.8430e-03, 3.4505e-01,
-4.8595e-01, 3.1062e-01, -2.3385e-01, -4.9405e-01, 2.4844e-01,
-1.7157e-01, -1.1732e-01, -3.2414e-02, -3.2593e-02, 2.5656e-01,
4.0298e-01, -1.6135e-01, -1.5998e+00, 6.2897e-01, -2.5658e-01,
3.6113e-01, -3.5929e-01, -6.4096e-02, -6.0351e-01, 9.0629e-02,
-5.7726e-01, -5.7916e-01, 7.7140e-02, 7.0365e-01, 6.8171e-01,
4.6132e-01, -7.9373e-01, -6.5585e-01, -1.7960e-01, -1.0100e-01,
2.9783e-01, 3.7116e-01, 8.0531e-01, 1.3535e+00, 7.5272e-02,
8.8138e-02, 5.1565e-01, -7.7518e-01, -2.6482e-02, 3.4903e-01,
-2.5056e-01, -1.0148e+00, -8.3682e-01, -5.3850e-01, -6.8621e-01,
4.9744e-01, -3.3191e-01, 8.5343e-01, -3.3801e-01, 1.6430e-01,
2.2564e-01, -6.5955e-01, 1.5393e+00, -3.3408e-01, -3.1948e-01,
-2.7393e-01, -5.4071e-01, -2.4693e-01, 2.0252e-01, 3.4830e-01,
-3.6906e-01, 7.6637e-02, -1.0271e-02, 6.3987e-01, 9.1297e-02,
-5.2880e-01, 5.6669e-01, -7.9297e-02, 1.0211e-01, -4.6177e-01,
1.5421e-01, 1.0091e+00, 3.8423e-02, -4.6183e-03, -6.4856e-02,
5.6802e-02, -2.8851e-01, -3.6307e-01, 1.6994e-01, 3.7190e-02,
-1.4736e-01, 3.9060e-01, -1.3411e-01, -1.4662e-01, -1.1278e-01,
1.5644e-01, 1.0063e+00, 2.3521e-01, 4.5906e-01, -3.3830e-01,
-8.3492e-01, -3.8113e-01, -7.0029e-01, 8.7593e-01, 2.2131e-01,
-1.7978e-01, 2.1739e-03, 2.2068e-02, -3.9097e+00, -1.5422e-01,
3.5245e-02, -4.4309e-01, 7.3892e-01, -3.1915e-01, -2.4455e-01,
-6.8192e-02, -7.0023e-01, -4.9143e-01, 4.7838e-01, -3.4434e-01,
3.3744e-01, -7.1846e-02, 5.1823e-01, -7.7763e-01, 9.1886e-01,
-1.3507e+00, -3.0661e-01, 1.4994e+00, -2.6080e-01, 4.8350e-01,
1.3054e-02, -1.6011e+00, -2.6464e-01, 6.9146e-01, -1.8287e-01,
4.3331e-01, -9.2075e-01, 2.8617e-01, -6.5120e-01, 4.0433e-02,
7.6375e-01, -3.4808e-01, -1.2961e-01, 2.3617e-01, -4.7257e-01,
3.2213e-01, 4.9476e-01, -3.9025e-01, -3.9282e-02, -8.8296e-01,
-2.3676e-01, 2.0576e-01, 7.8493e-01, 2.9166e-01, -8.9267e-01,
-6.9995e-01, -1.1621e-01, 3.1348e-01, -3.0402e-02, -3.5697e-01,
7.6259e-01, -3.8881e-01, 1.4062e-01, 2.3462e-02, 3.3668e-01,
6.4391e-01, 4.1262e-01, -4.5005e-01, 9.8163e-01, 2.8540e-01,
-1.2271e+00, 5.9497e-01, 6.8793e-01, -2.0614e-01, -4.0656e-01,
-5.0118e-01, -2.5067e-01, 8.9082e-01, 2.3043e-01, 3.5030e-02,
-8.9694e-02, -9.6329e-01, -6.1572e-01, 2.4865e-01, -2.1465e-01,
-2.9813e-01, 2.6593e-01, -4.3126e-01, 1.3545e-01, -5.7002e-02,
-2.2326e-01, -7.0460e-01, -2.0718e-01, -9.5383e-01, 1.0063e-01,
1.5615e-01, 1.4894e-01, 2.5866e-01, 6.2881e-01, 5.2115e-02,
6.5751e-01, -1.2936e-01, 3.1473e-01, -8.7717e-02, 7.3883e-01,
-1.2833e+00, 8.0107e-01, -1.1498e-01, 7.0118e-01, -5.2350e-01,
9.8697e-02, -4.5489e-01, 4.4888e-01, -6.7135e-02, -6.9953e-01,
4.1966e-01, -1.4551e-01, -2.1602e-01, 1.3741e-01, -1.8891e-01,
-4.0281e-02, -4.7467e-01, -6.0825e-01, -9.2448e-01, 1.6341e-01,
7.3957e-01, -4.4879e-01, -5.4063e-01, -9.8654e-02, 9.1529e-01,
3.4073e-02, -9.5282e-01, -7.0746e-01, 4.2436e-01, 5.1130e-02,
-5.6547e-01, -6.2789e-01, -2.0289e-01, -4.0827e-01, 4.8593e-01,
-1.3983e-01, -4.5109e-01, -2.2471e-02, -5.5091e-02, -1.1020e+00,
-4.2352e-01, 2.5242e-01, -1.6897e-02, -1.7915e-01, -2.2832e-01,
-5.2508e-01, 9.4824e-01, 1.4480e-01, 3.5931e-01, -6.3862e-01,
-1.0716e+00, -9.4374e-02, 7.5291e-01, -9.4283e-02, 2.0895e-01,
4.2567e-01, -1.8934e-01, 6.6017e-01, -6.6652e-01, 4.4894e-01,
-1.3327e+00, 5.2749e-01, -4.4665e-01, -1.0009e+00, -9.4026e-02,
-8.0740e-02, 8.0899e-01, 9.7773e-01, 7.8945e-01, 7.9632e-01,
1.5203e-01, -4.2365e-01, -6.3749e-01, -6.2841e-01, -8.2197e-02,
2.9019e-01, -9.6255e-01, -2.2180e-01, 2.7472e-01, -5.2187e-01,
-2.9262e-01, -3.0104e-01, 7.0484e-01, -2.7861e-01, -4.5972e-02,
-2.6090e-01, 1.1933e-01, 2.5849e-01, 1.1659e+00, -7.4335e-01,
-6.9255e-01, 3.7246e-01, 4.7926e-01, -3.4053e-01, 4.6152e-01,
-7.7556e-03, -9.4658e-01, -1.3587e+00, -7.3678e-01, 4.8375e-01,
-6.2155e-01, 6.8931e-01, 1.0368e+00, 3.9489e-01, 4.7054e-01,
-5.4903e-01, 2.8023e-01, -1.4400e-01, -2.7306e-01, -5.3359e-01,
-4.0098e-01, -7.8677e-02, 1.1498e+00, -3.1929e-01, -1.1925e-01,
-5.4443e-01, -2.8112e-02, 1.2542e-01, -4.3709e-01, 4.2124e-01,
-5.7101e-02, -5.2534e-01, -6.8623e-01, 1.2638e-01, 3.1548e-02,
3.3472e-01, -1.5904e-01, -5.0015e-01, 3.2961e-01, -2.2866e-01,
-5.3680e-01, -1.0118e+00, -3.0877e-01, 7.4049e-01, -3.0200e-01,
5.3244e-01, 1.7533e-01, 9.0922e-01, 4.7946e-02, 3.4100e-01,
-6.4228e-01, 6.2724e-01, 2.8463e-01, 4.9425e-01, -1.3037e+00,
5.2868e-01, -1.9798e-01, 5.1998e-01, -6.6654e-01, -1.4844e-01,
-1.1480e-01, -1.4905e-01, 1.5045e-01, 1.0798e-02, 2.1238e-01,
1.7019e-01, -3.2850e-01, 4.2309e-02, -3.3738e-01, -9.3419e-01,
5.9148e-02, -2.8340e-02, 5.7189e-01, -1.2283e-01, -6.2760e-02,
-9.9991e-02, 4.2451e-02, 3.7635e-01, 6.4487e-01, 8.7779e-01,
8.0027e-01, -3.7165e-01, 8.3460e-01, -1.7753e-01, -5.8062e-01,
1.0067e-01, -3.4937e-01, 4.4698e-01, -4.0376e-02, 3.8843e-01,
-4.0283e-01, 3.9475e-01, -3.7148e-01, -1.0815e-01, 7.0277e-01,
4.5009e-03, -6.2236e-01, 6.3353e-01, -4.2915e-01, -4.1764e-01,
5.5494e-01, 8.2269e-02, 5.4140e-01, 1.1784e-01, 1.5233e-01,
7.4020e-02, 8.3469e-01, -3.5486e-02, 8.3431e-02, -1.9437e-01,
3.6847e-01, -1.0350e+00, 2.3906e-01, -5.5285e-01, 6.6109e-01,
8.6004e-01, 3.6660e-02, -8.7034e-02, -4.1473e-01, -1.1148e-01,
-6.6199e-01, 1.5014e-01, -5.5832e-01, -2.8433e-01, -1.8642e-01,
5.8694e-01, 1.1001e+00, 2.1808e-01, -5.7002e-01, 6.1907e-01,
-9.2778e-01, -5.1305e-01, -1.9022e-01, -8.5656e-01, -3.1219e-01,
8.1579e-01, 2.7145e-01, 1.6772e-01, 1.6957e-01, -7.8931e-01,
3.8314e-01, 4.4218e-01, 6.7751e-01, 7.7700e-02, 1.5916e-01,
-5.2882e-03, 1.0134e+00, -1.9337e-01, 8.5150e-01, -1.8481e-02,
-1.3644e+00, 3.9875e-02, 3.4745e-01, 6.5113e-02, 2.0463e-01,
4.6259e-01, 2.2908e-01, -5.5292e-01, -2.5372e-01, -9.7568e-02,
1.7253e-01, 1.1800e+00, -3.6090e-01, -5.1904e-02, 3.3350e-01,
-6.8763e-01, -1.3519e+00, 2.0438e-01, 5.2987e-01, 9.8921e-01,
4.4452e-01, 3.0961e-01, -6.5212e-01, -7.6537e-01, 5.3358e-02,
7.2843e-01, -6.6854e-01, 5.5170e-02, -3.3123e-02, 3.8828e-01,
-5.5752e-01, -3.0912e-01, -3.4038e-01, 2.6431e-01, 1.2934e-01,
4.3893e-01, 1.0860e+00, 7.8357e-02, 6.1868e-01, -2.5454e-01,
-9.0809e-01, 2.1211e-01, 4.3651e-01, -2.5165e-01, 4.8652e-01,
-4.7844e-01, -4.6564e-01, -2.7391e-01, -1.6360e+00, -1.0416e-01,
-1.1840e+00, 7.0985e-01, 5.8680e-02, 7.1903e-01, -3.1943e-01,
8.0358e-01, -7.6008e-01, 3.9513e-01, -5.8160e-01, 2.7178e-01,
7.6680e-02, -1.2132e-01, 2.9638e-01, 6.5445e-01, 5.0970e-01,
-5.9009e-01, -4.8557e-02, -3.2783e-01, 3.1994e-01, -8.8879e-01,
7.2325e-01, -7.4886e-02, -6.5689e-01, 3.8978e-01, 9.3510e-01,
-8.2564e-01, 4.3794e-01, 8.4586e-02, 4.3033e-01, 5.3826e-01,
-3.8785e-01, -5.8888e-01, -6.4476e-01, -7.0063e-02, -5.7163e-01,
5.7367e-01, 1.5694e-01, -1.3442e-01, -2.5962e-01, -4.4676e-01,
9.1701e-01, 3.4829e-01, -5.4059e-01, -6.0303e-01, 4.9440e-01,
-1.6879e-02, -3.2807e-01, -1.6992e-01, -6.0576e-01, 7.3596e-01,
-2.3484e-01, 2.0991e-01, -1.1968e+00, 8.6241e-01, 4.5729e-01,
4.7098e-01, -5.3314e-01, -1.0370e+00, 1.6372e-01, -8.5133e-02,
-3.4767e-01, 8.1252e-01, -8.2418e-01, 5.7300e-01, -5.1195e-01,
7.3901e-01, -4.2194e-01, 5.2264e-01, 2.6561e-01, 2.2641e-01,
-3.5472e-01, 5.5171e-01, -2.3674e-01], grad_fn=)

입력하고자 하는 문장을 벡터로 변환하였다.

burt모델은 중간중간 masking한다.

감성분석

import tensorflow_datasets as tfds
import tensorflow as tf
(ds_train, ds_test), ds_info = tfds.load('imdb_reviews',
          split = (tfds.Split.TRAIN, tfds.Split.TEST),
          as_supervised=True,
          with_info=True)

문자가 모두 숫자로 표현되어 있어 문자로 변환하기 위해 decode과정이 필요하다.

def convert_example_to_feature(review):
  return tokenizer.encode_plus(review,
                add_special_tokens = True, # add [CLS], [SEP]
                max_length = max_length, # max length of the text that can go to BERT
                pad_to_max_length = True, # add [PAD] tokens
                return_attention_mask = True, # add attention mask to not focus on pad tokens
              )

리뷰에다 cls + maxlength만큼 padding attention mask 하라고 encoding_plus에 들어가 있음. 1번을 한 큐에 처리에주는 함수. 전처리를 한번 하고

profile
리액트

0개의 댓글