https://www.acmicpc.net/problem/16198
처음에는 곱한 에너지의 크기가 큰 순으로 정렬하는 우선순위 큐를 하나 만드는 쪽으로 접근을 하였다. 처음에 짠 코드는 다음과 같다.
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.PriorityQueue;
public class Main {
static boolean[] isSelected;
static int[] arr; // 값
static int[] left; // 유효한 왼쪽 인덱스
static int[] right; // 유효한 오른쪽 인덱스
static class Node implements Comparable<Node>{
int m;
int l;
int r;
Node (int m, int l, int r) {
this.m = m;
this.l = l;
this.r = r;
}
@Override
public int compareTo(Node o) {
return arr[o.l] * arr[o.r] - arr[this.l] * arr[this.r];
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
int n = Integer.parseInt(br.readLine());
String[] inputs = br.readLine().split(" ");
PriorityQueue<Node> pq = new PriorityQueue<>();
arr = new int[n];
left = new int[n];
right = new int[n];
for (int i = 0; i < n; i++) {
arr[i] = Integer.parseInt(inputs[i]);
}
for (int i = 1; i < n - 1; i++) {
pq.add(new Node(i, i - 1, i + 1));
left[i] = i - 1;
right[i] = i + 1;
}
int sum = 0;
while (!pq.isEmpty()) {
Node node = pq.remove();
int lPt = node.l;
int rPt = node.r;
int mPt = node.m;
// pq에 넣어둔 본인 기준 좌우가 현재 상태에서 본인 기준 좌우가 다르면 건너뛴다.
// (기존에 넣어둔 좌우 중 누군가는 사라졌다는 의미)
if (left[mPt] != lPt || right[mPt] != rPt) {
continue;
}
sum += arr[lPt] * arr[rPt];
if (lPt != 0) {
right[lPt] = right[mPt];
pq.add(new Node(lPt, left[lPt], right[lPt]));
}
if (rPt != n - 1) {
left[rPt] = left[mPt];
pq.add(new Node(rPt, left[rPt], right[rPt]));
}
}
bw.write(Integer.toString(sum));
bw.flush();
bw.close();
}
}
예제는 다 정답을 출력하였지만 제출하자마자 틀렸다고 떴다. 곰곰히 반례를 생각해보았다.
반례
3
3 1 2 2 3 3
이와 같이 입력 값이 들어왔을 때 처음에 본인 기준 양 옆의 값의 곱이 6인 부분이 많다. 이럴 경우 일단 먼저 들어온 중앙 값에 대해서 삭제가 이루어지는데, 정답이라는 보장이 없다.
결론적으로 모두 따져봐야 한다는 뜻이다. 경우의 수는 최대 2^8 - 1, 경우의 수 하나 당 최댓값을 구하는데 넉넉잡아 10으로 쳤을 때 시간 내에 해결 가능하다고 판단하였다.
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.PriorityQueue;
public class Main {
static boolean[] isSelected;
static int[] arr; // 값
static int[] left; // 유효한 왼쪽 인덱스
static int[] right; // 유효한 오른쪽 인덱스
static boolean[] isDead;
static int n;
static int answer = 0;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
n = Integer.parseInt(br.readLine());
String[] inputs = br.readLine().split(" ");
arr = new int[n];
left = new int[n];
right = new int[n];
isDead = new boolean[n];
for (int i = 0; i < n; i++) {
arr[i] = Integer.parseInt(inputs[i]);
}
for (int i = 1; i < n - 1; i++) {
left[i] = i - 1;
right[i] = i + 1;
}
dfs(0, 0);
bw.write(Integer.toString(answer));
bw.flush();
bw.close();
}
static void dfs(int depth, int sum) {
if (depth == n - 2) {
if (sum > answer) {
answer = sum;
}
return;
}
for (int i = 1; i < n - 1; i++) {
if (!isDead[i]) {
isDead[i] = true;
int l = left[i];
int r = right[i];
int value = arr[l] * arr[r];
if (l != 0) {
right[l] = right[i];
}
if (r != n - 1) {
left[r] = left[i];
}
// 호출
dfs(depth + 1, sum + value);
right[l] = i;
left[r] = i;
isDead[i] = false;
}
}
}
}