Transformer 구현 및 학습(1)

정리용 블로그·2024년 3월 12일
0

언어모델

목록 보기
2/5

github

코드 설명을 시작하기 전 시행착오를 좀 적자면
처음에 기존 Transformer, 즉 Encoder Decoder 구조를 만드려다가,
약간 LLM처럼 Decoder Only Transformer도 같이 만들려고 하는 와중에
Encoder와 Decoder를 하나의 Encoder 모델로 만들려고 하다가, 생각해보니까 기존의 Encoder Decoder 구조에서의 Decoder엔 Encoder의 output과 cross attention을 해야하기 때문에 이 부분에 대해 다르게 만들었어야했는데, 그냥 Encoder처럼 만들다가 뭔가 꼬이게 되어버렸다.
애초에 Decoder Only Transformer는 Encoder Decoder Transformer 구조의 둘 다와 미묘하게 다른데다가 굳이 따지자면 Encoder에 더 가까운 편인데 그냥 생각없이 만들다보니까 뭔가 꼬인 느낌이다.
그래서 결국 Encoder Decoder구조의 Encoder와 Decoder를 따로 만들고, Decoder Only Transformer 구조의 Decoder를 따로 만들기로 했다.

원래는 Tokenizer나 Word Embedding이나 다 Pretrain된 것으로 사용하려고 했는데, 그러려고 보니까 Word Embedding이 애초에 Transformer의 앞 Layer에 붙어있는 구조이기 때문에 Transformer를 처음부터 Training 한다면 굳이 Embedding도 Pretrain 된 것을 이용할 필요가 없다고 느꼈고, 그렇게 되면 Tokenizer도 그냥 내가 train하는게 낫다고 느껴서 그냥 처음부터 scratch로 train 하기로 했다.

일단 학습은 Namuwiki extracted 데이터셋으로 하기로 했기 때문에, Decoder Only Transformer를 이용해 학습하기로 했다.

class MultiHeadAttentionLegacy(nn.Module):
    def __init__(self,
                 embed_dim=768,
                 num_heads=8,
                 dropout=0.0,
                 bias=True,
                 kdim=None,
                 vdim=None,
                 ):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
        self.head_dim = embed_dim // num_heads
        if kdim is None:
            kdim = embed_dim
        if vdim is None:
            vdim = embed_dim

        self.kdim = kdim
        self.vdim = vdim

        # Linear layers for the query, key, and value for each head
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.k_proj = nn.Linear(kdim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(vdim, embed_dim, bias=bias)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attn_mask=None, is_causal=False):
        # Get the batch size
        bsz, tgt_len, embed_dim = query.size()
        src_len = key.size(1)

        # Project the query, key, and value
        q = self.q_proj(query)  # (bsz, tgt_len, embed_dim)
        k = self.k_proj(key)  # (bsz, src_len, embed_dim)
        v = self.v_proj(value)  # (bsz, src_len, embed_dim)

        # Reshape the query, key, and value to have the same shape as the number of heads
        q = q.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
        k = k.contiguous().view(bsz * self.num_heads, src_len, self.head_dim)
        v = v.contiguous().view(bsz * self.num_heads, src_len, self.head_dim)

        # compute the attention scores
        attn_output_weights = torch.bmm(q, k.transpose(1, 2)) / (self.head_dim ** 0.5)
        attn_softmax = nn.functional.softmax(attn_output_weights, dim=-1)

        # Apply causal mask
        if is_causal:
            attn_softmax = attn_softmax.tril()
        # Apply the attention mask
        if attn_mask is not None:
            attn_softmax = attn_mask * attn_softmax

        # Apply the dropout
        attn_softmax = self.dropout(attn_softmax) if self.dropout is not None else attn_softmax

        # Apply the attention to the value
        attn_output = torch.bmm(attn_softmax, v)

        # Reshape the attention output to have the same shape as the number of heads
        attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim)  # it's the same as concatenating the heads

        # Apply the output projection
        attn_output = self.out_proj(attn_output)

        return attn_output

일단 참고를 위해 기존 attention 구현을 만들어 놓았으며, 이를 개선하기 위해 xformers에서 지원하는 BlockSparseAttention(BigBirdAttention)을 이용해 효율적인 attention도 지원하도록 하였다.

