[강화학습] 6. 다변수, 연속값 상태 표현

KBC·2024년 11월 4일
0

강화학습

목록 보기
6/13

CartPole의 상태

  • 미로에서
    • 상태 : 에이전트가 어느 칸에 위치했는지 변수 하나로 나타냄(0 ~ 8 단순 이산값)
  • 역진자 태스크에서는 더 복잡하게 상태가 정의되어야함
  • 역진자에서의 상태
    • 수레의 위치 : 2.42.4-2.4 \sim 2.4
    • 수레의 속도 : -\infty \sim \infty
    • 봉의 각도 : 41.8°41.8°-41.8\degree \sim41.8\degree
    • 봉의 각속도 : -\infty \sim \infty

      원래 상태가 가질 수 있는 것은 연속값이지만, 상한과 하한부터 이산값으로 변환하여 표현

상태의 이산변수 변환 구현

# 구현에 사용할 패키지 임포트
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym

# 상수 정의
ENV = 'CartPole-v0' # 태스크 이름
NUM_DIZITIZED = 6 # 각 상태를 이산변수로 변환할 구간 수

#  CartPole 실행
env = gym.make(ENV) # 실행할 태스크 설정
observation = env.reset() # 환경 초기화

# 이산값으로 만들 구간 계산
def bins(clip_min, clip_max, num) :
    '''관측된 상태(연속값)를 이산값으로 변환하는 구간을 계산'''
    return np.linspace(clip_min, clip_max, num + 1)[1:-1]
  • np.linspace는 각 구간 경곗값으로 이뤄진 수열을 생성하는 명령이다
    • 예를 들어 np.linspace(-2.4,2.4,6+1)을 실행하면
      [2.4,1.6,0.8,0.,0.8,1.6,2.4][-2.4,-1.6,-0.8,0.,0.8,1.6,2.4]를 결과로 얻는다
      이 리스트에서 첫 번째 요소와 마지막 요소를 뺀 부분 리스트를 사용할 것이므로 [1:1][1:-1] 사용
  • 그 다음으로 bins에 구했던 구간값에 따라 연속변수이산변수로 변환하는 함수를 구현
def digitize_state(observation) :
    '''관측된 상태(observation 변수)를 이산값으로 변환'''
    cart_pos, cart_v, pole_angle, pole_v = observation
    digitized = [
        np.digitize(cart_pos, bins=bins(-2.4, 2.4, NUM_DIZITIZED)),
        np.digitize(cart_v, bins=bins(-3.0, 3.0, NUM_DIZITIZED)),
        np.digitize(pole_angle, bins=bins(-0.5, 0.5, NUM_DIZITIZED)),
        np.digitize(pole_v, bins=bins(-2.0, 2.0, NUM_DIZITIZED))
    ]
    return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])
  • np.digitize는 상태변수의 리스트를 bins에 정의된 구간값에 따라 이산값으로 변환
  • digitize_state의 반환값은 상태 4개 변수를 모두 합쳐 0부터 1295사이의 값으로 변환
  • DIGITIZED = 6이라면 6진수로 계산
    • 예를 들어 (수레위치, 수레속도, 봉의 각도, 봉의 각속도) = (1, 2, 3, 4)
      1×60+2×61+3×62+4×63=9851\times6^0+2\times6^1+3\times 6^2+4\times6^3=985
profile
AI, Security

0개의 댓글