
주어지는 n개의 주사위를 A, B가 n//2개 씩 나눠 가졌을 때, A가 승리하는 확률이 가장 높도록 가져가는 주사위 조합을 구하면 되는 문제입니다.
주어지는 dice의 길이를 생각했을 때, 완전 탐색으로 접근하면 복잡하긴 해도 어렵지 않게 해결할 수 있을 것 같습니다.
솔루션의 백본으로 코드의 흐름을 설명하겠습니다. 코드의 depth에 주의하여 아래 솔루션을 이해하셨으면 좋겠습니다.
def solution(dice: List[List[int]]) -> List[int]:
n = len(dice)
combinations = make_combinations(n, n // 2) # 주사위 조합
products = product(n // 2) # 주사위 면의 조합
answer_cnt = 0
answer = None # 명시적 초기화
for A_picks in combinations:
B_picks = set(range(n)) - A_picks
A_picks_result = get_roll_result(dice, products, A_picks) # A 주사위 결과
B_picks_result = get_roll_result(dice, products, B_picks) # B 주사위 결과
wins_cnt = get_wins_cnt(A_picks_result, B_picks_result) # A 주사위가 이긴 횟수
if answer_cnt < wins_cnt: # answer 업데이트
answer_cnt = wins_cnt
answer = sorted(list(map(lambda x: x + 1, A_picks))) # 인덱스가 1부터 시작함, 오름차순
return answer
make_combinations)product)get_roll_result)get_win_cnt) def make_combinations(n: int, k: int) -> List[Set]:
"""
주사위 n개(0~n-1) 중 k개를 뽑는 조합 반환
ex) 주사위 4개 중 2개를 뽑기
[0, 1], [0, 2], ..., [2, 3]
"""
result = []
def recur(tmp):
if len(tmp) == k:
result.append(set(tmp))
return
for i in range(n):
if not tmp or i > tmp[-1]:
recur(tmp + [i])
recur([])
return result
itertools의 combinations 를 활용하면 쉽게 구할 수 있지만, 이와 같이 backtracking으로 구해낼 수 있습니다. 조합을 계산하는 것이므로, 중복을 포함하지 않기 위한 조건이 필요합니다.
def product(n: int) -> List[List[int]]:
"""
주사위 n개를 굴리는 경우의 수(인덱스 조합) 반환
ex) 주사위 3개인 경우
[0, 0, 0] ~ [5, 5, 5] 반환
"""
result = []
def recur(tmp: List) -> None:
if len(tmp) == n:
result.append(tmp)
return
for i in range(6):
recur(tmp + [i])
recur([])
return result
마찬가지로 itertools의 product를 활용한다면 따로 구현하지 않아도 됩니다. make_combination과 구현이 비슷하지만, 이 경우 가우시안 곱에 해당하므로 각 원소가 중복되어 조합을 구성하게 됩니다.
def get_roll_result(dice: List[List[int]], products: List[List[int]], picks: Set[int]) -> List[int]:
"""
주사위 pick에 따른 점수의 합의 모든 경우의 수 반환
ex) [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합
"""
result = []
for indice in products:
tmp_sum = 0
for idx, dice_idx in zip(indice, picks):
tmp_sum += dice[dice_idx][idx]
result.append(tmp_sum)
return result
make_combinations에서의 특정 주사위 조합 picks를 굴렸을 때 나오는 모든 면의 조합 products에 대한 결과를 구할 것입니다. 예를 들어 [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합을 원소로 가지는 리스트를 반환합니다.
매개변수로 받는 자료형의 channel에 주의하여 코드를 구현해야 합니다.
이제 여기서 A의 결과가 B의 결과를 몇 번 이기는 지를 계산하여 함수 바깥에서 이기는 횟수를 합산하면 됩니다. 여기서 아래와 같이 완전 탐색으로 이를 구하는 경우를 보겠습니다.
def get_wins_cnt(A: List[int], B: List[int]) -> int:
"""
모든 주사위 점수의 합에 따른 승리 횟수 반환(TLE 발생)
"""
result = 0
for a in A:
for b in B:
if a > b:
result += 1
return result
이 경우 주사위가 4개일 때만 해도 (6 x 6) x (6 x 6) = 1296번이고, 주사위는 10개까지 주어질 수 있으므로 최대 6^10 이상의 연산을 수행하게 됩니다. 너무 많은 연산량이고, 실제 효율성을 위한 테스트케이스에서 TLE가 발생하는 것을 확인할 수 있습니다.
이때, 해당 함수의 목적인 A의 결과가 B의 결과를 몇 번 이기는 지를 구해야 함을 생각했을 때, 이분 탐색을 고려할 수 있습니다.
아래와 같이 정렬된 두 리스트 V, W가 있을 때,
V = [3, 4, 5, 6 ,7]
W = [1, 2, 3, 4, 5]
V의 첫 번째 원소 3이 W의 원소 중 몇 개를 이기는지 계산해야 합니다. 이때 W 리스트에 대한 3의 lower_bound에 해당하는 위치 인덱스가 3보다 작은 W의 요소 개수와 같음을 알 수 있습니다.
lower bound : 찾고자 하는 값 이상이 처음 나타나는 위치
출처: https://12bme.tistory.com/120 [길은 가면, 뒤에 있다.:티스토리]
lower bound는 bisect 라이브러리의 bisect_left를 활용하여 구현하거나, 아래와 같이 이분탐색을 활용해 구현할 수 있습니다.
def get_wins_cnt(A: List[int], B: List[int]) -> int:
"""
모든 주사위 점수의 합에 따른 승리 횟수 반환(이분탐색 활용)
"""
result = 0
A.sort()
B.sort()
for a in A:
# B 리스트에 대한 a의 lower bound(찾고자 하는 값 이상이 처음 나타나는 위치)
# = a가 몇 개의 경우의 수를 이기는지 개수
start, end = 0, len(B) - 1
while start <= end:
mid = (start + end) // 2
if a > B[mid]:
start = mid + 1
else:
end = mid - 1
result += end
return result
A의 n개의 원소에 대해 B를 이분탐색(logn)했으므로, O(n^2)의 복잡도를 O(nlogn)으로 개선하여 TLE를 피할 수 있었습니다.
from typing import List, Set
def make_combinations(n: int, k: int) -> List[Set]:
"""
주사위 n개(0~n-1) 중 k개를 뽑는 조합 반환
ex) 주사위 4개 중 2개를 뽑기
[0, 1], [0, 2], ..., [2, 3]
"""
result = []
def recur(tmp):
if len(tmp) == k:
result.append(set(tmp))
return
for i in range(n):
if not tmp or i > tmp[-1]:
recur(tmp + [i])
recur([])
return result
def product(n: int) -> List[List[int]]:
"""
주사위 n개를 굴리는 경우의 수(인덱스 조합) 반환
ex) 주사위 3개인 경우
[0, 0, 0] ~ [5, 5, 5] 반환
"""
result = []
def recur(tmp: List) -> None:
if len(tmp) == n:
result.append(tmp)
return
for i in range(6):
recur(tmp + [i])
recur([])
return result
def get_roll_result(dice: List[List[int]], products: List[List[int]], picks: Set[int]) -> List[int]:
"""
주사위 pick에 따른 점수의 합의 모든 경우의 수 반환
ex) [0, 1]번 주사위가 각각 [2, 3]번 면이 나왔을 때의 합
"""
result = []
for indice in products:
tmp_sum = 0
for idx, dice_idx in zip(indice, picks):
tmp_sum += dice[dice_idx][idx]
result.append(tmp_sum)
return result
def get_wins_cnt(A: List[int], B: List[int]) -> int:
"""
모든 주사위 점수의 합에 따른 승리 횟수 반환(이분탐색 활용)
"""
result = 0
A.sort()
B.sort()
for a in A:
# B 리스트에 대한 a의 lower bound(찾고자 하는 값 이상이 처음 나타나는 위치)
# = a가 몇 개의 경우의 수를 이기는지 개수
start, end = 0, len(B) - 1
while start <= end:
mid = (start + end) // 2
if a > B[mid]:
start = mid + 1
else:
end = mid - 1
result += end
return result
# def get_wins_cnt(A: List[int], B: List[int]) -> int:
# """
# 모든 주사위 점수의 합에 따른 승리 횟수 반환(TLE 발생)
# """
# result = 0
# for a in A:
# for b in B:
# if a > b:
# result += 1
# return result
def solution(dice: List[List[int]]) -> List[int]:
n = len(dice)
combinations = make_combinations(n, n // 2) # 주사위 조합
products = product(n // 2) # 주사위 면의 조합
answer_cnt = 0
answer = None # 명시적 초기화
for A_picks in combinations:
B_picks = set(range(n)) - A_picks
A_picks_result = get_roll_result(dice, products, A_picks) # A 주사위 결과
B_picks_result = get_roll_result(dice, products, B_picks) # B 주사위 결과
wins_cnt = get_wins_cnt(A_picks_result, B_picks_result) # A 주사위가 이긴 횟수
if answer_cnt < wins_cnt: # answer 업데이트
answer_cnt = wins_cnt
answer = sorted(list(map(lambda x: x + 1, A_picks))) # 인덱스가 1부터 시작함, 오름차순
return answer
주어지는 input이 작아 완전 탐색으로 풀이를 진행하려 했을 때, 효율성을 개선할 여지가 있는 구현부에 대해 인지하고 있어야, 발생하는 TLE에 유연하게 대처할 수 있습니다. 미리미리 이를 메모하는 습관과, 이번 이분 탐색처럼 구현이 복잡하지 않은 경우라면 바로바로 도입하는 것도 괜찮을 것 같습니다.
🙏 문제 접근 방법 및 코드에 대한 피드백과 질문은 환영입니다!
✏️ Algorithm Study
본 문제의 다른 풀이 및 피드백, 전체 문제 풀이 순서는 위 알고리즘 스터디 Repository에서도 확인 가능합니다.