https://codeforces.com/contest/1942/problem/E
딱 보자마자, 게임이론과 조합을 사용하는 것 같다는 생각이 들었다.
적절한 간격을 결정하고, 그것을 기반으로 조합을 구하면 될 것 같았다. 하지만, 규칙을 찾는 게 잘되지 않았다.
그래서 에디토리얼을 봤는데, (a,b)의 사이가 모두 짝수이면 무조건 후에 시작하는 사람이 이긴다는 것이다.
이론은 이랬다. 먼저 시작하는 사람이 둘 사이의 간격을 무조건 홀수로 만들게 된다. 이미 짝수니까 어쩔 수 없다.
그러면 반대편의 소를 또 당겨서 짝수로 만들어준다.
0도 짝수이고, 계속해서 먼저 시작한 애의 절대적인 위치의 합은 줄어들거나 커지게 되니, 무조건 질 수밖에 없는 것이다.
또한, 그 반대의 경우를 봐보자. 전부 다 even이 아닌 경우, 이 경우 진짜 이해가 안 됐는데, 결국 찾아냈다. 문제를 잘못 이해한 것이다. k 개의 소를 골라서 한 번에 움직이는 거였는데, 나는 k 번째 소를 골라서 한 칸 움직이는 줄 알았다. 이것을 알았다면 이 문제도 풀 수 있었을까? 풀지는 못했을 것 같은데, 이와 비슷한 접근했을 것 같다. 조금 아쉽다.
그래서, 이러한 가정과 함께 먼저 시작하는 이가 이기는 경우는 여집합을 활용하면 되는 것이다.
즉, 후에 시작하는 사람이 이기는 경우의 수를 구하고 모든 경우의 수에서 후에 시작하는 사람이 이기는 경우의 수를 빼주면 되는 것이다.
모든 경우의 수 - 후에 시작하는 사람이 이기는 경우의 수 (모든 간격이 짝수인 경우)
결과적으로 위를 구하기 위해서 간격의 총합을 짝수로 두고 시작할 수 있다. 그러므로 for 문이 나오게 되는 것이다. 이제 수식만 세우면 된다. 저 수식이 근데 진짜 이해가 안 갔다. 하지만, 결국 이해하게 되었는데, 이해한 바는 다음과 같다.
결과적으로 나오게 된 수식은 2 * (s/2+n-1) C (n-1) * (l-s-n) C (n) 이다.
하나씩 뜯어보자. 2를 곱한 이유는 시작점을 a로 두느냐 b를 두느냐 두 가지의 경우가 있기 때문이다. (s/2+n-1) C (n-1) 이거를 분석하면 중복 조합으로 풀어내야 한다. 중복 조합은 nHr이라고 했을 때 (n+r-1) C (r) -> (n+r-1) C (n-1) 이라고 정의할 수 있다. 그러면 간단하게 nHr를 구하면 되는 것이다. 잘 생각해 보자. 각각의 (a,b) 짝 들은 2를 하나씩 고를 수 있다, 물론 중복해서 고를 수 있다. 즉, n 개의 짝들을 s/2번 고를 수 있는 것이다.
이렇게 되면 (n) H (s/2) -> (s/2+n-1) C (n) 로 유도할 수 있다. 자 그러면 이제 (l-s-n) C (n) 을 구할 수 있는데, 이거는 굉장히 쉽다. 이미 간격을 정해놓았으니 각각의 짝들은 하나의 요소라고 볼 수 있다. 다만, s+2의 폭을 가진 요소가 되는 것이다. 그러니 얘네들을 이제 여분의 남은 공간에 분배하기만 하면 된다.
이미 정해진 요소들은 총 l-s-2 * n의 면적을 차지하게 된다. 그러니 선택할 수 있는 영역은 l-s-n이 되는 것이다. 그러면 끝났다. 선택할 수 있는 영역에서 n 개를 선택하면 된다. 그렇기 때문에 (l-s-n) C (n) 이 되는 것이다. 그리고 이거를 곱해주면 된다. 왜? 각각의 간격을 정하는 경우의 수 * 실제로 놓는 경우의 수이기 때문이다. 그러면 이제 전체 경우의 수 - 구해준 후에 시작한 플레이어가 이기는 경우의 수를 빼주면 된다. 전체 경우의 수는 아주 간단하게 (l) C (n) 이다. l 개의 영역 중 n 개를 순서 상관없이 고르면 되기 때문이다.
이렇게 해서 문제를 풀 수 있었다. 구현은 정말 간단하나 이해하는 게 참 힘든 문제였다.
O(l)
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <deque>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;
const int MOD = 998244353;
vector<long long> fact = {1};
long long pow_mod(long long x, long long p) {
if (p == 0) {
return 1;
}
if (p % 2 == 0) {
long long y = pow_mod(x, p / 2);
return (y * y) % MOD;
}
return (x * pow_mod(x, p - 1)) % MOD;
}
long long inv(long long x) {
return pow_mod(x, MOD - 2);
}
long long cnk(long long n, long long k) {
if (n < 0 || k < 0) {
return 0;
}
if (n < k) {
return 0;
}
long long res = fact[n];
res = (res * inv(fact[k])) % MOD;
res = (res * inv(fact[n - k])) % MOD;
return res;
}
void solve() {
int l, n;
cin >> l >> n;
long long allEven = 0;
for (int i = 0; i <= l; i += 2) {
allEven += 2 * cnk(i / 2 + n - 1, n - 1) % MOD * cnk(l - i - n, n) % MOD;
allEven %= MOD;
}
cout << ((2 * cnk(l, 2 * n) % MOD - allEven + MOD) % MOD) << "\n";
}
int main(void) {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
freopen("e.input.txt", "r", stdin);
for (int i = 1; i <= 1e6; i++) {
fact.push_back((fact.back() * i) % MOD);
}
int T;
cin >> T;
while (T-- > 0) {
solve();
}
return 0;
}