[RL] Easy21 assignment by David Silver

Minseo Jeong·2025년 5월 15일

RL by David Silver

목록 보기
11/11
post-thumbnail

UCL CS "Reinforcement Learning" 과제 기반
Sutton & Barto의 Blackjack 예제를 변형한 Easy21 게임에서
Monte-Carlo Control, Sarsa(λ), Function Approximation 등을 실습하는 과제


| Easy21 게임 규칙 요약

요소설명
덱 구성무한 덱. 1~10 값의 카드. 검정(2/3), 빨강(1/3)
초기 세팅플레이어 & 딜러 모두 검정 카드 한 장씩 받음
행동Hit(추가) or Stick(멈춤)
카드 합산검정 카드: +, 빨강 카드: –
게임 종료 조건합 < 1 또는 > 21 → bust
딜러 룰17 이상이면 stick, 아니면 hit
승패 판정버스트 발생 시 즉시 종료. 아니면 점수 비교

리워드: Win = +1, Lose = -1, Draw = 0
할인율: γ = 1 (즉, future reward에 discount 없음)


| 1. Easy21 환경 구현 (10점)

  • 핵심 함수: step(s, a)

    • 입력: 현재 상태 s = (dealer's first card, player's sum), 행동 a
    • 출력: 다음 상태 s′, 보상 r
  • stick 선택 시 딜러의 모든 플레이가 자동으로 실행됨

  • MDP 전이 확률을 명시적으로 계산하지 말 것 (모델 프리)


| 2. Monte Carlo Control 적용 (15점)

설정

  • Value function 초기값: 0
  • 학습률(αₜ): αt=1N(st,at)\alpha_t = \frac{1}{N(s_t, a_t)}
  • 탐험률(εₜ): ϵt=N0N0+N(st)\epsilon_t = \frac{N_0}{N_0 + N(s_t)}, N0=100N_0 = 100
  • ε-greedy 정책 사용

출력 요구

  • Optimal Value Function:
    V(s)=maxaQ(s,a)V^*(s) = \max_a Q^*(s,a)
  • 3D Plot으로 시각화 (x: dealer 카드, y: player sum, z: value)

| 3. Sarsa(λ) 알고리즘 적용 (15점)

설정

  • λ ∈ {0, 0.1, ..., 1}
  • 학습률/탐험률: 위와 동일
  • 1000 episode 진행
  • Ground-truth QQ^*은 MC Control 결과 사용

출력 요구

  • λ에 따른 MSE (mean-squared error) 그래프
  • λ = 0, 1인 경우는 episode별 학습 곡선 추가

| 4. Function Approximation + Sarsa(λ) (15점)

설정

  • Feature Vector: Coarse Coding (36차원 binary vector)

Feature 구성

dealer{[1–4], [4–7], [7–10]}
player{[1–6], [4–9], ..., [16–21]} (6구간)
actionhit / stick (2개)
  • 총 feature: 3 × 6 × 2 = 36

근사 함수

  • Q(s,a)=ϕ(s,a)TθQ(s, a) = \phi(s, a)^T \theta

하이퍼파라미터

  • ε = 0.05 (고정)
  • α = 0.01 (고정)

출력 요구

  • λ에 따른 MSE 그래프
  • λ = 0, 1일 때 학습 곡선

| 5. 분석 및 토론 (5점)

주요 질문

  • Easy21에서 bootstrapping의 장단점은?
  • Blackjack과 Easy21 중 bootstrapping 효과가 더 큰 쪽은?
  • Function Approximation의 장단점은?
  • 제안된 근사 방식보다 더 나은 방식이 있을까?

| Monte Carlo Control 구현 구조

디렉토리 구조 예시

easy21/
├── easy21_env.py          ← 환경 정의 (step, reset 등)
├── monte_carlo_control.py ← MC Control 알고리즘
├── plot_utils.py          ← 시각화 도구 (3D plot)
└── main.py                ← 실행 및 통합

