A와 B가 n개의 주사위를 가지고 승부를 합니다. 주사위의 6개 면에 각각 하나의 수가 쓰여 있으며, 주사위를 던졌을 때 각 면이 나올 확률은 동일합니다. 각 주사위는 1 ~ n의 번호를 가지고 있으며, 주사위에 쓰인 수의 구성은 모두 다릅니다.
A가 먼저 n / 2개의 주사위를 가져가면 B가 남은 n / 2개의 주사위를 가져갑니다. 각각 가져간 주사위를 모두 굴린 뒤, 나온 수들을 모두 합해 점수를 계산합니다. 점수가 더 큰 쪽이 승리하며, 점수가 같다면 무승부입니다.
A는 자신이 승리할 확률이 가장 높아지도록 주사위를 가져가려 합니다.
다음은 n = 4인 예시입니다.
주사위 | 구성 |
---|---|
#1 | [1, 2, 3, 4, 5, 6] |
#2 | [3, 3, 3, 3, 4, 4] |
#3 | [1, 3, 3, 4, 4, 4] |
#4 | [1, 1, 4, 4, 5, 5] |
예를 들어 A가 주사위 #1, #2를 가져간 뒤 6, 3을 굴리고, B가 주사위 #3, #4를 가져간 뒤 4, 1을 굴린다면 A의 승리입니다. (6 + 3 > 4 + 1)
A가 가져가는 주사위 조합에 따라, 주사위를 굴린 1296가지 경우의 승패 결과를 세어보면 아래 표와 같습니다.
A의 주사위 | 승 | 무 | 패 |
---|---|---|---|
#1, #2 | 596 | 196 | 504 |
#1, #3 | 560 | 176 | 560 |
#1, #4 | 616 | 184 | 496 |
#2, #3 | 496 | 184 | 616 |
#2, #4 | 560 | 176 | 560 |
#3, #4 | 504 | 196 | 596 |
A가 승리할 확률이 가장 높아지기 위해선 주사위 #1, #4를 가져가야 합니다.
주사위에 쓰인 수의 구성을 담은 2차원 정수 배열 dice가 매개변수로 주어집니다. 이때, 자신이 승리할 확률이 가장 높아지기 위해 A가 골라야 하는 주사위 번호를 오름차순으로 1차원 정수 배열에 담아 return 하도록 solution 함수를 완성해 주세요. 승리할 확률이 가장 높은 주사위 조합이 유일한 경우만 주어집니다.
dice | result |
---|---|
[[1, 2, 3, 4, 5, 6], [3, 3, 3, 3, 4, 4], [1, 3, 3, 4, 4, 4], [1, 1, 4, 4, 5, 5]] | [1, 4] |
[[1, 2, 3, 4, 5, 6], [2, 2, 4, 4, 6, 6]] | [2] |
[[40, 41, 42, 43, 44, 45], [43, 43, 42, 42, 41, 41], [1, 1, 80, 80, 80, 80], [70, 70, 1, 1, 70, 70]] | [1, 3] |
N 의 범위가 10 밖에 안되기 때문에, 완전 탐색으로 문제를 풀 수 있다.
주사위를 가져가는 방법은 조합으로 가져올 수 있다.
이게 최대 10개 중 5개를 가져오는 방법이므로
그리고, 서로 던진 주사위의 합을 비교하여 승리 횟수를 구한다.
여기서 던진 주사위 합의 최대 연산 횟수는 이다.
각각 던져 얻어낸 주사위 합 7,776 개의 결과 A, 7,776 개의 결과 B 에서 A가 B 보다 큰 것을 찾아낸다.
이 방법에서 단순히 2중 반복문을 통해서 구했는데, 으로 시간 초과가 발생했다.
그래서 다음 방법으로 bisect 를 사용했다.
A 의 합 배열에서 각 원소들을 뽑아 키로 만들고, 그것을 통해서 1번의 반복문을 돌린다.
그리고, A에 bisect_right 를 하여 i 번째 수가 몇 개를 가지고 있는지 알아내고, B에 bisect_right를 하여 i 보다 작은 수가 몇 개 가지고 있는지 알아낸다.
그리고 서로 곱한 것을 더하면 총 승리 수를 얻을 수 있다.
다음 문제 풀 때, 이분 탐색을 적극적으로 활용해야겠다.
그리고 문제를 잘 읽자...
import itertools
from collections import defaultdict
from bisect import bisect_left, bisect_right
def get_sums(comb, dice, N):
sums = []
for i in range(6):
d1 = dice[comb[0]]
if N > 1:
d2 = dice[comb[1]]
for j in range(6):
if N > 2:
d3 = dice[comb[2]]
for k in range(6):
if N > 3:
d4 = dice[comb[3]]
for l in range(6):
if N > 4:
d5 = dice[comb[4]]
for m in range(6):
sums.append(d1[i] + d2[j] + d3[k] + d4[l] + d5[m])
else:
sums.append(d1[i] + d2[j] + d3[k] + d4[l])
else:
sums.append(d1[i] + d2[j] + d3[k])
else:
sums.append(d1[i] + d2[j])
else:
sums.append(d1[i])
return sums
def solution(dice):
answer = []
N = len(dice)
combs = list(itertools.combinations([i for i in range(N)], N // 2))
results = defaultdict(list)
A = combs
B = combs
max_wins = 0
for a in A:
if results[a]:
continue
for b in B:
is_duplicate = False
for _ in a:
if _ in b:
is_duplicate = True
break
if is_duplicate:
continue
a_sums = get_sums(a, dice, len(a))
b_sums = get_sums(b, dice, len(b))
a_sums.sort()
b_sums.sort()
a_keys = list(set(a_sums))
a_keys.sort()
wins = 0
last_num = 0
for num in a_keys:
a_count = bisect_right(a_sums, num) - bisect_right(a_sums, last_num)
b_count = bisect_left(b_sums, num)
wins += a_count * b_count
last_num = num
if max_wins < wins:
max_wins = wins
answer = a
print(max_wins)
ans = []
for i in answer:
ans.append(i+1)
print(ans)
return ans