문제의 초반에 주어진 조건인 정점의 개수가 개이고 간선의 개수가 개인 하나의 컴포넌트 = 트리의 정의와 같은 말입니다.
일반적인 그래프에서는 사이클이 존재할 수 있기 때문에 문제를 재귀적으로 구성해서 풀 수 없습니다. 하지만 트리에서는 이야기가 다릅니다. 트리는 간선에 방향을 준다면 DAG이기때문에 상태공간을 잘 정의해서 다이나믹 프로그래밍을 적용시킬 수 있습니다.
문제에 입력으로 주어지는 트리는 부모 자식 관계로 주어지는 것이 아니라 단순히 연결 관계로 주어지기 때문에 간선에 방향성을 주어서 DAG로 만들어야 합니다. 문제에서 번 정점이 루트라고 주었기 때문에 번 정점부터 그래프를 탐색하면서 새로운 정점을 만날때 마다 간에 간선을 추가해주면서 그래프를 만들 수 있습니다. 똑같은 정점을 여러 번 방문하지 않기 때문에 시간 복잡도는 입니다.
문제는 점화식 구성입니다. 경우의 수를 잘못 나눴다가는 중복되거나 빼먹을 수 있기 때문입니다. 점화식을 쉽게 구성하는 방법은 작은 크기의 데이터를 만들어보고 손으로 직접 해보는 것입니다. 다음과 같은 트리에서는 가지 방법이 가능합니다.
대충 점화식의 인자부터 넣으면서 식을 세워봅시다. 당연히 현재 몇 번째 정점인지는 넣어야겠죠 또, 우리는 오름차순으로 정점을 골라야 하므로 맨 마지막에 고른 정점의 번호도 저장을 해둡시다. 그러면 다음과 같습니다.
번째 정점을 루트로 하는 서브트리에서 가장 마지막으로 고른 정수가 일 때, 오름차순으로 정점을 선택하는 경우의 수.
이 점화식의 Base case부터 생각해봅시다. 트리에서 Base case라고 하면 당연히 리프노드인 경우겠죠. 경우의 수는 그 리프노드를 안 고르는 경우 가지, 고를 수 있다면 고르는 경우 가지입니다. 그런데 이러면 문제의 정의와 모순이 생깁니다. 우리는 오름차순으로 정점을 선택하는 경우의 수라고 정의했기 때문이죠. 예제 입력 1에서도 알 수 있듯이 한 번도 고르지 않은 경우는 경우의 수에 포함이 되지 않습니다.
어쩔수 없이 새로운 인자를 추가해줍시다. 는 번째 정점을 선택했는지 안했는지 여부입니다.
Base Case : 이라면 을 반환하고 그렇지 않다면 을 반환.
번째 정점을 루트로 하는 서브트리에서 가장 마지막으로 고른 정수가 이고, here번째 정점을 선택하는지 여부가 flag로 주어질 때, 오름차순으로 정점을 선택하는 경우의 수.
점화식은 간단하게 모든 자식들에 대해서 고를 수 있다면 고르는 경우, 고르지 않는 경우 두 가지로 나눌 수 있습니다.
(if )
하지만 이게 다가 아닙니다. 아직 빠진 경우의 수가 있습니다. 앞서 언급했듯 우리가 세운 점화식은 오름차순으로 정점을 선택하는 경우의 수입니다.
이 그림에서 세 번째 상황이죠. 번째 정점을 선택하고 더 이상 선택하지 않는 경우의 수도 하나 추가해줍니다. 당연히 가 이어야 하겠죠.
adj는 단순히 연결 관계만 나타내는 간선입니다. 이를 부모-자식 관계의 간선 children으로 바꿔주기 위해 그래프를 한 번 탐색해줍니다.
모듈러 연산자는 비싼 연산자입니다. 이를 MOD보다 커지면 MOD를 빼 주는 식으로 구현할 수 있습니다.
public class Main {
static int N;
static int[] S;
static ArrayList<ArrayList<Integer>> adj;
static ArrayList<ArrayList<Integer>> children;
static void dfs(int here, int prev) {
for (int there : adj.get(here)) if (there != prev) {
children.get(here).add(there);
dfs(there, here);
}
}
static int[][][] cache;
static final int MOD = 1000000007;
// here번째 정점을 루트로 하는 서브 트리에서
// 가장 마지막으로 고른 정점에 적힌 정수가 prev이고 (안 골랐어도 0)
// here번째 정점을 골랐는지 여부 flag가 주어질 때,
// 오름차순으로 정점을 고르는 경우의 수
static int dp(int here, int prev, int flag) {
// 리프 노드인 경우 here 정점을 고르는 경우만 경우를 하나 찾은 것
if (children.get(here).isEmpty()) return flag;
if (cache[here][prev][flag] != -1) return cache[here][prev][flag];
// here 정점을 선택하고 here의 자손으로부터는 안 고르는 경우 1
int sum = flag;
for (int there : children.get(here)) {
// there을 고를 수 있으면 고르는 경우
if (prev <= S[there]) sum += dp(there, S[there], 1);
if (sum >= MOD) sum -= MOD;
// 고르지 않는 경우
sum += dp(there, prev, 0);
if (sum >= MOD) sum -= MOD;
}
return cache[here][prev][flag] = sum;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
N = Integer.parseInt(br.readLine());
S = new int[N + 1];
StringTokenizer st = new StringTokenizer(br.readLine(), " ");
for (int i = 1; i <= N; i++) S[i] = Integer.parseInt(st.nextToken());
adj = new ArrayList<>();
for (int i = 0; i <= N; i++) adj.add(new ArrayList<>());
for (int i = 0; i < N - 1; i++) {
st = new StringTokenizer(br.readLine(), " ");
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
adj.get(u).add(v); adj.get(v).add(u);
}
children = new ArrayList<>();
for (int i = 0; i <= N; i++) children.add(new ArrayList<>());
dfs(1, 1);
cache = new int[N + 1][10][2];
for (int i = 0; i < cache.length; i++)
for (int j = 0; j < cache[i].length; j++)
Arrays.fill(cache[i][j], -1);
System.out.println((dp(1, 0, 0) + dp(1, S[1], 1)) % MOD);
}
}