Easy21 환경 코드 (easy21_env.py)

import numpy as np
import random

class Easy21Env:
    def __init__(self):
        self.reset()

    def draw_card(self):
        value = np.random.randint(1, 11)
        color = 'black' if np.random.rand() < (2/3) else 'red'
        return value if color == 'black' else -value

    def reset(self):
        self.dealer = np.random.randint(1, 11)
        self.player = np.random.randint(1, 11)
        return (self.dealer, self.player)

    def step(self, state, action):
        dealer, player = state

        if action == 'hit':
            player += self.draw_card()
            if player < 1 or player > 21:
                return (None, -1, True)  # bust
            return ((dealer, player), 0, False)

        # Stick: dealer plays
        while dealer < 17:
            dealer += self.draw_card()
            if dealer < 1 or dealer > 21:
                return (None, 1, True)

        # Compare scores
        if player > dealer:
            return (None, 1, True)
        elif player < dealer:
            return (None, -1, True)
        else:
            return (None, 0, True)

Monte Carlo Control 구현 (monte_carlo_control.py)

from collections import defaultdict
import numpy as np
import random

class MonteCarloControl:
    def __init__(self, env, N0=100):
        self.env = env
        self.N0 = N0
        self.Q = defaultdict(lambda: {'hit': 0.0, 'stick': 0.0})
        self.N = defaultdict(lambda: {'hit': 0, 'stick': 0})

    def epsilon(self, state):
        total_N = sum(self.N[state].values())
        return self.N0 / (self.N0 + total_N)

    def generate_episode(self):
        episode = []
        state = self.env.reset()
        done = False
        while not done:
            eps = self.epsilon(state)
            if random.random() < eps:
                action = random.choice(['hit', 'stick'])
            else:
                action = max(self.Q[state], key=self.Q[state].get)

            next_state, reward, done = self.env.step(state, action)
            episode.append((state, action, reward))
            state = next_state
        return episode

    def train(self, num_episodes=100000):
        for _ in range(num_episodes):
            episode = self.generate_episode()
            G = episode[-1][2]  # final reward
            visited = set()
            for state, action, _ in episode:
                if (state, action) not in visited:
                    self.N[state][action] += 1
                    alpha = 1 / self.N[state][action]
                    self.Q[state][action] += alpha * (G - self.Q[state][action])
                    visited.add((state, action))

Value Function 시각화 (plot_utils.py)

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def plot_value_function(Q):
    Z = np.zeros((10, 21))  # dealer 1~10, player 1~21
    for dealer in range(1, 11):
        for player in range(1, 22):
            state = (dealer, player)
            if state in Q:
                Z[dealer-1][player-1] = max(Q[state]['hit'], Q[state]['stick'])

    X, Y = np.meshgrid(np.arange(1, 22), np.arange(1, 11))
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, Z)
    ax.set_xlabel('Player Sum')
    ax.set_ylabel('Dealer Showing')
    ax.set_zlabel('Value')
    plt.show()

실행 예시 (main.py)

from easy21_env import Easy21Env
from monte_carlo_control import MonteCarloControl
from plot_utils import plot_value_function

env = Easy21Env()
agent = MonteCarloControl(env)
agent.train(num_episodes=500000)
plot_value_function(agent.Q)

| Sarsa(λ) with Tabular Representation

파일명 예시: sarsa_lambda.py

from collections import defaultdict
import numpy as np
import random

