코드 설명을 시작하기 전 시행착오를 좀 적자면
처음에 기존 Transformer, 즉 Encoder Decoder 구조를 만드려다가,
약간 LLM처럼 Decoder Only Transformer도 같이 만들려고 하는 와중에
Encoder와 Decoder를 하나의 Encoder 모델로 만들려고 하다가, 생각해보니까 기존의 Encoder Decoder 구조에서의 Decoder엔 Encoder의 output과 cross attention을 해야하기 때문에 이 부분에 대해 다르게 만들었어야했는데, 그냥 Encoder처럼 만들다가 뭔가 꼬이게 되어버렸다.
애초에 Decoder Only Transformer는 Encoder Decoder Transformer 구조의 둘 다와 미묘하게 다른데다가 굳이 따지자면 Encoder에 더 가까운 편인데 그냥 생각없이 만들다보니까 뭔가 꼬인 느낌이다.
그래서 결국 Encoder Decoder구조의 Encoder와 Decoder를 따로 만들고, Decoder Only Transformer 구조의 Decoder를 따로 만들기로 했다.
원래는 Tokenizer나 Word Embedding이나 다 Pretrain된 것으로 사용하려고 했는데, 그러려고 보니까 Word Embedding이 애초에 Transformer의 앞 Layer에 붙어있는 구조이기 때문에 Transformer를 처음부터 Training 한다면 굳이 Embedding도 Pretrain 된 것을 이용할 필요가 없다고 느꼈고, 그렇게 되면 Tokenizer도 그냥 내가 train하는게 낫다고 느껴서 그냥 처음부터 scratch로 train 하기로 했다.
일단 학습은 Namuwiki extracted 데이터셋으로 하기로 했기 때문에, Decoder Only Transformer를 이용해 학습하기로 했다.
class MultiHeadAttentionLegacy(nn.Module):
def __init__(self,
embed_dim=768,
num_heads=8,
dropout=0.0,
bias=True,
kdim=None,
vdim=None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
self.head_dim = embed_dim // num_heads
if kdim is None:
kdim = embed_dim
if vdim is None:
vdim = embed_dim
self.kdim = kdim
self.vdim = vdim
# Linear layers for the query, key, and value for each head
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = nn.Linear(kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(vdim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, attn_mask=None, is_causal=False):
# Get the batch size
bsz, tgt_len, embed_dim = query.size()
src_len = key.size(1)
# Project the query, key, and value
q = self.q_proj(query) # (bsz, tgt_len, embed_dim)
k = self.k_proj(key) # (bsz, src_len, embed_dim)
v = self.v_proj(value) # (bsz, src_len, embed_dim)
# Reshape the query, key, and value to have the same shape as the number of heads
q = q.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
k = k.contiguous().view(bsz * self.num_heads, src_len, self.head_dim)
v = v.contiguous().view(bsz * self.num_heads, src_len, self.head_dim)
# compute the attention scores
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) / (self.head_dim ** 0.5)
attn_softmax = nn.functional.softmax(attn_output_weights, dim=-1)
# Apply causal mask
if is_causal:
attn_softmax = attn_softmax.tril()
# Apply the attention mask
if attn_mask is not None:
attn_softmax = attn_mask * attn_softmax
# Apply the dropout
attn_softmax = self.dropout(attn_softmax) if self.dropout is not None else attn_softmax
# Apply the attention to the value
attn_output = torch.bmm(attn_softmax, v)
# Reshape the attention output to have the same shape as the number of heads
attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim) # it's the same as concatenating the heads
# Apply the output projection
attn_output = self.out_proj(attn_output)
return attn_output
일단 참고를 위해 기존 attention 구현을 만들어 놓았으며, 이를 개선하기 위해 xformers에서 지원하는 BlockSparseAttention(BigBirdAttention)을 이용해 효율적인 attention도 지원하도록 하였다.
class MultiHeadBigBirdAttention(nn.Module):
def __init__(self,
block_size=32,
embed_dim=1024,
num_heads=8,
sequence_length=2048,
dropout=0.1,
):
super().__init__()
BLOCKS = sequence_length // block_size
causal_layout = torch.ones((num_heads, BLOCKS, BLOCKS))
block_attn = BlockSparseAttention(layout=causal_layout, block_size=block_size, dropout=dropout,
num_heads=num_heads)
self.attn = MultiHeadDispatch(
dim_model=embed_dim,
residual_dropout=dropout,
num_heads=num_heads,
attention=block_attn
)
def forward(
self,
q, k, v,
attn_mask=None,
is_causal=False
):
if is_causal:
causal_mask = torch.tril(torch.ones((q.size(1), q.size(1))).to(q.device))
attn_mask = attn_mask * causal_mask if attn_mask is not None else causal_mask
return self.attn(q, k, v, attn_mask=attn_mask)
이 후 DecoderOnlyBlock, DecoderOnly, DecoderOnlyTransformer를 구축했고,
class DecoderOnlyBlock(nn.Module):
def __init__(self,
max_seq_len=1024,
embed_dim=768,
num_heads=8,
use_legacy=False,
):
super().__init__()
# self.attn = MultiHeadAttentionLegacy(embed_dim=embed_dim, num_heads=num_heads) if use_legacy else nn.MultiheadAttention(embed_dim, num_heads)
self.attn = (MultiHeadBigBirdAttention
(embed_dim=embed_dim, num_heads=num_heads, sequence_length=max_seq_len, dropout=0.1)) \
if not use_legacy else MultiHeadAttentionLegacy(embed_dim=embed_dim, num_heads=num_heads)
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.ffn = PositionWiseFeedForward(embed_dim=embed_dim)
self.residual = ResidualConnection(embed_dim=embed_dim, pre_norm=True)
def forward(self, q, tgt_mask=None, is_causal=True):
q = self.residual(q, lambda x: self.attn(x, x, x, attn_mask=tgt_mask, is_causal=is_causal))
q = self.residual(q, lambda x: self.ffn(x))
return q
class DecoderOnly(nn.Module):
def __init__(self,
max_seq_len=1024,
embed_dim=768,
num_heads=8,
num_layer=6,
use_legacy=False,
):
super().__init__()
self.layers = nn.ModuleList([DecoderOnlyBlock(max_seq_len=max_seq_len, embed_dim=embed_dim, num_heads=num_heads,
use_legacy=use_legacy) for _ in range(num_layer)])
def forward(self, q, tgt_mask=None, is_causal=True):
for layer in self.layers:
q = layer(q, tgt_mask=tgt_mask, is_causal=is_causal)
return q
class DecoderOnlyTransformer(nn.Module):
def __init__(self,
max_seq_len=1024,
embed_dim=768,
num_heads=8,
num_layer=6,
vocab_size=32000,
use_legacy=False,
):
super().__init__()
self.decoder = DecoderOnly(max_seq_len=max_seq_len, embed_dim=embed_dim, num_heads=num_heads, num_layer=num_layer,
use_legacy=use_legacy)
self.te = TokenEmbedding(embed_dim, vocab_size)
self.pe = PositionalEncoding(embed_dim, max_seq_len)
self.generator = nn.Linear(embed_dim, vocab_size)
def decode(self, tgt, tgt_mask=None):
tgt = self.te(tgt)
tgt = self.pe(tgt)
return self.decoder(tgt, tgt_mask=tgt_mask, is_causal=True)
def forward(self, tgt):
tgt_mask = self.make_pad_mask(tgt, tgt)
decoder_out = self.decode(tgt, tgt_mask=tgt_mask)
out = self.generator(tgt)
out = nn.functional.log_softmax(out, dim=-1)
return out, decoder_out
def make_pad_mask(self, q, kv, pad_idx=0):
# q: [batch, q_len]
# kv: [batch, kv_len]
q_len = q.size(1)
kv_len = kv.size(1)
q_mask = q.ne(pad_idx)
q_mask = rearrange(q_mask, 'b i -> b 1 i 1')
q_mask = repeat(q_mask, 'b 1 i k -> b 1 i k', k=kv_len)
kv_mask = kv.ne(pad_idx)
kv_mask = rearrange(kv_mask, 'b i -> b 1 1 i')
kv_mask = repeat(kv_mask, 'b 1 1 i -> b 1 j i', j=q_len)
mask = q_mask & kv_mask
mask.requires_grad = False
return mask
이를 학습하기만 하면 되는데, nlp 쪽을 다뤄본적이 없어 Tokenizer를 이용해 Dataloader를 구축하는 것이 좀 힘들어 시간이 좀 오래 걸릴 것 같다.