강화학습은 여러 분야에서 많은 성공을 거뒀고, DQN은 그 첫번째 사례이다. 이번 튜토리얼에서는 Tianshou를 이용하여 DQN Agent로 Cartpole 환경을 차근차근 학습해볼것이다. hyper-parameter, network 등의 specification 만 조율할 수 있는 다른 라이브러리들과 다르게, Tianshou는 code-level에서 손쉬운 구축이 가능하다.
먼저, 너의 agent가 상호작용 할 Environment를 만들어야한다. 우리는 openAI Gym의 env 양식을 그대로 따른다. Python 코드에서 그냥 Tianshou를 import 하고 환경을 만들면 된다.
import gym
import tianshou as ts
env = gym.make('CartPole-v0')
CartPole-v0은 이산적인 액션 공간을 가지는 간단한 환경이다. 액션 공간이 이산적인지, 연속적인지 먼저 알고, 그에 맞는 알고리즘을 적용해야 한다. 예를들어 DDPG는 연속적인 공간에만 사용 할 수 있는 반면에 다른 거의 모든 PG 알고리즘들은 양쪽 모두에 적용 가능하다.
만약 기존의 gym.Env
를 사용하는 경우 다음과 같이 하면 된다.
train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')
Tianshou 는 모든 알고리즘들에 대하여 병렬 샘플링 (parallel sampling)을 지원한다. 또한 네가지 타입의 Vectorized Environment Wrapper를 제공하는데, 종류와 사용방법은 다음과 같다.
DummyVectorEnv
SubprocVectorEnv
ShmemVectorEnv
RayVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
여기서 우리는 train_envs
에 10개, test_envs
에 100개의 환경을 만들었다.
Demonstration을 위해서 여기서는 두번째 code-block 을 사용한다.
만약 custom env 를 사용한다면,
seed
메소드를 제대로 설정해야 한다.def seed(self, seed): np.random.seed(seed)
그렇지 않으면 모든 env 에서의 결과값이 같을 수 있다.
Tianshou는 Pytorch로 짜여진 모든 Network 와 Optimizer 를 지원한다(당연히 input 과 output 은 Tianshou API 를 준수해야한다.) 예는 다음과 같다.
import torch, numpy as np
from torch import nn
class Net(nn.Module):
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape)),
)
def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
혹은, 우리가 만들어 놓은 common
, discrete
, continuous
MLP Network 를 사용할수도 있다. 자체 정의된 네트워크의 규칙은 다음과 같다.
input: Observation obs
( numpy.ndarray
, torch.Tensor
, dict 나 self-defined class), hidden state state
(RNN, LSTM 등을 위한), 그리고 환경에서 받은 다른 정보들 info
Output: logits
과 다음 hidden state state
. logit값은 torch.Tensor
대신 튜플이나 다른 변수가 될 수도 있다. 이건 정책이 어떻게 네트워크 출력값을 처리하느냐에 달렸다. 예를들어 PPO 에서는 네트워크의 리턴값이 가우시안 정책을 위한 (mu, sigma), state
형태일수도 있다.
여기서 logits 은 raw output 을 뜻한다. 지도학습의 regression/classification 모델에서 raw output을 logits 이라고 부르는데, 우리는 그 정의를 확장해서 모든 NN에서의 Raw output 을 logits 이라고 부른다.
우리는 정책을 정의하기 위해서 초매개변수들과 함께 위에서 정의된 net
과 optim
을 사용한다. Target network가 있는 DQN 정책을 정의하는 코드는 다음과 같다.
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
Collector 는 Tianshou 의 핵심 개념이다. 컬렉터는 정책이 다른 종류의 환경들과 쉽게 상호작용 하도록 한다. 각 스텝 마다, 컬렉터는 정책이 (최소한) 지정된 스텝이나 에피소드 수 만큼 행동하게 하고, 리플레이 버퍼에 그 데이터를 저장한다.
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)
Tianshou 는 세가지 트레이너 onpolicy_trainer()
, offpolicy_trainer()
, offline_trainer()
를 제공한다. 트레이너는 test collector의 정책이 stop condtion stop_fn
에 도달하면 자동으로 학습을 종료한다. DQN은 Off-Policy 알고리즘이므로, 여기서 우리는 offpolicy_trainer()
를 사용한다.
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector,
max_epoch=10, step_per_epoch=10000, step_per_collect=10,
update_per_step=0.1, episode_per_test=100, batch_size=64,
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
print(f'Finished training! Use {result["duration"]}')
)
각 매개변수의 의미는 다음과 같다.
max_epoch
: 학습 중 최대 에폭 갯수. 학습은 최대 에폭에 도달하기전에 끝날 수 있다.step_per_epoch
: 한 에폭당 transition의 갯수step_per_collect
: 네트워크를 업데이트 할 때, 컬렉터가 수집하는 transition 의 갯수. 예를들어, 위 코드는 열개의 transition 을 가지고 정책 네트워크를 한번 업데이트 한다.episode_per_test
: 정책 평가를 할 때 이용할 에피소드의 갯수batch_size
: 정책 네트워크에 먹일(?) 샘플 데이터의 배치 사이즈train_fn
: 이 함수는 현재 epoch 의 넘버와 step index 를 받아서 이번 에폭에서 학습을 시작하기 전에 몇몇 작업을 한다. 예를들어, 위 코드에서는 학습전에 epsilon 값을 0.1로 리셋 해주는 작업을 한다.test_fn
: 이 함수는 현재 epoch 의 넘버와 step index 를 받아서 이번 에폭에서 테스팅을 시작하기 전에 몇몇 작업을 한다. 예를들어 위 코드에서는 테스트 전에 epsilon 값을 0.05로 리셋 해주는 작업을 한다.stop_fn
: 이 함수는 테스팅 결과에서 할인되지 않은 리턴값의 평균을 받아서, 목표에 도달했는지에 대한 bool 값을 반환한다.logger
: 밑을 보세용트레이너는 로깅을 위해 TensorBoard
를 지원한다. 사용법은 다음과 같다
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
wirter = summaryWriter('log/dqn')
logger = basicLogger(writer)
로거를 트레이너에 넘겨주면, 학습 결과는 TensorBoard 에 다음과 같은 형식으로 기록된다.
{
'train_step': 9246,
'train_episode': 504.0,
'train_time/collector': '0.65s',
'train_time/model': '1.97s',
'train_speed': '3518.79 step/s',
'test_step': 49112,
'test_episode': 400.0,
'test_time': '1.38s',
'test_speed': '35600.52 step/s',
'best_reward': 199.03,
'duration': '4.01s'
}
위 기록은 대략 4초만에 카트폴에 대한 DQN의 학습이 끝난것을 보여준다. 100개가 넘는 연속되는 에피소드에서 리턴값의 평균은 199.03 이다.
정책이 torch.nn.Module
를 상속 받기 때문에, 정책을 저장하거나 불러오는 방법은 torch module 과 정확히 일치한다.
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))
collector
는 렌더링을 지원한다. agent 의 퍼포먼스를 35FPS 환경에서 보는 코드는 다음과 같다.
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
"나는 트레이너 안쓸거고, 내가 직접 커스터마이즈 할거임"
Tianshou 는 사용자 정의 학습 코드를 지원한다. 코드 예시는 다음과 같다.
# pre-collect at least 5000 transitions with random action before training
train_collector.collect(n_step=5000, random=True)
policy.set_eps(0.1)
for i in range(int(1e6)): # total step
collect_result = train_collector.collect(n_step=10)
# once if the collected episodes' mean returns reach the threshold,
# or every 1000 steps, we test it on test_collector
if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0:
policy.set_eps(0.05)
result = test_collector.collect(n_episode=100)
if result['rews'].mean() >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rews"].mean()}')
break
else:
# back to training eps
policy.set_eps(0.1)
# train policy with a sampled batch data from buffer
losses = policy.update(64, train_collector.buffer)
더 자세한 내용은 Cheat Sheet 를 참고하면 된다.
Quick start 가이드를 번역도 하면서 공부해봤는데, 대충 어떤식으로 쓰이는지는 가닥이 잡힌다. 다만 아직 컬렉터나 배치 등에 대한 확실한 개념은 잘 모르겠다. 뒤에 더 자세한 설명이 있으니 계속 해보장