SALSA (code)

Seungsoo Lee·2022년 7월 7일
0

RL

목록 보기
13/15

date: 2021-10-18 22:00:00


(https://github.com/rlcode/reinforcement-learning-kr-v2) 의 코드를 참고하였습니다.***

이 코드를 이해하려면 앞선 SALSA 포스트를 읽어주세요.

설명은 주석으로 해두었습니다.


agent.py


if __name__ == "__main__":
    env = Env()
    agent = SARSAgent(actions=list(range(env.n_actions)))
		
    # 1000개의 에피소드를 진행함.
    for episode in range(1000):
        # 게임 환경과 상태를 초기화
        state = env.reset()
        # 현재 상태에 대한 행동을 선택 (epsilon greedy policy 에 의한 action 선택임)
        action = agent.get_action(state)

        while True:
            env.render()

            # 행동을 취한 후 다음상태 보상 에피소드의 종료 여부를 받아옴
            next_state, reward, done = env.step(action)
            # 다음 상태에서의 다음 행동 선택 (다음 행동 또한 epsilon greedy policy 에 의한 action 선택임)
            next_action = agent.get_action(next_state)
            # <s,a,r,s',a'>로 큐함수를 업데이트
            agent.learn(state, action, reward, next_state, next_action)

            # 반복을 위해 다시 next 변수들을 state, action에 초기화 시킴(다시 반복)
            state = next_state
            action = next_action

            # 모든 큐함수를 화면에 표시
            env.print_value_all(agent.q_table)

            if done:
                break

SARSAgent 클래스의 초기화 부분을 보겠다

class SARSAgent:
    def __init__(self, actions):
        self.actions = actions
        # (stepsize가 작기 때문에 reward와 penalty를 각각 100, -100으로 설정함)
        self.step_size = 0.01
        self.discount_factor = 0.9
        self.epsilon = 0.1
        # 0을 초기값으로 가지는 큐함수 테이블 생성
        self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])

learn 함수에 대해서 보겠다.

# <s, a, r, s', a'>의 샘플로부터 큐함수를 업데이트
    def learn(self, state, action, reward, next_state, next_action):
    		# 변수들을 str으로 변환해줌
        state, next_state = str(state), str(next_state)
        # 매개변수로 받아온 현재 state의 action에 해당하는 q값을 얻어와서 cuuent_q에 넣음
        current_q = self.q_table[state][action]
        # 매개변수로 받아온 다음 state와 action에 해당하는 q값을 얻어와서 next_current_q에 넣음
        next_state_q = self.q_table[next_state][next_action]
        
        # 업데이트의 목표를 크기를 계산
        td = reward + self.discount_factor * next_state_q - current_q
        new_q = current_q + self.step_size * td
        # q_table에 해당 state에서 해당 action을 하였을때의 새로 업데이트된 q값을 넣는다.
        self.q_table[state][action] = new_q

get_action 함수에 대해서 보겠다.

# 입실론 탐욕 정책에 따라서 행동을 반환
    def get_action(self, state):
    		# 설정된 epsilon이 0.1 이라서 10% 확률로 무작위 행동을 함
        if np.random.rand() < self.epsilon:
            # 무작위 행동 반환
            action = np.random.choice(self.actions)
        else:
            # 큐함수에 따른 행동 반환
            # 최대 q값을 하는 행동을 반환
            state = str(state)
            q_list = self.q_table[state]
            action = arg_max(q_list)
        return action

제가 올린 글에서 잘못된 부분이 있으면 제 메일로 연락주세요!

Reference : https://github.com/rlcode/reinforcement-learning-kr-v2


이승수의 저작물인 이 저작물은(는) 크리에이티브 커먼즈 저작자표시-비영리-동일조건변경허락 4.0 국제 라이선스에 따라 이용할 수 있습니다.

0개의 댓글