이번에는 네이브 DQN의 문제를 해결하기 위해 나온 방법들을 이어서 알아보겠습니다!
Target network : to overcome the non-statoinary target problem
기존 Q네트워크를 업데이트 할 때 loss ft을 minimize하는 방향으로 진행합니다.
큐네트워크의 output인 큐벨류 값이 타겟값과 가까워지게 weight를 업데이트하는 건데, output으로 나온 Qvalue값을 target함수에서도 사용하고 있습니다.
그렇다는 말은 minibatch로 얻은 결과로 Q네트워크의 파라미터를 업데이트하는데, 업데이트할 때마다 target에 있는 값도 같이 업데이트하여 변화하게 되는 것 입니다.
만약 target값으로 다가가며 업데이트하는데 target이 움직이면 안 될 것 입니다.
이를 target이 고정되어 있지 않고 움직이기에 non-stationary target problem이라고 합니다.
이 문제를 해결하기 위해 고정시켜줍니다.
loss ft을 줄이는 방향으로 update하지만 target안에 있는 값은 업데이트를 하지 않는 것 입니다.
이제는 동일한 Q network이지만 복제했다고 생각하고 분리하는 것 입니다.
기존 네트워크는 Behavior Q-network라 부르고 파라미터는 Θ라고 하며, target 네트워크는 Target Qhat-network라고 하고 파라미터는 Θhat 이라고 하겠습니다.
Behavior Q-network 계속 업데이트를 진행하고 Target Qhat-network는 업데이트를 하지 않고 고정시켜둡니다.
1) Θ = Θhat
2) Θ update, Θhat fixed
3) Θ는 매 step마다 update, Θhat은 일정 step(ex 1000)마다 update
이제는 loss ft이 바뀝니다.
역전파를 실행할 때는 Θ에 대해서만 업데이트를 진행하게 됩니다. Θhat은 상수취급
이렇게 되면 Θhat은 업데이트 할 시기가 되면 그냥 Θ를 복제하면 되는 것 입니다!
이렇게 non-stationary 문제를 해결합니다.
DQN이 어떤식으로 작동하는지 그림으로 확인해 보겠습니다.
1) Q-network를 통해 현재 state에서의 각 action의 Q vlaue값을 얻고 highest q value 갖는 action선택하고 입실론 그리디 방법으로 action 선택
2) 환경에 가서 immediate reward와 next state 얻기
3) (st,at,rt+1,st+1) 데이터 new transition으로 Replay buffer에 저장
4) 미니배치만큼 transition을 샘플링
5) GD방법으로 behavior network에 있는 파라미터를 업데이트
이렇게 한 사이클을 돌면 behavior network에 있는 파라미터를 한 번 업데이트한 것
이를 반복하며 가끔씩 behavior network에 있는 파라미터를 copy해서 target nerwork에 복사
이제는 DQN의 수도코드를 확인해보겠습니다.
Preprocessing Φ : 전처리 과정에서 raw fram으로 바꾸고 흑백으로 처리해줍니다. 또한 좌우 필요없는 곳은 crop합니다. 최종 흑백 84x84가 됩니다.
이후 위에서 언급한 내용을 순차적으로 실행합니다.
마지막으로 DQN에서 사용되는 CNN에 대해 알아보겠습니다.
preprocessing된 하나의 state st : 84x84x4(4개프레임) input
: state space = 총 28,224 pixel이 있고, 흑백이기에 2^28224 종류
-> CNN을 거쳐 256dim으로 압축
-> FC를 통해 한 state에 대한 모든 action의 q값을 도출
여기서 maxpooling을 사용하지 않습니다!
애초에 maxpooling을 사용하는 이유는 translation invariance(약간의 움직여도 같은 것으로 인식) 때문인데, 게임을 할 때는 약간의 픽셀만 이동해도 완전히 다른 것이므로 사용하지 않습니다.
DQN과 네이브DQN의 차이입니다.
고려대학교 오승상 교수님 강화학습 강의 : https://www.youtube.com/watch?v=C-mfKSM0VFQ&list=PLvbUC2Zh5oJtYXow4jawpZJ2xBel6vGhC&index=18