[강화학습] 7. 역진자 태스크 Q러닝 구현

KBC·2024년 11월 4일
0

강화학습

목록 보기
7/13
post-thumbnail

Q러닝 구현

  • 구현할 클래스는 Agent, Brain, Environment
    • Agent : CartPole의 수레
      • update_Q_function : Q함수의 수정
      • get_action : 다음에 취할 행동을 결정
    • Brain : Agent 클래스의 두뇌 역할
      • bins, digitize_state : 관측한 상태 observation을 이산변수로 변환
      • update_Q_table : Q테이블 수정
      • decide_action : Q테이블을 이용해 행동을 결정
    • Environment : 실행 환경
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    '''
    Displays a list of frames as a gif, with controls
    '''
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save('movie_cartpole.gif')
    # display(display_animation(anim, default_mode='loop'))
  • 위는 애니메이션 코드이고 아래는 상수 정의
# 상수 정의
ENV = 'CartPole-v0' # 태스크 이름
NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수
GAMMA = 0.99 # 시간할인율
ETA = 0.5 # 학습률
MAX_STEPS = 200 # 1에피소드당 최대 단계 수
NUM_EPISODES = 1000 # 최대 에피소드 수
  • Agent 구현
class Agent :
    '''CartPole 에이전트 역할을 할 클래스, 봉 달린 수레'''
    def __init__(self, num_states, num_actions) :
        self.brain = Brain(num_states, num_actions) # 에이전트가 행동을 결정하는 두뇌 역할
    
    def update_Q_function(self, observation, action, reward, observation_next) :
        '''Q함수 수정'''
        self.brain.update_Q_table(
            observation, action, reward, observation_next
        )

    def get_action(self, observation, step) :
        '''행동 결정'''
        action = self.brain.decide_action(observation, step)
        return action
  • Brain 클래스 구현
class Brain :
    '''에이전트의 두뇌 역할을 하는 클래스, Q러닝을 실제 수행'''

    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions # 행동의 가짓수(왼쪽, 오른쪽)를 구함
        # Q 테이블을 생성. 줄 수는 상태를 구간수^4(변수의 수)가지 값 중 하나로 변환한 값, 
        # 열 수는 행동의 가짓수
        self.q_table = np.random.uniform(low=0, high=1, size=(
            NUM_DIZITIZED**num_states, num_actions
        ))

    def bins(self, clip_min, clip_max, num):
        '''관측된 상태(연속값)를 이산변수로 변환하는 구간을 계산'''
        return np.linspace(clip_min, clip_max, num + 1)[1:-1]
    
    def digitize_state(self, observation):
        '''관측된 상태 observation을 이산변수로 반환'''
        cart_pos, cart_v, pole_angle_pole_v = observation
        digitized = [
            np.digitize(cart_pos, bins=bins(-2.4, 2.4, NUM_DIZITIZED)),
            np.digitize(cart_v, bins=bins(-3.0, 3.0, NUM_DIZITIZED)),
            np.digitize(pole_angle, bins=bins(-0.5, 0.5, NUM_DIZITIZED)),
            np.digitize(pole_v, bins=bins(-2.0, 2.0, NUM_DIZITIZED))
        ]
        return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])
    
    def update_Q_table(self, observation, action, reward, observation_next):
        '''Q러닝으로 Q테이블을 수정'''
        state = self.digitize_state(observation) # 상태를 이산변수로 변환
        state_next = self.digitize_state(observation_next) # 다음 상태를 변환
        Max_Q_next = max(self.q_table[state_next][:])
        self.q_table[state, action] = self.q_table[state, action] + \
        ETA * (reward + GAMMA + Max_Q_next - self.q_table[state, action])

    def decide_action(self, observation, episode):
        '''Epsilon-Greedy 알고리즘 적용 서서히 최적행동의 비중을 늘림'''
        state = self.digitize_state(observation)
        epsilon = 0.5 * (1 / (episode + 1))

        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state][:])
        else :
            action = np.random.choice(self.num_actions) # 0, 1 두가지 행동 중 하나
        return action
  • Environment 클래스 구현
class Environment:
    '''CartPole을 실행하는 환경 역할을 하는 클래스'''

    def __init__(self):
        self.env = gym.make(ENV, render_mode = 'rgb_array') # 실행할 태스크를 설정
        num_states = self.env.observation_space.shape[0] # 태스크의 상태 변수 수
        num_actions = self.env.action_space.n # 가능한 행동 수
        self.agent = Agent(num_states, num_actions) # 에이전트 객체 생성

    def run(self):
        '''실행'''
        complete_episodes = 0 # 성공한 에피소드 수
        is_episode_final = False # 마지막 에피소드 여부
        frames = [] # 이미지 저장

        for episode in range(NUM_EPISODES) : # 에피소드만큼 반복
            observation, info = self.env.reset() # 환경 초기화

            if episode == NUM_EPISODES - 1 :
                is_episode_final = True

            for step in range(MAX_STEPS) : # 1 에피소드에 대해 반복
                if is_episode_final is True : # 마지막 에피소드이면 이미지 저장
                    frames.append(self.env.render())

                # 행동을 선택
                action = self.agent.get_action(observation, episode)

                # 행동 a_t를 실행해 s_{t+1}, r_{t+1}을 계산
                observation_next, _, done, _, _ = self.env.step(
                    action # reward, info는 사용하지 않으므로 _로 처리
                )

                # 보상을 부여
                if done : # 200단계를 넘거나 일정 각도 이상 기울면 done의 값이 True
                    if step < 195 :
                        reward = -1 # 봉이 쓰러지면 패널티 보상 -1
                        complete_episodes = 0 # 해당 에피소드 실패
                    else :
                        reward = 1 # 성공하면 보상 1
                        complete_episodes += 1 # 연속 성공 기록 업데이트
                else :
                    reward = 0 # 에피소드 중에는 보상 0

                # 다음 단계의 상태 observation_next로 Q함수를 수정
                self.agent.update_Q_function(
                    observation, action, reward, observation_next
                )

                # 다음 단계 상태 관측
                observation = observation_next

                # 에피소드 마무리
                if done :
                    print('{0} Episode: Finished after {1} time steps'.format(
                        episode, step + 1
                    ))
                    break

            if is_episode_final is True : # 마지막에서는 애니메이션 저장
                display_frames_as_gif(frames)
                break

            if complete_episodes >= 10 : # 10에피소드 이상 연속 성공
                print("10 연속 에피소드 성공")
                is_episode_final = True # 다음 에피소드가 마지막이 됨
  • run 메소드 호출
# main
cartpole_env = Environment()
cartpole_env.run()
  • 예시에서는 200에피소드 이상 쯤 되면 학습이 완료되어 꽤 버티는데 난 그렇게까진 안나왔다
  • 그래도 학습 후반부에는 꽤 버틴다
학습 학습
profile
AI, Security

0개의 댓글