Trajectory Transformer (+ Fast Trajectory Transformer)

임수정·2022년 8월 10일
0

https://github.com/Howuhh/faster-trajectory-transformer/tree/27923614dbb8512953e57b092f0aed6b2cf17045

Fast Trajectory Transformer

This is reimplementation of Trajectory Transformer, introduced in Offline Reinforcement Learning as One Big Sequence Modeling Problem paper.

이전 코드의 문제(Trajectory Transformer)

The original implementation has few problems with inference speed, namely quadratic attention during inference and sequential rollouts.

  • 2차원 attention 계산 + quadratic memory size(= 2차원 dense matrix) 필요성의 한계가 존재

(참고 : https://ai.googleblog.com/2020/10/rethinking-attention-with-performers.html)

  • memory caching 등과 같이 여러 방법이 제안됨.
  • 제일 일반적인 방법이 sparse attention
    • 모든 토큰 pair에 대해 유사도 계산하지 않고 선택된 pair 끼리만 유사돌르 계산
    • pair는 여러 방법으로 선택될 수 있음 (예시 -> manually, 최적화, 학습, randomized 등)
    • sparse attention은 graph와 edge로도 표현될 수 있음. 따라서 GNN에서도 영향을 받음. (참고 : Graph Attention Networks)
    • sparse attention은 일반적으로 implicitly하게 fully attention mechanism을 표현하기 위해 추가적인 layer가 필요함

The former slows down planning a lot, while the latter does not allow to do rollouts in parallel and utilize GPU to the full.

  • 이전 방식은 planning 시간이 너무 걸림.
  • fast trajectory transformer는 병렬 rollout을 허용하지 않고, GPU를 최대로 활용함(?)

Still, even after all changes, it is not that fast compared to traditional methods such as PPO or SAC/DDPG. However, the gains are huge, what used to take hours now takes a dozen minutes (25 rollouts, 1k steps each, for example). Training time remains the same, though.

  • 바꿨더라도 사실 PPO나 SAC/DDPG 만큼 빨라지지는 않았음.
  • 하지만 몇시간 걸릴 것을 12분 정도 걸리게 바꿨기 때문에 매우 이득이다.
  • 학습 시간은 동일함

정확히 뭘 바꿨나?

1. Attention caching
During beam search we're only predicting one token at a time. So with the naive implementation model will make a lot of unnecessary computations to recompute attention maps for full past context. However it is not necessary, as it was already computed when the previous token was predicted. All we need is to cache it!

  • beam search 할 때 token을 하나씩 예측한다. 기존 구현은 이전에 계산 했던 값도 다시 계산하기 때문에 불필요한 계산이 너무 많음.
  • 한 번 계산한 것은 저장해서 불필요한 계산을 줄임.

Actually, attention caching is a common thing in NLP field, but a lot of RL practitioners may not be familiar with NLP, so the code also can be educational.

  • NLP에서는 attention caching이 의미가 있지만, RL에서는 익숙하지 않을 것이기 때문에, 코드는 교육적이기도 합니다.

2. Vectorized rollouts
Vectorized environments allow batching beam search planning and select actions in parallel, which is a lot faster if you need to evaluate agent on number of episodes (or seeds) during training.

  • environments를 vetorize하여 beam search를 한 번에 할 수 있게하고, 병렬적으로 action을 선택하게 되어 훈련 중 에피소드(또는 시드) 수로 에이전트를 평가해야 하는 경우 훨씬 빠르다.
    (??위에랑 말이 좀 달라서 헷갈립니다..제가 잘못이해했거나..)

https://ai.googleblog.com/2020/10/rethinking-attention-with-performers.html

profile
유쾌하게, 열정적으로, 진심을 다해

0개의 댓글