이 code review는 다음 colab code를 기반으로 합니다.
D4RL은 Datasets for Deep-Dricen Reinforcement Learning의 줄임말로,
Deep Reinforcement Learning을 위한 데이터셋을 제공합니다.
Maze2D, AntMaze, Adroit, Gym, Flow, CARLA 등 다양한 환경에서 수집한 데이터셋을 제공합니다.
dataset = "medium" # medium / medium-replay / medium-expert
rtg_scale = 1000 # scale to normalize returns to go
이 min decision transformer code에서는 D4RL의 Walker2d 환경에서 데이터를 수집하고, Decision Transformer를 학습시킵니다.
Decision Transformer에서 dataset을 다음과 같이 소개합니다!
medium datasets은 1백만 steps 후에 생성된 정말 "medium" dataset이며 전문가 policy의 약 1/3정도의 스코어를 냅니다.
medium-replay datasets은 medium policy를 통해 학습된 policy가 생성한 dataset(이 환경에선 25k-400k steps의 데이터셋이라 합니다.)
medium-expert datasets은 전문가의 시연과 suboptimal data, 부분 학습된 policy나 random policy가 섞인 데이터셋입니다.
env_name = 'Walker2d-v3'
rtg_target = 5000
env_d4rl_name = f'walker2d-{dataset}-v2'
가장 기초가 되는 Transfomer의 구조부터 살펴봅시다.
class MaskedCausalAttention(nn.Module):
def __init__(self, h_dim, max_T, n_heads, drop_p):
super().__init__()
self.n_heads = n_heads
self.max_T = max_T
self.q_net = nn.Linear(h_dim, h_dim)
self.k_net = nn.Linear(h_dim, h_dim)
self.v_net = nn.Linear(h_dim, h_dim)
self.proj_net = nn.Linear(h_dim, h_dim)
self.att_drop = nn.Dropout(drop_p)
self.proj_drop = nn.Dropout(drop_p)
ones = torch.ones((max_T, max_T))
mask = torch.tril(ones).view(1, 1, max_T, max_T)
# register buffer makes sure mask does not get updated
# during backpropagation
self.register_buffer('mask',mask)
*n_heads는 multi-head attention에서 head의 개수를 의미합니다.
head의 개수가 많을수록 병렬적으로 attention을 많이 계산할 수 있습니다.
병렬적으로 attention으로 계산하는 것은 다양한 시각으로 정보를 학습할 수 있다고 해석 가능합니다!
q_net, k_net, v_net은 각각 query, key, value를 위한 fully connected layer(h_dim -> h_dim)입니다.
여기서 h_dim은 hidden dimension을 의미합니다.
forward 코드를 계속해서 봅시다.
def forward(self, x):
B, T, C = x.shape # batch size, seq length, attention_dim * n_heads
# N = num heads(병렬 attention = 얼마나 다양한 시각으로 볼 것인지), D = attention_dim
N, D = self.n_heads, C // self.n_heads
# rearrange q, k, v as (B, N, T, D) ->(Batch_size, num_heads, seq_length, attention_dim)
q = self.q_net(x).view(B, T, N, D).transpose(1,2)
k = self.k_net(x).view(B, T, N, D).transpose(1,2)
v = self.v_net(x).view(B, T, N, D).transpose(1,2)
# weights (B, N, T, T)
weights = q @ k.transpose(2,3) / math.sqrt(D)
# causal mask applied to weights
weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
# normalize weights, all -inf -> 0 after softmax
normalized_weights = F.softmax(weights, dim=-1)
# attention (B, N, T, D)
attention = self.att_drop(normalized_weights @ v)
# gather heads and project (B, N, T, D) -> (B, T, N*D)
attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)
out = self.proj_drop(self.proj_net(attention))
return out
일단 batch_size는 빼고 생각해봅시다!
input으로 넣어주는 raw data는 그림과 같습니다.
가로축은 time step 수만큼의 sequence length를 의미하고, 새로축은 hidden dimension을 의미합니다.
현재 새로축이 h_dim * num_heads로 나타나져 있는데,
q_net, k_net, v_net을 통과하면서 Sequnce length(가로)와 hidden dimension(세로)이 유지됩니다.
이는 각 network가 network를 통과하면서 query, key, value를 만들어내도록 학습되겠다고 해석할 수 있습니다.
# rearrange q, k, v as (B, N, T, D) ->(Batch_size, num_heads, seq_length, attention_dim)
q = self.q_net(x).view(B, T, N, D).transpose(1,2)
k = self.k_net(x).view(B, T, N, D).transpose(1,2)
v = self.v_net(x).view(B, T, N, D).transpose(1,2)
이제 새로축인 hidden dimension을 num_heads와 attention_dim으로 나누어줍시다.
num_heads는 병렬적으로 attention을 계산할 때 몇개의 head를 사용할 것인지를 의미하고,
attention_dim은 각 head에서 attention을 계산할 때 사용할 dimension을 의미합니다.
쉽게 생각하면, 각 병렬 attention에서의 hidden dimension이라고 생각하면 됩니다.
다른 색깔의 block이 각각 다른 attention을 계산하게 됩니다.
총 3개가 나오게 되죠?
하나의 input sequence를 각각 query, key, value에 대해 network를 통과시켰기 때문입니다.
그래서 각 층마다 query, key, value가 나오게 됩니다.
# weights (B, N, T, T)
weights = q @ k.transpose(2,3) / math.sqrt(D)
이제 각 병렬 attention에서의 query와 key를 곱해줍니다.
query는 attention을 계산할 때 사용할 정보를 담고 있고, key는 attention을 계산할 때 참고할 정보를 담고 있습니다.
쉽게 말해서 query가 이거랑 관련있는 놈 누구야! 하면 key가 관련있는 놈들을 찾아줍니다.(관련이 있으면 코사인 유사도(행렬곱)이 크니까!)
# causal mask applied to weights
weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
# normalize weights, all -inf -> 0 after softmax
normalized_weights = F.softmax(weights, dim=-1)
아- 벌써 코드가 어려워요
예시를 들어가며 한번 이해해봅시다.
import torch
weights = torch.tensor([[1,1,1],[1,1,1],[1,1,1]], dtype=torch.float32)
ones = torch.ones((3,3))
mask = torch.tril(ones).view(1, 3, 3)
# mask = [[1,0,0],
# [1,1,0],
# [1,1,1]]
저는 논문과 비슷하게 다음과 같이 예시를 들어보았습니다.
weights = weights.masked_fill(mask[...,:3,:3] == 0, float('-inf'))
그리고 masking을 해주기 위해 masked_fill을 사용해주었습니다.
mask가 0인 부분은 -inf로 채워주고, 1인 부분은 그대로 두는 것입니다.
이후 softmax를 통해 normalize해줍니다.
normalized_weights = F.softmax(weights, dim=-1)
# normalized_weights = [[1.0000, 0.0000, 0.0000],
# [0.5000, 0.5000, 0.0000],
# [0.3333, 0.3333, 0.3333]]
attention map을 잘 뽑아내는 것을 확인할 수 있습니다.
논문 리뷰에서 뒤를 보고 예측하면 반칙이라고 했죠?
casual mask는 이런식으로 구현됩니다.
뒤에 것을 참고하지 않고 앞에 것만 참고하도록 mask를 씌워줍니다.
또한 softmax를 통해 각 weight를 normalize해줍니다.
더 잘 이해해보기 위해,
query의 맨 위 파란색 줄을 가져왔습니다.
조그만 블럭 하나가 float값을 가지고 있으며, 새로줄 하나가 한 state, action, reward에 대한 정보를 가지고 있습니다.
이걸로 key에 대해 나랑 비슷한 놈들을 찾아보자!(행렬곱이 높은놈) 라고 하면,
(D,T)@(T,D)가 되어서 (D,D)가 됩니다.
그리고 이걸 softmax를 통해 normalize해주면 (D,D)가 나오게 됩니다.
즉, attention weight이 나오게 되겠죠?
# attention (B, N, T, D)
attention = self.att_drop(normalized_weights @ v)
이걸 이후 value에 대해 곱해주면 (D,D)@(D,T)가 되어서 (D,T)가 됩니다.
즉, attention을 계산한 결과가 나오게 됩니다.
value를 한국말로 직역하는 과정에서 헷갈릴 수 있는데,
정말 가치! 를 말하는 것이 아닌, 해당 state, action, reward에 대한 정보를 말합니다.
(저는 처음 attention 공부할 때 많이 헷갈렸던 기억이 있네요 ㅎㅎ)
# (B, N, T, D)
context = normalized_weights @ v
논문에서처럼 여러개의 attention을 쌓아서 사용해야 하는데,
그럼 원래 raw input과 dimension이 같아야겠죠?
원래 raw input dimension은 (Sequence len, attention_dim * num_heads)이었고, 현재 나온 결과는 (Sequence len, attention_dim)이기 때문에,
raw input dimension과 같지 않습니다!
따라서 병렬 attention heads(여러 관점에서 본다고 했죠?)에서 계산했던 결과를 다시 합쳐줘야 합니다.
# gather heads and project (B, N, T, D) -> (B, T, N*D)
attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)
이후 linear projection을 통해 마무리합니다.
out = self.proj_drop(self.proj_net(attention))
자 이제 Attention 코드를 바탕으로 Block을 만들어봅시다.
지금 얼마나 했냐고요?
빨간박스만큼 했습니다. 힘내봅시다!
class Block(nn.Module):
def __init__(self, h_dim, max_T, n_heads, drop_p):
super().__init__()
self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
self.mlp = nn.Sequential(
nn.Linear(h_dim, 4*h_dim),
nn.GELU(),
nn.Linear(4*h_dim, h_dim),
nn.Dropout(drop_p),
)
self.ln1 = nn.LayerNorm(h_dim)
self.ln2 = nn.LayerNorm(h_dim)
def forward(self, x):
# Attention -> LayerNorm -> MLP -> LayerNorm
x = x + self.attention(x) # residual
x = self.ln1(x)
x = x + self.mlp(x) # residual
x = self.ln2(x)
return x
forward를 보며 설명해보곘습니다.
x = x + self.attention(x) # residual
input을 넣으면 위에서 한참 설명했던 attention을 진행합니다.
attention 결과물로 뭐가 나왔죠? (B, T, D)가 나왔습니다.
T는 sequence length, D는 dimensiond이었죠?
여기에 input을 더해주면 (B, T, D) + (B, T, D)가 되어서 (B, T, D)가 됩니다.
이것을 우리는 residual connection이라고 부릅니다.
이루 layer norm을 통해 정규화를 해줍니다.
그림에서는 add & norm이라고 표현되어 있습니다.
이후 Feed Forward Network를 통해 (B, T, D) -> (B, T, D)로 변환합니다.
또한 residual connection을 통해 (B, T, D) + (B, T, D)가 되어 (B, T, D)가 됩니다.
그림을 따라가면서 코드와 한번 다시 읽어보면서 이해해봅시다.
이해가 잘 될꺼예요 :)
이렇게 작성한 코드가 빨간색으로 표시된 Block입니다.
왼쪽에 Nx라고 되어있죠? 이렇게 Block을 여러번 통과시켜야 합니다.
다음 post에선 이렇게 만든 Block을 여러번 통과시키는 코드를 작성해보겠습니다!