트랜스포머 바닥부터 구현하기

WooSeongkyun·2023년 3월 29일
0

Reference

"""
nn.Module: Neural network의 모든것을 포괄하는 신경망 모델의
base class다. 다른 말로, 모든 신경망 모델은 nn.Module의 subclass이다
nn.Module을 상속한 subclass가 신경망 모델이 되기 위해선 두 메소드를 오버드라이브해야한다
    1.__init__(self): 내가 사용하고 싶은 신경말의 구성품을 정의
    2.forward(self,x):init에서 정의된 구성품을 연결한다

nn.Lienar(in_features,out_features,bias=True)
    - in_features : input sample의 크기를 지정한다
    - out_features : output sample의 크기를 지정한다
    - bias: additive bias를 학습할 것인지 정한다

nn.ModuleList():모듈들을 리스트 형태로 저장하는 것


"""

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    """embedding vector 차원 크기와 head수를 지정한다
    """
    def __init__(self,embed_size,heads):
        super(SelfAttention,self).__init__()
        self.embed_size= embed_size
        self.heads= heads
        """각 쿼리,키,값이 갖는 차원의 크기는 embed//heads"""
        self.head_dim= embed_size // heads
        assert (self.head_dim * heads == embed_size)
        self.values = nn.Linear(self.head_dim,
                                self.head_dim,
                                bias=False)
        self.keys = nn.Linear(self.head_dim,
                              self.head_dim,
                              bias=False)
        self.queries= nn.Linear(self.head_dim,
                                self.head_dim,
                                bias=False)
        """각 head가 학습한 쿼리,키,값 행렬들을 하나로 이어붙인다
        이어붙인 행렬은 embed_dim을 갖는다"""
        self.fc_out = nn.Linear(heads*self.head_dim,
                                embed_size)

    def forward(self,values,keys,queries,mask):
        """N: batchsize"""
        N=queries.shape[0]
        """source sentence length에 해당된다"""
        value_len, key_len, query_len = values.shape[1],keys.shape[1],queries.shape[1]
        """embedding 벡터를 self.head piece로 쪼개준다
        self.heads, self.head_dim으로 """
        values= values.reshape(N,value_len,self.heads,
                               self.head_dim)
        keys= keys.reshape(N,key_len,self.heads,
                           self.head_dim)
        queries= queries.reshape(N,query_len,self.heads,
                               self.head_dim)
        """
        쿼리 형태:    (N,query_len,heads,heads_dim)
        키 형태:      (N,key_len,heads,heads_dim)
        에너지 형태: (N,heads,query_len,key_len)
        에너지란?: 디코더가  인코더(소스)의 모든 hidden state를 참고하기 위하여, 각 인코더 셀의 중요도 가중치를 결정하기 위하여, 소프트맥스를 거쳐 출력한 값
        이전 상태의 디코더 상태와 각 인코더 셀의 히든스테이츠를 내적하여 계산한다.
        transformer에선 query와 key가 인코더와 디코더층에 대응된다.
        여기서 에너지는 각 배치마다, 임베딩 벡터를 내적하여 결정한다.
        """
        values= self.values(values)
        keys = self.keys(keys)
        queries =self.queries(queries)
        energy= torch.einsum('nqhd,nkhd->nhqk',[queries,keys])

        if mask is not None:
            """
            마스크 작업
            masked_fill(a,b): 텐서중 a라는 값을 갖는
            원소를 모두 b로 변환시킨다"""
            energy = energy.masked_fill(mask==0,float('-1e20'))

        """Attention(Q,K,V) 연산 시행하기"""
        attention= torch.softmax(energy/(self.embed_size **(1/2)),dim=3)
        """
        attention shape: (N,heads,query_len,key_len)
        value shape: (N,value_len,heads,heads_dim)
        (N,query_len,heads,head_dim)
        여기선 key_len과 value_len이 같은것을 활용
        >(N,query_len,heads *head_dim)으로 기존 마지막 2차원을 flatten하게 변환
        """
        out= torch.einsum('nhql,nlhd->nqhd',[attention,values]).reshape(N,query_len,self.heads *self.head_dim)
        out= self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self,embed_size, heads, dropout,
                 forward_expansion):
        """Module의 인자를 TransformerBlock이 받게끔 하는것
        nn.LayerNorm(): 정규화를 하나의 관측치마다 적용시키는 방법
        """

        super(TransformerBlock,self).__init__()
        self.attention= SelfAttention(embed_size,heads)
        self.norm1=nn.LayerNorm(embed_size)
        self.norm2=nn.LayerNorm(embed_size)
        self.feed_foward = nn.Sequential(
            nn.Linear(embed_size,forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size,embed_size)
        )
        self.dropout = nn.Dropout(dropout)
    """
    1.attention 연산 후 dropout
    2.attention과 query를 계산 후 normalization 후 dropout
    3.feedfoward 연산후 dropout
    4.x+feedforward(x) 더한 후 normlization
    """
    def forward(self,value,key,query,mask):
        attention= self.attention(value,key,query,mask)
        x=self.dropout(self.norm1(attention + query))
        forward= self.feed_foward(x)
        out= self.dropout(self.norm2(
            forward+x))
        return out

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length):
        """max_length는 positional embedding을
        위하여 제한을 둔 것"""
        super(Encoder,self).__init__()
        self.embed_size= embed_size
        self.device= device
        self.word_embedding= nn.Embedding(src_vocab_size,embed_size)
        self.positional_embedding= nn.Embedding(max_length, embed_size)
        """ModuleList를 활용하여 TransformerBLock을
        num_layers 갯수만큼 만듬"""
        self.layers= nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout= dropout,
                forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

        """Encoder의 foward propagation 과정
        1.word embedding과 positional embedding을 더함
        2.multihead attention
        3.add and normalization
        4. feed foward NN
        5.add and normalization"""
    def forward(self,x,mask):
        N,seq_length = x.shape
        """(1,2,...,seq_length-1)을 N(시퀀스 수)만큼 만든다"""
        positions= torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        """1.단어 임베딩 과 positional embedding을 더함"""
        out=self.dropout(self.word_embedding(x) + self.positional_embedding(positions))

        for layer in self.layers:
            """
            여기서 self.layer는 TransformerBLock으로써, 2,3,4,5의 과정을 처리함.
            """
            out= layer(out,out,out,mask)
        return out

