강화학습 - Model Free(2)

BSH·2023년 5월 22일
0

강화학습_basic

목록 보기
6/12

Model Free(1)에서는 MC(Monte Carlo)학습을 알아보았는데 이번엔 TD(Temporal Difference)학습에 대해서 알아보겠습니다.

MC의 경우에는 업데이트를 하기위해 에피소드가 끝나야 한다는 문제가 있습니다. 다시말해 반드시 종료하는 MDP이어야만 사용할 수 있다는 것이죠. 실제 환경에서는 종료 조건이 없는 MDP도 존재합니다. 반면 TD는 에피소드가 끝나기 전에 업데이트하는 방법으로 종료하지 않는 MDP에서도 학습을 할 수 있습니다.
책에서는 "추측을 추측으로 업데이트 하자"라고 말하네요

이론적 배경

MC에 대한 근거는 기댓값에 있습니다. 샘플을 모으면 모을수록 평균은 수렴하게 됩니다.

vπ(st)=Eπ[Gt]v_{\pi}(s_{t})=\mathbb{E_{\pi}}[G_{t}]

통계학 용어로 GtG_{t}vπ(st)v_{\pi}(s_{t})의 불편추정량(unbiased estimate)라고 합니다

편향되지 않은 추정량이란 뜻이죠. MC는 대수의 법칙에 의해 수렴이 보장되어있습니다. TD의 경우는 어떤지 보겠습니다.

vπ(st)=Eπ[rt+1+γvπ(st+1)]v_{\pi}(s_{t})=\mathbb{E_{\pi}}[r_{t+1}+\gamma v_{\pi}(s_{t+1})]

rt+1+γvπ(st+1)r_{t+1}+\gamma v_{\pi}(s_{t+1})는 TD target이라고 부르며 이 값을 모아서 평균내면 원하는 vπ(st)v_{\pi}(s_{t})를 얻을 수 있고 곧 TD 학습입니다. 그리고 벨만 기대 방정식이 이 TD 학습에 근간이 되는 식입니다. 앞의 수식은 당연히 vπ(s)v_{\pi}(s)를 모르는 상태에서 쓰기 때문에 편차가 발생할 수 있습니다. 다음으로 학습 알고리즘을 보겠습니다.

TD 학습 알고리즘

MC에서 봤던 그리드월드를 그대로 가져오겠습니다.
TD도 MC와 마찬가지로 테이블을 초기화 하는 것부터 시작합니다.

먼저 MC와 TD의 식을 비교해보겠습니다.

MC:V(st):=V(st)+α(GtV(st)TD:V(st):=V(st)+α(rt+1+γV(st+1V(st))MC:V(s_{t}):=V(s_{t})+\alpha(G_{t}-V(s_{t})\\ TD:V(s_{t}):=V(s_{t})+\alpha(r_{t+1}+\gamma V(s_{t+1}-V(s_{t}))

GtG_{t}rt+1+γV(st+1)r_{t+1}+\gamma V(s_{t+1})로 바뀌었고 종료시점이 아닌 하나의 단계를 진행하면 바로 업데이트 합니다.

코드를 보면서 이해하면 더 쉽습니다.
그리드 월드 환경 코드와 에이전트 코드는 동일합니다.

import random

class GridWorld:
    # 환경에 해당하는 클래스
    def __init__(self, n):
        self.n = n - 1
        self.r = 0
        self.c = 0
        
    def step(self, a):
        # action을 받아서 상태 변이를 일으키며 보상을 정해주는 함수
        if a == 0:
            self.move_right()
        elif a == 1:
            self.move_left()
        elif a == 2:
            self.move_up()
        elif a == 3:
            self.move_down()
        
        reward = -1
        done = self.is_done()
        return (self.r, self.c), reward, done
    
    def move_right(self):
        self.c += 1
        if self.c > self.n:
            self.c = self.n
    
    def move_left(self):
        self.c -= 1
        if self.c < 0:
            self.c = 0
    
    def move_up(self):
        self.r -= 1
        if self.r < 0:
            self.r = 0
    
    def move_down(self):
        self.r += 1
        if self.r > self.n:
            self.r = self.n
            
    def is_done(self):
        # 에피소드 끝났는지 판결
        if self.r == self.n and self.c == self.n:
            return True
        else:
            return False
    
    def get_state(self):
        return (self.r, self.c)
    
    def reset(self):
        self.r, self.c = 0, 0
        return (self.r, self.c)

class Agent:
    def __init__(self):
        pass
    
    def select_action(self):
        coin_list = [0, 1, 2, 3]
        action = random.choice(coin_list)
        return action

TD code

def main():
    length = 5
    env = GridWorld(n=length)
    agent = Agent()
    data = [[0]*length for _ in range(length)]
    gamma = 1.0
    alpha = 0.01 # MC에 비해 큰 값 사용 (학습 변동성이 작기 때문에 크게 해주어야 함)
    
    for k in range(50000):
        # 에피소드 5만번 진행
        done = False
        while not done:
            r, c = env.get_state()
            action = agent.select_action()
            (r_prime, c_prime), reward, done = env.step(action)

            # 하나의 step이 진행될 때 마다 바로 데이블의 데이터를 업데이트 해줌
            data[r][c] = data[r][c] + alpha * (reward + gamma * data[r_prime][c_prime] - data[r][c])

        env.reset()
    
    for row in data:
        for e in row:
            print(f"{e:>10.2f}", end=" ")
        print()

            
main()

MC에 비해 학습코드가 간결해졌습니다. while문 안에 다른 loop를 반복하지 않습니다. 그리고 업데이트 폭이 매우 커졌습니다. 이는 TD가 MC에 비해 학습 변동성이 작기 때문에 그렇습니다.

학습 결과

  -104.35    -102.20     -98.73     -94.71     -92.73 
  -102.62     -99.88     -95.10     -89.63     -86.74 
   -98.78     -94.99     -87.00     -78.64     -71.10 
   -94.50     -89.88     -79.35     -66.26     -45.77 
   -91.96     -85.48     -71.10     -50.19       0.00 
profile
컴공생

0개의 댓글