상근이는 영화 DVD 수집가이다. 상근이는 그의 DVD 콜렉션을 쌓아 보관한다.
보고 싶은 영화가 있을 때는, DVD의 위치를 찾은 다음 쌓아놓은 콜렉션이 무너지지 않게 조심스럽게 DVD를 뺀다. 영화를 다 본 이후에는 가장 위에 놓는다.
상근이는 DVD가 매우 많기 때문에, 영화의 위치를 찾는데 시간이 너무 오래 걸린다. 각 DVD의 위치는, 찾으려는 DVD의 위에 있는 영화의 개수만 알면 쉽게 구할 수 있다. 각 영화는 DVD 표지에 붙어있는 숫자로 쉽게 구별할 수 있다.
각 영화의 위치를 기록하는 프로그램을 작성하시오. 상근이가 영화를 한 편 볼 때마다 그 DVD의 위에 몇 개의 DVD가 있었는지를 구해야 한다.
자료 구조
세그먼트 트리
입력으로 주어지는 DVD의 번호에 대해, 해당 번호가 몇 번째 위치에 있는지 먼저 구해야 하는데, ary[]
로 해결했다. ary[idx]
는 idx
번호의 DVD에 대한 위치이다. ary[]
는 초기에 ary[i] = n - i;
의 역전된 값으로 초기화한다. 즉, 리프노드의 0
번 위치는 DVD의 맨 밑의 위치이고, n - 1
번 위치는 맨 위의 위치로 초기화된다.
이제 DVD의 위치 상태를 구현하면 되는데, 세그먼트 트리
를 이용하면 된다.
우선 기본적으로 n
개의 DVD에 m
개의 쿼리이므로 세그먼트 트리의 총 리프노드의 개수는 n - m
개가 될 것이다.
우선 가장 왼쪽에 모든 DVD를 정렬하여 놓는다. 즉, 세그먼트 트리의 리프노드의0
~n
까지는 1
로 초기화하고, n + 1
~n + m
까지는 0
으로 초기화한다. 1
이 곧 해당 위치에 DVD가 있음을 이야기한다.
이제 i
번째의 쿼리를 가정하고,in
번호의 DVD를 꺼내본다면, in
의 위치를 ary[in]
을 통해 받아 온다. 이후 ary[in]
위치를 0
으로 업데이트 한 뒤, 해당 위치부터 n + m - 1
까지의 합이 곧 위에 쌓여 있는 DVD의 개수가 된다.
이후 in
의 위치, 즉 ary[in]
를 n + i
로 치환해준다.
즉, 의 구간은 기본적으로 DVD가 쌓여있는 구간이고, 의 m
개의 구간은 쿼리가 진행되면서 DVD가 쌓이는 구간이다.
어떠한 구간 사이의 거리를 구할 때 세그먼트 트리를 응용하는 문제였다.
#include <stdio.h>
#include <iostream>
#include <algorithm>
using namespace std;
int seg[800000], t, n, m, ary[100001];
int construct(int l, int r, int idx)
{
if (l == r) {
if (r < n) seg[idx] = 1;
else seg[idx] = 0;
}
else {
int mid = (l + r) / 2;
seg[idx] = construct(l, mid, idx * 2 + 1) + construct(mid + 1, r, idx * 2 + 2);
}
return seg[idx];
}
int update(int l, int r, int idx, int loc, int val)
{
if (loc < l || loc > r) return seg[idx];
if (l == r) {
if (l == loc) seg[idx] = val;
}
else {
int mid = (l + r) / 2;
seg[idx] = update(l, mid, idx * 2 + 1, loc, val) + update(mid + 1, r, idx * 2 + 2, loc, val);
}
return seg[idx];
}
int sum(int start, int end, int l, int r, int idx)
{
if (r < start || l > end) return 0;
if (start <= l && r <= end) return seg[idx];
int mid = (l + r) / 2;
return sum(start, end, l, mid, idx * 2 + 1) + sum(start, end, mid + 1, r, idx * 2 + 2);
}
int main()
{
int in, in2;
cin >> t;
while (t--) {
scanf("%d%d", &n, &m);
construct(0, n + m - 1, 0);
for (int i = 1; i <= n; i++)
ary[i] = n - i;
for (int i = 0; i < m; i++) {
scanf("%d", &in);
in2 = ary[in];
update(0, n + m - 1, 0, in2, 0);
printf("%d ", sum(in2, n + m - 1, 0, n + m - 1, 0));
update(0, n + m - 1, 0, n + i, 1);
ary[in] = n + i;
}
printf("\n");
}
return 0;
}