class SarsaLambdaAgent:
    def __init__(self, env, lambda_=0.5, N0=100):
        self.env = env
        self.lambda_ = lambda_
        self.N0 = N0
        self.Q = defaultdict(lambda: {'hit': 0.0, 'stick': 0.0})
        self.N = defaultdict(lambda: {'hit': 0, 'stick': 0})
        self.E = defaultdict(lambda: {'hit': 0.0, 'stick': 0.0})  # eligibility traces

    def epsilon(self, state):
        total = sum(self.N[state].values())
        return self.N0 / (self.N0 + total)

    def choose_action(self, state):
        if random.random() < self.epsilon(state):
            return random.choice(['hit', 'stick'])
        return max(self.Q[state], key=self.Q[state].get)

    def train(self, num_episodes=1000):
        for _ in range(num_episodes):
            self.E = defaultdict(lambda: {'hit': 0.0, 'stick': 0.0})  # reset traces
            state = self.env.reset()
            action = self.choose_action(state)

            done = False
            while not done:
                next_state, reward, done = self.env.step(state, action)
                next_action = self.choose_action(next_state) if not done else None

                # update counts
                self.N[state][action] += 1
                alpha = 1 / self.N[state][action]

                # TD error
                q_predict = self.Q[state][action]
                q_target = reward + (0 if done else self.Q[next_state][next_action])
                delta = q_target - q_predict

                # update traces
                self.E[state][action] += 1

                # update all Q values
                for s in self.Q:
                    for a in ['hit', 'stick']:
                        self.Q[s][a] += alpha * delta * self.E[s][a]
                        self.E[s][a] *= self.lambda_

                state, action = next_state, next_action

Function Approximation + Sarsa(λ)

파일명 예시: fa_sarsa_lambda.py

import numpy as np
import random

class FASarsaLambdaAgent:
    def __init__(self, env, lambda_=0.5, alpha=0.01, epsilon=0.05):
        self.env = env
        self.lambda_ = lambda_
        self.alpha = alpha
        self.epsilon = epsilon
        self.weights = np.zeros(36)  # 36 features
        self.eligibility = np.zeros(36)

    def get_features(self, state, action):
        dealer, player = state
        features = np.zeros(36)
        dealer_bins = [(1, 4), (4, 7), (7, 10)]
        player_bins = [(1, 6), (4, 9), (7, 12), (10, 15), (13, 18), (16, 21)]
        action_index = 0 if action == 'hit' else 1

        for i, (d_min, d_max) in enumerate(dealer_bins):
            if d_min <= dealer <= d_max:
                for j, (p_min, p_max) in enumerate(player_bins):
                    if p_min <= player <= p_max:
                        idx = i * 6 * 2 + j * 2 + action_index
                        features[idx] = 1
        return features

    def q_value(self, state, action):
        features = self.get_features(state, action)
        return np.dot(self.weights, features)

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.choice(['hit', 'stick'])
        q_hit = self.q_value(state, 'hit')
        q_stick = self.q_value(state, 'stick')
        return 'hit' if q_hit > q_stick else 'stick'

    def train(self, num_episodes=1000):
        for _ in range(num_episodes):
            self.eligibility = np.zeros(36)
            state = self.env.reset()
            action = self.choose_action(state)

            done = False
            while not done:
                next_state, reward, done = self.env.step(state, action)
                next_action = self.choose_action(next_state) if not done else None

                phi = self.get_features(state, action)
                q_now = np.dot(self.weights, phi)

                if done:
                    delta = reward - q_now
                else:
                    phi_next = self.get_features(next_state, next_action)
                    q_next = np.dot(self.weights, phi_next)
                    delta = reward + q_next - q_now

                self.eligibility = self.lambda_ * self.eligibility + phi
                self.weights += self.alpha * delta * self.eligibility

                state, action = next_state, next_action

학습 & 시각화 코드 예시

import matplotlib.pyplot as plt

def plot_mse_vs_lambda(errors, lambdas):
    plt.plot(lambdas, errors)
    plt.xlabel("Lambda (λ)")
    plt.ylabel("Mean Squared Error")
    plt.title("MSE vs Lambda in Sarsa(λ)")
    plt.grid(True)
    plt.show()
profile
로봇 소프트웨어 개발자입니다. AI 공부도 합니다.

0개의 댓글