해당 문제는 3가지 단계로 나눠야합니다.
해당 단계에 맞춰서 설명해보겠습니다.
이렇게 flow를 알고, 시간복잡도를 계산해보자.
1초가 약 1억번을 계산해야하는데,
만약 주사위가 10개라면,
주사위를 나눠가지는 경우의 수가 10C5이므로, 252
주사위 1개당 6면이므로 6^5, 7776
2527776 = 1,959,552인데
이건 B도 동일하게 나오므로
1,959,5521,959,552을 하면 계산 수가 너무 크게 나온다.
그러므로 마지막 n^2을 어떻게든 줄여야한다.
일단 이 배경지식을 알고 시작해보자.
주사위 조합 구하기
조합을 이용하는 방법은 유명한 prev_permutation이 있다.
해당 내용은 구글링으로 찾아보자.
vector<vector<int>> combi;
d_num = dice.size();
for(int i=0;i<d_num;i++){
d_idx.push_back(i+1);
}
for(int i=0;i<d_num/2;i++){
temp.push_back(1);
}
for(int i=0;i<d_num/2;i++){
temp.push_back(0);
}
do{
vector<int>tmp;
for(int i=0;i<d_num;i++){
if(temp[i]==1) tmp.push_back(d_idx[i]);
}
combi.push_back(tmp);
}while(prev_permutation(temp.begin(),temp.end()));
이렇게 하면 combi 벡터에
{1,2},{1,3},,,이런식으로 담기게 된다.
2번 주사위를 던져서 나올 수 있는 합의 경우 구하기
만약 A가 주사위를 1,2,3번 을 가져갔다고 치자.
그럼
만약 주사위 1개에 3면이 있다 치면
1번 1,2,3
2번 4,5,6
3번 7,8,9
가 있으면
1+4+7, 1+4+8, 1+4+9,,, 이런식으로 계산을 해줘야한다.
이것은 백트레킹과 비슷하다.
void findDiceCombi(int depth, int sum, vector<int>&A_D, vector<int>&combi,vector<vector<int>>&dice){
if(depth == combi.size()){
A_D.push_back(sum);
return;
}
for(int i=0;i<dice[combi[depth]-1].size();i++){
sum+=dice[combi[depth]-1][i];
findDiceCombi(depth+1,sum,A_D,combi,dice);
sum-=dice[combi[depth]-1][i];
}
}
depth가 0부터 시작해서 combi.size()만큼 증가했으면 1,2,3번 주사위의 각 면을 한번씩은 더한거다.
그리고 for문을 보면 sum에다가 dice값들을 더해주고 재귀함수를 통해서 계산하였다.
이렇게 되면, A_D라는 vector에 1,2,3주사위를 뽑았을경우 나올 수 있는 경우의 수들이 담기게 된다.
int start = 0;
int end = combi.size()-1;
while(start<end){
vector<int>A_D;
vector<int>B_D;
findDiceCombi(0,0,A_D,combi[start],dice);
findDiceCombi(0,0,B_D,combi[end],dice);
sort(A_D.begin(),A_D.end());
sort(B_D.begin(),B_D.end());
solution의 while문을 보면 start와 end를 사용하고 있다.
그 이유는
combi가 현재 {1,2},{1,3},{1,4},{2,3},{2,4},{3,4} 이런상태인데
A가 1,2를 가져가면 자동적으로 B는 마지막인 3,4를 가져갈 수 밖에 없다.
A가 1,3을 가져가면 B는 마지막에서 두번째인 2,4를 가져갈 수 밖에 없기 때문이다.
그리고 이분탐색을 위한 정렬을 해준다.
int win1 =0;
int win2 =0;
for(auto n : A_D){
int win = lower_bound(B_D.begin(),B_D.end(),n) - B_D.begin();
if(win>=0) win1+=win;
}
for(auto n : B_D){
int win = lower_bound(A_D.begin(),A_D.end(),n) - A_D.begin();
if(win>=0) win2+=win;
}
if(win1>win2&& win1>maxWin){
answer = combi[start];
maxWin = win1;
}else if(win2 > win1 && win2 > maxWin){
answer = combi[end];
maxWin = win2;
}
start++;
end--;
앞에서 n^2을 줄여야하는데 이러면 nlogn으로 줄일 수 있다. 왜냐하면,
A_D: 1,1,2,3,4,5
B_D: 4,4,4,7,8,9
이렇게 있다고 치면, A가 A_D를 골랐을 경우 1은 아무것도 이길 수 가없다. 최소 5를 골라야 4,4,4를 이길 수 있다.
그러므로, A_D가 B_D를 이길수있는 경우의 수를 구하고, B_D가 A_D를 이길 수 있는 경우를 각각 구한다음에,
maxWin을 갱신해준다.
전체코드
#include <string>
#include <vector>
#include <algorithm>
#include <iostream>
using namespace std;
int d_num=0;
vector<int> d_idx;
vector<int> temp;
vector<vector<int>> combi;
int maxWin =0;
void findDiceCombi(int depth, int sum, vector<int>&A_D, vector<int>&combi,vector<vector<int>>&dice){
if(depth == combi.size()){
A_D.push_back(sum);
return;
}
for(int i=0;i<dice[combi[depth]-1].size();i++){
sum+=dice[combi[depth]-1][i];
findDiceCombi(depth+1,sum,A_D,combi,dice);
sum-=dice[combi[depth]-1][i];
}
}
vector<int> solution(vector<vector<int>> dice) {
vector<int> answer;
d_num = dice.size();
for(int i=0;i<d_num;i++){
d_idx.push_back(i+1);
}
for(int i=0;i<d_num/2;i++){
temp.push_back(1);
}
for(int i=0;i<d_num/2;i++){
temp.push_back(0);
}
do{
vector<int>tmp;
for(int i=0;i<d_num;i++){
if(temp[i]==1) tmp.push_back(d_idx[i]);
}
combi.push_back(tmp);
}while(prev_permutation(temp.begin(),temp.end()));
int start = 0;
int end = combi.size()-1;
while(start<end){
vector<int>A_D;
vector<int>B_D;
findDiceCombi(0,0,A_D,combi[start],dice);
findDiceCombi(0,0,B_D,combi[end],dice);
sort(A_D.begin(),A_D.end());
sort(B_D.begin(),B_D.end());
int win1 =0;
int win2 =0;
for(auto n : A_D){
int win = lower_bound(B_D.begin(),B_D.end(),n) - B_D.begin();
if(win>=0) win1+=win;
}
for(auto n : B_D){
int win = lower_bound(A_D.begin(),A_D.end(),n) - A_D.begin();
if(win>=0) win2+=win;
}
if(win1>win2&& win1>maxWin){
answer = combi[start];
maxWin = win1;
}else if(win2 > win1 && win2 > maxWin){
answer = combi[end];
maxWin = win2;
}
start++;
end--;
}
return answer;
}
이번문제는 아래의 블로그를 보고 풀었다.