시간 제한 | 메모리 제한 | 제출 | 정답 | 맞은 사람 | 정답 비율 |
---|---|---|---|---|---|
2 초 | 512 MB | 105 | 42 | 39 | 52.000% |
크기가 N×M인 격자판의 각 칸에 정수가 하나씩 들어있다. 이 격자판에서 칸 K개를 선택할 것이고, 선택한 칸에 들어있는 수를 모두 더한 값의 최댓값을 구하려고 한다. 단, 선택한 두 칸이 인접하면 안된다. r행 c열에 있는 칸을 (r, c)라고 했을 때, (r-1, c), (r+1, c), (r, c-1), (r, c+1)에 있는 칸이 인접한 칸이다.
첫째 줄에 N, M, K가 주어진다. 둘째 줄부터 N개의 줄에 격자판에 들어있는 수가 주어진다.
선택한 칸에 들어있는 수를 모두 더한 값의 최댓값을 출력한다.
비트마스킹을 이용한 DP 문제임을 확인하고, 테이블을 어떻게 놓을지 고민했다.
가로로 아래로 내려가면서, 테이블을 채워넣는 것이 좋아보였다.
그래서 테이블을 dp[i][j][k] = i번째 줄까지 j개를 썼는데 i번째줄의 상태가 k일 때 최댓값
이렇게 놓고 푸니까 식이 바로 나왔다.
다만 전처리 해야되는 부분이 많았다.
우선, K가 가로로 붙어있지 않는 bit임을 확인하는 작업이 필요했다.
그리고, i번째 행에서 K를 뽑았을 때의 값도 필요했다.
마지막으로 K가 몇개의 bit로 이루어져있는지도 확인해야했다.
귀찮은 데이터 전처리만 해주면, 점화식은 간단하게 도출된다.
dp[i][j][k] = max(dp[i][j][k], dp[i - 1][j - k의 Bit][l(한번 더 비트 순회)] + sum[i][k]
행은 1부터 시작하고, 열은 0부터 시작하는 arr 배열에 값을 담아주었다.
2^(M) - 1 까지 순회하면서 유효한 비트인지 확인해주고 kSet에 담았다.
여기엔 bitset을 활용해서 간편하게 구현했다.
이 작업을 하는 과정에서 k의 비트수도 kBit 배열에 담아주었다.
그후 행과 kSet을 순회하면서 arr[i][k] 값을 찾아주었다.
비트연산자 &를 사용해서 선택한 칸의 값들을 더해주었다.
이후에는 dp 테이블을 INT_MIN으로 초기화해주었다.
마지막으로 dp 테이블을 채우는 과정이다.
i를 N까지 돌면서, j는 K까지 돈다.
k는 kSet에 있는 유효한 비트들을 뽑는다.
다만, j가 k의 비트수보다 작은 경우는 계산을 하지 않는다.
다시 kSet의 수들을 l에 넣어서 i - 1번째 테이블을 순회한다.
dp[i - 1]j - kBit[k][l] + sum[i][k]의 값이 기존 값보다 크면, 바꿔주면 된다.
j가 K인 경우에 ans 값을 max를 이용해 넣어주었다.
#include <bits/stdc++.h>
using namespace std;
#define ll long long int
#define FUP(i, a, b) for(int i = a; i <= b; i++)
#define FDOWN(i, a, b) for(int i = a; i >= b; i--)
#define MS(a, b) memset(a, b, sizeof(a))
#define ALL(v) v.begin(), v.end()
#define CIN(a) cin >> a;
#define CIN2(a, b) cin >> a >> b
#define CIN3(a, b, c) cin >> a >> b >> c
#define COUT(a) cout << a
#define COUT2(a, b) cout << a << ' ' << b
#define COUT3(a, b, c) cout << a << ' ' << b << ' ' << c
#define ENDL cout << '\n'
int dy[4] = { -1, 1, 0, 0 };
int dx[4] = { 0, 0, 1, -1 };
int N, M, K, dp[11][51][1024], arr[11][10], sum[11][1024], kBit[1024], ans = INT_MIN;
set<int> kSet;
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
MS(kBit, 0);
MS(sum, 0);
CIN3(N, M, K);
FUP(i, 1, N)
{
FUP(j, 0, M - 1)
{
CIN(arr[i][j]);
}
}
int m = (1 << M) - 1;
// k 전처리 (붙어있는 bit 없게)
FUP(k, 0, m)
{
bitset<10> tmp = bitset<10>(k);
bool ok = true;
FUP(i, 0, tmp.size() - 2)
{
if (tmp[i] && tmp[i + 1]) ok = false;
if (tmp[i]) kBit[k]++;
}
if (tmp[tmp.size() - 1]) kBit[k]++;
if (ok) kSet.insert(k);
}
// COUT(kSet.size()); 10일때 144
// arr 전처리 (K를 뽑았을 때 i번째의 합)
FUP(i, 0, N)
{
for(int k : kSet)
{
FUP(j, 0, M - 1)
{
if (k & (1 << j)) sum[i][k] += arr[i][j];
}
}
}
FUP(i, 0, N)
{
FUP(j, 0, K)
{
FUP(k, 0, m)
dp[i][j][k] = INT_MIN;
}
}
dp[0][0][0] = 0;
FUP(i, 1, N)
{
FUP(j, 0, K)
{
for(int k : kSet)
{
if (j < kBit[k]) continue;
for(int l : kSet)
{
if (!(k & l) && dp[i - 1][j - kBit[k]][l] != INT_MIN)
dp[i][j][k] = max(dp[i][j][k], dp[i - 1][j - kBit[k]][l] + sum[i][k]);
}
if (j == K) ans = max(ans, dp[i][j][k]);
}
}
}
COUT(ans);
return 0;
}