class MultiHeadBigBirdAttention(nn.Module):
    def __init__(self,
                 block_size=32,
                 embed_dim=1024,
                 num_heads=8,
                 sequence_length=2048,
                 dropout=0.1,
                 ):
        super().__init__()

        BLOCKS = sequence_length // block_size
        causal_layout = torch.ones((num_heads, BLOCKS, BLOCKS))

        block_attn = BlockSparseAttention(layout=causal_layout, block_size=block_size, dropout=dropout,
                                          num_heads=num_heads)

        self.attn = MultiHeadDispatch(
            dim_model=embed_dim,
            residual_dropout=dropout,
            num_heads=num_heads,
            attention=block_attn
        )

    def forward(
            self,
            q, k, v,
            attn_mask=None,
            is_causal=False
    ):
        if is_causal:
            causal_mask = torch.tril(torch.ones((q.size(1), q.size(1))).to(q.device))
            attn_mask = attn_mask * causal_mask if attn_mask is not None else causal_mask
        return self.attn(q, k, v, attn_mask=attn_mask)

이 후 DecoderOnlyBlock, DecoderOnly, DecoderOnlyTransformer를 구축했고,

class DecoderOnlyBlock(nn.Module):
    def __init__(self,
                 max_seq_len=1024,
                 embed_dim=768,
                 num_heads=8,
                 use_legacy=False,
                 ):
        super().__init__()
        # self.attn = MultiHeadAttentionLegacy(embed_dim=embed_dim, num_heads=num_heads) if use_legacy else nn.MultiheadAttention(embed_dim, num_heads)
        self.attn = (MultiHeadBigBirdAttention
                     (embed_dim=embed_dim, num_heads=num_heads, sequence_length=max_seq_len, dropout=0.1)) \
            if not use_legacy else MultiHeadAttentionLegacy(embed_dim=embed_dim, num_heads=num_heads)

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.ffn = PositionWiseFeedForward(embed_dim=embed_dim)
        self.residual = ResidualConnection(embed_dim=embed_dim, pre_norm=True)

    def forward(self, q, tgt_mask=None, is_causal=True):
        q = self.residual(q, lambda x: self.attn(x, x, x, attn_mask=tgt_mask, is_causal=is_causal))
        q = self.residual(q, lambda x: self.ffn(x))
        return q

class DecoderOnly(nn.Module):
    def __init__(self,
                 max_seq_len=1024,
                 embed_dim=768,
                 num_heads=8,
                 num_layer=6,
                 use_legacy=False,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([DecoderOnlyBlock(max_seq_len=max_seq_len, embed_dim=embed_dim, num_heads=num_heads,
                                                  use_legacy=use_legacy) for _ in range(num_layer)])

    def forward(self, q, tgt_mask=None, is_causal=True):
        for layer in self.layers:
            q = layer(q, tgt_mask=tgt_mask, is_causal=is_causal)
        return q
        
class DecoderOnlyTransformer(nn.Module):
    def __init__(self,
                 max_seq_len=1024,
                 embed_dim=768,
                 num_heads=8,
                 num_layer=6,
                 vocab_size=32000,
                 use_legacy=False,
                 ):
        super().__init__()
        self.decoder = DecoderOnly(max_seq_len=max_seq_len, embed_dim=embed_dim, num_heads=num_heads, num_layer=num_layer,
                               use_legacy=use_legacy)

        self.te = TokenEmbedding(embed_dim, vocab_size)
        self.pe = PositionalEncoding(embed_dim, max_seq_len)

        self.generator = nn.Linear(embed_dim, vocab_size)

    def decode(self, tgt, tgt_mask=None):
        tgt = self.te(tgt)
        tgt = self.pe(tgt)
        return self.decoder(tgt, tgt_mask=tgt_mask, is_causal=True)

    def forward(self, tgt):
        tgt_mask = self.make_pad_mask(tgt, tgt)
        decoder_out = self.decode(tgt, tgt_mask=tgt_mask)
        out = self.generator(tgt)
        out = nn.functional.log_softmax(out, dim=-1)
        return out, decoder_out

    def make_pad_mask(self, q, kv, pad_idx=0):
        # q: [batch, q_len]
        # kv: [batch, kv_len]
        q_len = q.size(1)
        kv_len = kv.size(1)
        q_mask = q.ne(pad_idx)
        q_mask = rearrange(q_mask, 'b i -> b 1 i 1')
        q_mask = repeat(q_mask, 'b 1 i k -> b 1 i k', k=kv_len)

        kv_mask = kv.ne(pad_idx)
        kv_mask = rearrange(kv_mask, 'b i -> b 1 1 i')
        kv_mask = repeat(kv_mask, 'b 1 1 i -> b 1 j i', j=q_len)

        mask = q_mask & kv_mask
        mask.requires_grad = False
        return mask

이를 학습하기만 하면 되는데, nlp 쪽을 다뤄본적이 없어 Tokenizer를 이용해 Dataloader를 구축하는 것이 좀 힘들어 시간이 좀 오래 걸릴 것 같다.

0개의 댓글

관련 채용 정보