프로그래머스(Programmers) : 주사위 고르기 - python 풀이

JISU LIM·2024년 3월 19일

Algorithm Study Records

목록 보기
76/79
post-thumbnail

🔴 문제 개요

문제 원문 - 프로그래머스(Programmers)

주어지는 n개의 주사위를 A, B가 n//2개 씩 나눠 가졌을 때, A가 승리하는 확률이 가장 높도록 가져가는 주사위 조합을 구하면 되는 문제입니다.

제한 사항

  • 2 ≤ dice의 길이 = n ≤ 10
    • n은 2의 배수입니다.
    • dice[i]는 i+1번 주사위에 쓰인 6개의 수를 담고 있습니다.
    • dice[i]의 길이 = 6
    • 1 ≤ dice[i]의 원소 ≤ 100

주어지는 dice의 길이를 생각했을 때, 완전 탐색으로 접근하면 복잡하긴 해도 어렵지 않게 해결할 수 있을 것 같습니다.

🟠 Solution

🦴 Backborn

솔루션의 백본으로 코드의 흐름을 설명하겠습니다. 코드의 depth에 주의하여 아래 솔루션을 이해하셨으면 좋겠습니다.

def solution(dice: List[List[int]]) -> List[int]:
    n = len(dice)
    combinations = make_combinations(n, n // 2)  # 주사위 조합
    products = product(n // 2)  # 주사위 면의 조합

    answer_cnt = 0
    answer = None   # 명시적 초기화

    for A_picks in combinations:
        B_picks = set(range(n)) - A_picks

        A_picks_result = get_roll_result(dice, products, A_picks)  # A 주사위 결과
        B_picks_result = get_roll_result(dice, products, B_picks)  # B 주사위 결과

        wins_cnt = get_wins_cnt(A_picks_result, B_picks_result)  # A 주사위가 이긴 횟수

        if answer_cnt < wins_cnt:   # answer 업데이트
            answer_cnt = wins_cnt
            answer = sorted(list(map(lambda x: x + 1, A_picks)))    # 인덱스가 1부터 시작함, 오름차순

    return answer
  1. 완전 탐색을 위해 주사위 n개 중 n//2를 가져가는 모든 조합을 찾아야 합니다. 이를 통해 A와 B가 가져가는 주사위 조합마다 결과를 구하게 됩니다.(make_combinations)
  2. 또한 n//2개의 주사위를 굴렸을 때 나오는 모든 면의 조합을 구해야 합니다. 가령 3개의 주사위를 굴렸을 때 발생하는 [0, 0, 0] ~ [5, 5, 5] 경우에 대해 모두 계산해주어야 합니다. (product)
  3. 이제 (A, B의 모든 주사위 조합 + 굴리는 모든 면의 조합)에 따른 모든 결과를 계산해야 합니다.(get_roll_result)
  4. A, B의 결과를 확보 했다면 A의 각 결과 별로 B의 각 결과를 몇 번 이기는지를 계산해야 합니다.(get_win_cnt)
  5. 문제에는 확률이 가장 높은 경우를 계산하라 명시되어있지만, 생각해보면 확률이 높은 경우 == 이기는 횟수가 많은 경우이므로 이기는 횟수가 가장 많은 A 주사위 조합을 answer로 도출하면 됩니다.

1️⃣ 주사위 n개중 n//2를 뽑는 모든 조합 찾기 : make_combinations()

def make_combinations(n: int, k: int) -> List[Set]:
    """
    주사위 n개(0~n-1) 중 k개를 뽑는 조합 반환
    ex) 주사위 4개 중 2개를 뽑기
    [0, 1], [0, 2], ...,  [2, 3]
    """
    result = []

    def recur(tmp):
        if len(tmp) == k:
            result.append(set(tmp))
            return

        for i in range(n):
            if not tmp or i > tmp[-1]:
                recur(tmp + [i])

    recur([])

    return result

itertools의 combinations 를 활용하면 쉽게 구할 수 있지만, 이와 같이 backtracking으로 구해낼 수 있습니다. 조합을 계산하는 것이므로, 중복을 포함하지 않기 위한 조건이 필요합니다.

2️⃣ n//2개의 주사위를 굴리는 모든 경우의 수 찾기 : product()

def product(n: int) -> List[List[int]]:
    """
    주사위 n개를 굴리는 경우의 수(인덱스 조합) 반환
    ex) 주사위 3개인 경우
    [0, 0, 0] ~ [5, 5, 5] 반환
    """
    result = []

    def recur(tmp: List) -> None:
        if len(tmp) == n:
            result.append(tmp)
            return

        for i in range(6):
            recur(tmp + [i])

    recur([])

    return result

마찬가지로 itertools의 product를 활용한다면 따로 구현하지 않아도 됩니다. make_combination과 구현이 비슷하지만, 이 경우 가우시안 곱에 해당하므로 각 원소가 중복되어 조합을 구성하게 됩니다.

3️⃣ (모든 주사위 조합 + 굴리는 모든 경우의 수)에 따른 결과 구하기 : get_roll_result()

def get_roll_result(dice: List[List[int]], products: List[List[int]], picks: Set[int]) -> List[int]:
    """
    주사위 pick에 따른 점수의 합의 모든 경우의 수 반환
    ex) [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합
    """
    result = []
    for indice in products:
        tmp_sum = 0
        for idx, dice_idx in zip(indice, picks):
            tmp_sum += dice[dice_idx][idx]
        result.append(tmp_sum)

    return result

make_combinations에서의 특정 주사위 조합 picks를 굴렸을 때 나오는 모든 면의 조합 products에 대한 결과를 구할 것입니다. 예를 들어 [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합을 원소로 가지는 리스트를 반환합니다.

매개변수로 받는 자료형의 channel에 주의하여 코드를 구현해야 합니다.

4️⃣ A, B 주사위 결과에 따른 승리 횟수 구하기 : get_wins_cnt()

이제 여기서 A의 결과가 B의 결과를 몇 번 이기는 지를 계산하여 함수 바깥에서 이기는 횟수를 합산하면 됩니다. 여기서 아래와 같이 완전 탐색으로 이를 구하는 경우를 보겠습니다.


 def get_wins_cnt(A: List[int], B: List[int]) -> int:
     """
     모든 주사위 점수의 합에 따른 승리 횟수 반환(TLE 발생)
     """
     result = 0

     for a in A:
         for b in B:
             if a > b:
                 result += 1

     return result

이 경우 주사위가 4개일 때만 해도 (6 x 6) x (6 x 6) = 1296번이고, 주사위는 10개까지 주어질 수 있으므로 최대 6^10 이상의 연산을 수행하게 됩니다. 너무 많은 연산량이고, 실제 효율성을 위한 테스트케이스에서 TLE가 발생하는 것을 확인할 수 있습니다.

이때, 해당 함수의 목적인 A의 결과가 B의 결과를 몇 번 이기는 지를 구해야 함을 생각했을 때, 이분 탐색을 고려할 수 있습니다.

아래와 같이 정렬된 두 리스트 V, W가 있을 때,
V = [3, 4, 5, 6 ,7]
W = [1, 2, 3, 4, 5]

V의 첫 번째 원소 3이 W의 원소 중 몇 개를 이기는지 계산해야 합니다. 이때 W 리스트에 대한 3의 lower_bound에 해당하는 위치 인덱스3보다 작은 W의 요소 개수와 같음을 알 수 있습니다.

lower bound : 찾고자 하는 값 이상이 처음 나타나는 위치
출처: https://12bme.tistory.com/120 [길은 가면, 뒤에 있다.:티스토리]

lower bound는 bisect 라이브러리의 bisect_left를 활용하여 구현하거나, 아래와 같이 이분탐색을 활용해 구현할 수 있습니다.

def get_wins_cnt(A: List[int], B: List[int]) -> int:
    """
    모든 주사위 점수의 합에 따른 승리 횟수 반환(이분탐색 활용)
    """
    result = 0

    A.sort()
    B.sort()

    for a in A:
        # B 리스트에 대한 a의 lower bound(찾고자 하는 값 이상이 처음 나타나는 위치)
        # = a가 몇 개의 경우의 수를 이기는지 개수
        start, end = 0, len(B) - 1

        while start <= end:
            mid = (start + end) // 2
            if a > B[mid]:
                start = mid + 1
            else:
                end = mid - 1

        result += end

    return result

A의 n개의 원소에 대해 B를 이분탐색(logn)했으므로, O(n^2)의 복잡도를 O(nlogn)으로 개선하여 TLE를 피할 수 있었습니다.

🥳 전체 코드

from typing import List, Set


def make_combinations(n: int, k: int) -> List[Set]:
    """
    주사위 n개(0~n-1) 중 k개를 뽑는 조합 반환
    ex) 주사위 4개 중 2개를 뽑기
    [0, 1], [0, 2], ...,  [2, 3]
    """
    result = []

    def recur(tmp):
        if len(tmp) == k:
            result.append(set(tmp))
            return

        for i in range(n):
            if not tmp or i > tmp[-1]:
                recur(tmp + [i])

    recur([])

    return result


def product(n: int) -> List[List[int]]:
    """
    주사위 n개를 굴리는 경우의 수(인덱스 조합) 반환
    ex) 주사위 3개인 경우
    [0, 0, 0] ~ [5, 5, 5] 반환
    """
    result = []

    def recur(tmp: List) -> None:
        if len(tmp) == n:
            result.append(tmp)
            return

        for i in range(6):
            recur(tmp + [i])

    recur([])

    return result


def get_roll_result(dice: List[List[int]], products: List[List[int]], picks: Set[int]) -> List[int]:
    """
    주사위 pick에 따른 점수의 합의 모든 경우의 수 반환
    ex) [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합
    """
    result = []
    for indice in products:
        tmp_sum = 0
        for idx, dice_idx in zip(indice, picks):
            tmp_sum += dice[dice_idx][idx]
        result.append(tmp_sum)

    return result


def get_wins_cnt(A: List[int], B: List[int]) -> int:
    """
    모든 주사위 점수의 합에 따른 승리 횟수 반환(이분탐색 활용)
    """
    result = 0

    A.sort()
    B.sort()

    for a in A:
        # B 리스트에 대한 a의 lower bound(찾고자 하는 값 이상이 처음 나타나는 위치)
        # = a가 몇 개의 경우의 수를 이기는지 개수
        start, end = 0, len(B) - 1

        while start <= end:
            mid = (start + end) // 2
            if a > B[mid]:
                start = mid + 1
            else:
                end = mid - 1

        result += end

    return result


# def get_wins_cnt(A: List[int], B: List[int]) -> int:
#     """
#     모든 주사위 점수의 합에 따른 승리 횟수 반환(TLE 발생)
#     """
#     result = 0

#     for a in A:
#         for b in B:
#             if a > b:
#                 result += 1

#     return result


def solution(dice: List[List[int]]) -> List[int]:
    n = len(dice)
    combinations = make_combinations(n, n // 2)  # 주사위 조합
    products = product(n // 2)  # 주사위 면의 조합

    answer_cnt = 0
    answer = None   # 명시적 초기화

    for A_picks in combinations:
        B_picks = set(range(n)) - A_picks

        A_picks_result = get_roll_result(dice, products, A_picks)  # A 주사위 결과
        B_picks_result = get_roll_result(dice, products, B_picks)  # B 주사위 결과

        wins_cnt = get_wins_cnt(A_picks_result, B_picks_result)  # A 주사위가 이긴 횟수

        if answer_cnt < wins_cnt:   # answer 업데이트
            answer_cnt = wins_cnt
            answer = sorted(list(map(lambda x: x + 1, A_picks)))    # 인덱스가 1부터 시작함, 오름차순

    return answer

📚 고찰

주어지는 input이 작아 완전 탐색으로 풀이를 진행하려 했을 때, 효율성을 개선할 여지가 있는 구현부에 대해 인지하고 있어야, 발생하는 TLE에 유연하게 대처할 수 있습니다. 미리미리 이를 메모하는 습관과, 이번 이분 탐색처럼 구현이 복잡하지 않은 경우라면 바로바로 도입하는 것도 괜찮을 것 같습니다.

🙏 문제 접근 방법 및 코드에 대한 피드백과 질문은 환영입니다!

✏️ Algorithm Study
본 문제의 다른 풀이 및 피드백, 전체 문제 풀이 순서는 위 알고리즘 스터디 Repository에서도 확인 가능합니다.

profile
Grow Exponentially

0개의 댓글