[RL] SARSA

fine86·2023년 4월 20일

Reinforcement Learning

목록 보기
7/8

해당 글은 강화 학습의 개념 전반에 대해 순차적으로 다룰 예정입니다.

이번 포스팅에서는 Temporal Difference Algorithm 기반의 prediction한 정보를 통해 Policy Iteration을 하기 위한 SARSA Algorithm에 대해 설명하겠습니다.

TD Control

Basic Idea

앞선 포스팅들에서 반복해서 설명했듯이, TD는 MC보다 학습에 장점이 많은 알고리즘이다. 그렇다보니 Control의 과정도 TD를 활용하는 것이 MC Control보다 훨씬 효율적이라는 아이디어가 나왔고, 그렇게 만들어진 것이 TD Control 방식이다. TD Control의 기본적인 아이디어는 MC Control과 마찬가지로 TD에도 액션 가치 함수 q(s,a)q(s, a)를 적용하여 이 결과를 이용해 정책을 업데이트하는 것이다. 또한, 업데이트 방식도 MC Control 때 사용했던 ϵ\epsilon-greedy policy improvement 방식을 사용하여 policy iteration 과정에서 local optima에 빠지는 것을 방지한다. 따라서 이러한 일련의 과정들을 매 시점마다 수행하여 최종적으로 최적 정책으로 수렴하도록 학습을 진행하는 것이 TD Control이다. TD Control은 그 방식에 따라 SARSA와 Q-learning으로 나눌 수 있는데, 두 가지 방식을 나누는 기준은 해당 방식이 On-Policy 방식인지, 아니면 Off-Policy 방식인지에 달려있다. 해당 개념에 대한 설명은 Q-learning을 다루는 다음 포스팅에서 깊이 다룰 예정이므로, 지금은 prediction 차원에서의 MC와 TD의 관계처럼 MC Control을 TD의 방식으로 변환한 Control 알고리즘이 SARSA라는 사실만 파악하고 넘어가도록 하자.

SARSA

TD에서의 가치 함수 업데이트 방식을 상기시키기 위해 업데이트에 사용된 식을 가져왔다. 아래 식에서 알 수 있듯이 Prediction 단계에서 TD의 목표는 다음 상태에서의 가치 함수 V(St+1)V(S_{t+1})를 이용해 현재 상태의 가치 V(St)V(S_t)를 업데이트 하는 것이었다.

V(St)V(St)+α(Rt+1+γV(St+1)V(St))V(S_t) \leftarrow V(S_t) + \alpha(R_{t+1} + \gamma V(S_{t+1}) - V(S_t))

하지만 MC Control에 대해 다룰 때에도 설명했듯 최적 정책을 획득하기 위해 정책을 업데이트하는 과정에서는 액션에 대한 정보가 있어야 하기 때문에 기존의 TD에서 상태 가치 함수 V(St)V(S_t) 대신 액션에 대한 정보를 포함하고 있는 액션 가치 함수 Q(St,At)Q(S_t, A_t)를 사용하는 것이 효율적이다. 액션 가치 함수를 업데이트하기 위해 사용되는 식은 위 식을 변환하여 아래와 같이 나타낼 수 있다.

Q(St,At)Q(St,At)+α(Rt+1+γQ(St+1,At+1)Q(St,At))Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha(R_{t+1} + \gamma Q(S_{t+1}, A_{t+1})-Q(S_t, A_t))

결국, Q(St,At)Q(S_t, A_t)를 업데이트 하기 위해서는 TD target을 알기 위해 액션 AtA_t에 대한 보상 Rt+1R_{t+1}과 그 결과 이동한 다음 상태 St+1S_{t+1}와, 그 상태에서 선택할 액션 At+1A_{t+1}에 대한 정보를 알아야 하는 것이다. 따라서 해당 식을 계산하기 위해 알아야 할 에피소드의 transit은 <St,At,Rt+1,St+1,At+1><S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}>이 됨을 알 수 있다. (SARSA라는 명칭은 transit의 구성 요소가 되는 값들을 나열한 것이다.)

n-Step SARSA

Prediction 단계에서 TD를 다룰 때 우리는 일반적인 형태의 TD (1-Step TD) 외에도 n-Step TD, TD(λ\lambda) 등의 변형된 형태의 TD를 다뤘다. 따라서 TD를 기반으로 고안된 SARSA 역시 각각의 변형에 대응하는 형태가 존재한다. 우선은 그중에서 n-Step SARSA에 대해 다뤄보자.

n-Step TD와 마찬가지로, n-Step SARSA의 목적은 현재 시점 tt에서 n-Step 만큼 미래의 시점에 도달한 상태 St+nS_{t+n}에서의 액션 가치 함수 Q(St+n,At+n)Q(S_{t+n}, A_{t+n})를 기반으로 Q(St,At)Q(S_t, A_t)를 업데이트 하는 방식이다. 따라서, 이 때의 target을 풀어서 나타내면 아래 식과 같이 표현할 수 있다.

qt(n)=Rt+1+γRt+2+...+γn1Rt+n+γnQ(St+n,At+n)q_t^{(n)} = R_{t+1} + \gamma R_{t+2} +...+\gamma^{n-1}R_{t+n} + \gamma^nQ(S_{t+n}, A_{t+n})