class DecoderBlock(nn.Module):
    def __init__(self,
                 embed_size,
                 heads,
                 forward_expansion,
                 dropout,
                 device):
        super(DecoderBlock,self).__init__()
        """인코더 트랜스포머블록과 다른점은 앞에
        Attention + Add&Norm 블록이 하나 더있다는 점
        2번째 Attention 블록에 encoder의 output값을 받는다는 점"""
        self.attention=SelfAttention(embed_size,heads)
        self.norm= nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size,
            heads,
            dropout,
            forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x,
                value,
                key,
                src_mask,
                trg_mask):
        """
        1.masked multi-head attention 시행
            (이때 마스크는 다음 단어들의 정보를 막는 look-ahead mask)
        2. add & normalization 시행
        3. transformer_block실행
        """
        attention= self.attention(x,x,x,trg_mask)
        query= self.dropout(self.norm(attention +x))
        out= self.transformer_block(value,key,query,src_mask)
        return out

class Decoder(nn.Module):
    def __init__(self,
                 trg_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 device,
                 max_length):
        super(Decoder,self).__init__()
        self.device= device
        self.word_embedding= nn.Embedding(trg_vocab_size,embed_size)
        self.positional_embedding= nn.Embedding(max_length, embed_size)
        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size,
                          heads,
                          forward_expansion,
                          dropout,device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size,trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,enc_out,src_mask,trg_mask):
        """
        encoder와 구별되는 점으론 encoder 출력이 transformer에
        value와 key값으로써 들어가고, query에 이전 층 attention 출력이 들어간다는 점
        Decoder_Block의 attention은 첫번째 attention을,
        transformer_block은 2번째 attention 층 및 FFNN을 의미한다
        Attention>Attention>FFNN의 하나의 블록을 거쳐 나온 출력값은
        기존의 입력과 같은 크기(N,query_len,heads *head_dim)를 유지하기 때문에 계속하여 연산이 반복될 수 있다.(heads*head_dim=embed_dim)
        """
        N,seq_length= x.shape
        positions= torch.arange(0,seq_length).expand(N,seq_length).to(self.device)
        x= self.dropout((self.word_embedding(x)+self.positional_embedding(positions)))

        for layer in self.layers:
            x= layer(x,enc_out,enc_out,src_mask,trg_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_Size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device='mps',
                 max_length=100):
        super(Transformer,self).__init__()
        self.encoder= Encoder(
            src_vocab_size,
            embed_Size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )
        self.decoder= Decoder(
            trg_vocab_size,
            embed_Size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )

        self.src_pad_idx= src_pad_idx
        self.trg_pad_idx= trg_pad_idx
        self.device = device

        """src_pad_idx, trg_pad_idx 데이터셋에서 제공되어야 하는듯
            src는 (batch_size,src_vocab_size)를 갖는다"""
    def make_src_mask(self,src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self,trg):
        N,trg_len = trg.shape
        trg_mask = torch.tril(torch.ones(trg_len,trg_len)).expand(
            N,1,trg_len,trg_len
        )
        return trg_mask.to(self.device)

    def forward(self,src,trg):
        src_mask= self.make_src_mask(src)
        trg_mask= self.make_trg_mask(trg)
        enc_src= self.encoder(src,src_mask)
        out= self.decoder(trg,enc_src,src_mask, trg_mask)
        return out



profile
안녕하세요!

0개의 댓글