RL에서 Normalize하는 이유

JTDK·2021년 6월 28일
0
post-custom-banner

Normalization?

Data Normalization 은 데이터의 범위를 사용자가 원하는 범위로 제한하는 것이다. 이미지 데이터의 경우 픽셀 정보를 0~255 사이의 값으로 가지는데, 이를 255로 나누어주면 0~1.0 사이의 값을 가지게 될 것이다.

v=vmax(vp,ϵ)v = {v\over max(∥v∥_p,ϵ)}

여기서 v∥v∥는 벡터간의 거리를 의미하고, ϵϵ 는 분모가 0이되는걸 방지해주는 아주 작은 값이다.
즉 분모의 의미는 v의 최댓값과 v의 최솟값간 거리와 ϵ의 값중 큰값이다.

우선 코드부터 보자

import torch.nn.functional as F

    def forward(self, x):
        x = F.normalize(x, dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y), dim=+0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic

흔한 액터-크리틱 이중 신경망 구조인데, 보다시피 신경망을 거치기전에 정규화부터 한다.

Why Normalize?

  1. 학습을 더 빠르게 한다.
  2. Local Optimum에 빠지는 가능성을 줄인다.

먼저 1번부터 살펴보면, input으로 큰 range안의 값을 넣으면, 역전파시 각 parameter들도 크게 움직일거라는게 직관적으로 이해가능하다. 반대로 0~1사이 값만 넣어주면 작은 범위내에서 parameter들이 움직이면서 optimal value를 찾기때문에 학습이 빨라진다.

2번도 1번과 같은 맥락인데, 큰 range의 값을 input으로 주면 loss function의 그래프가 더 가파르게 생겨지고, 아주 가파른 local optimum에 빠지게 되면 나오기가 힘들다.

구현

쉽다.. torch.nn.functional을 호출하고 forward함수에서 제일 처음에 인풋을 원하는 dim에 맞춰서 정규화하면됨.

torch.nn.functional.normalize(input, p=2.0, dim=1, eps=1e-12, out=None)

parameters

  • input – input tensor of any shape
  • p (float) – the exponent value in the norm formulation. Default: 2
  • dim (int) – the dimension to reduce. Default: 1
  • eps (float) – small value to avoid division by zero. Default: 1e-12
  • out (Tensor, optional) – the output tensor. If out is used, this operation won’t be differentiable.

여기서 p는 norm의 타입(?)이다. Default는 2인데, 이는 우리가 l2 norm을 가장 많이 사용하기 때문.

normalized = torch.nn.functional.normalize(raw_data, dim=0)
profile
RL, 퀀트 투자 공부 정리
post-custom-banner

0개의 댓글