우리는 target과 현재의 Q(St,At)Q(S_t, A_t)의 차이에 해당하는 error 값을 통해 Q(St,At)Q(S_t, A_t)를 업데이트해야 하므로 결국 n-Step SARSA에서의 식은 다음과 같이 나타낼 수 있다. 결국, n-Step SARSA 방식으로 Policy Iteration을 수행하기 위해서는 t+nt+n 시점까지 에피소드를 진행하여 target인 qt(n)q_t^{(n)}를 계산하는 과정을 거쳐야 하는 것이다.

Q(St,At)Q(St,At)+α(qt(n)Q(St,At))Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha(q_t^{(n)} - Q(S_t, A_t))

SARSA(λ)

앞의 n-Step SARSA의 경우를 통해, 우리는 SARSA(λ\lambda) 역시 TD(λ\lambda) 기반의 TD Control 방식이라는 사실을 유추할 수 있다. n-Step SARSA의 과정이 n-Step TD와 거의 차이가 나지 않는 것처럼, SARSA(λ\lambda) 역시 TD(λ\lambda)와 그 궤를 같이 한다. 알다시피 TD(λ\lambda)는 1-Step부터 \infin-Step까지의 TD-target Gt(n)G_t^{(n)}의 기댓값 GtλG_t^\lambda를 TD-target으로 사용하여 가치 함수를 업데이트 하는 방식이다. 지금까지의 경험을 토대로 우리는 SARSA(λ\lambda) 역시 기존 TD(λ\lambda)에서 가치 함수만 액션 가치 함수로 변형해주면 된다는 점을 예측할 수 있다. 따라서 SARSA(λ\lambda)에서의 target qtλq_t^\lambda는 아래와 같이 나타낼 수 있다.

qtλ=(1λ)n=1λn1qt(n)q_t^\lambda=(1-\lambda) \sum_{n=1}^\infin \lambda^{n-1}q_t^{(n)}

그리고 이렇게 계산한 target을 기존의 식에 대입한 결과는 다음과 같다. 사용한 식은 Forward-view TD(λ\lambda)를 기반으로 수정한 Forward-view SARSA(λ\lambda) 식이다.

Q(St,At)Q(St,At)+α(qtλQ(St,At))Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha(q_t^\lambda-Q(S_t, A_t))

Forward-view SARSA(λ\lambda)까지 다뤘으니, 이제 거의 다 왔다. 전에도 말했지만 Forward-view 방식의 경우 MC와 마찬가지로 에피소드가 끝나야 업데이트를 진행할 수 있기 때문에, 마지막으로 Backward-view 방식까지 생각해줘야 한다. Backward-view SARSA(λ\lambda)도 Backward-view TD(λ\lambda)와 별 차이는 없다. eligibility trace를 고려하여 Q(s,a)Q(s, a)를 업데이트할 때 사용할 가중치 et(s,a)e_t(s, a)를 계산하여 업데이트에 반영하면 되는 것이다. 가중치를 계산하는 식은 다음과 같다.

e0(s,a)=0                                                      et(s,a)=γλet1(s,a)+1(St=s,At=a)e_0(s, a)=0\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \\\ e_t(s, a)=\gamma \lambda e_{t-1}(s, a) + \boldsymbol{1} (S_t=s, A_t=a)

우선 첫 번째 식은 episode가 시작하는 시점, 즉 t=0t=0인 시점에서 가중치를 초기화 하기 위한 동작이다. 이는 해당 episode에서 agent가 거쳐간 <s,a><s, a>에 대해서만 error값을 반영하기 위한 동작이다. 특정 episode에서 error가 발생한다 하더라도, episode의 시작과 함께 초기화가 되었기 때문에 해당 episode에서 거쳐가지 않은 <s,a><s, a>의 가중치 et(s,a)e_t(s, a)는 0이기 때문에 이 때의 가치 함수 Q(s,a)Q(s, a)는 업데이트 되지 않는 것이다. 또한 두 번째 식은 기존의 Backward-view TD(λ\lambda)에서의 두 식(St=sS_t=sStsS_t \neq s인 각각의 경우)을 하나의 식으로 표현한 것으로, St=sS_t=sAt=aA_t=a를 모두 만족하는 경우에는 et(s,a)e_t(s, a)에 1을 추가로 더하고, 그렇지 않은 경우에는 1을 더하지 않는 동작을 수행한다. 따라서, episode를 끝내지 않고 policy improvement를 수행하기 위해서는 다음과 같이 Backward-view SARSA(λ\lambda) 방식으로 학습을 진행하면 되는 것이다.

이렇게 지금까지 TD의 Control 방식인 SARSA에 대해 알아보았다. 간단히 정리하자면 SARSA는 결국 TD를 통한 prediction 결과를 ϵ\epsilon-greedy policy improvement 방식으로 반복해서 업데이트하는 policy iteration 과정이고, 그 과정을 수월하게 진행하기 위해 기존의 TD를 액션에 대한 정보를 포함한 Q(s,a)Q(s, a)에 대한 식으로 변환해준 것 뿐이다. 따라서 지난 포스팅에서 ϵ\epsilon-greedy에 대해 제대로 이해했고, 또 TD에 대해 잘 공부했다면 어렵지 않게 이해할 수 있었을 것이다. 그리고 SARSA를 제대로 이해했다면, 다음 포스팅에서 다룰, 강화 학습에서 가장 중요한 개념인 Q-learning을 이해하는 것도 어렵지 않을 것이다.

profile
좀 더 스마트하게 살고 싶은 리눅스, 로보틱스 개발자

0개의 댓글