[C++] 백준 3653: 영화 수집

Cyan·2024년 6월 20일
0

코딩 테스트

목록 보기
155/166

백준 3653: 영화 수집

문제 요약

상근이는 영화 DVD 수집가이다. 상근이는 그의 DVD 콜렉션을 쌓아 보관한다.

보고 싶은 영화가 있을 때는, DVD의 위치를 찾은 다음 쌓아놓은 콜렉션이 무너지지 않게 조심스럽게 DVD를 뺀다. 영화를 다 본 이후에는 가장 위에 놓는다.

상근이는 DVD가 매우 많기 때문에, 영화의 위치를 찾는데 시간이 너무 오래 걸린다. 각 DVD의 위치는, 찾으려는 DVD의 위에 있는 영화의 개수만 알면 쉽게 구할 수 있다. 각 영화는 DVD 표지에 붙어있는 숫자로 쉽게 구별할 수 있다.

각 영화의 위치를 기록하는 프로그램을 작성하시오. 상근이가 영화를 한 편 볼 때마다 그 DVD의 위에 몇 개의 DVD가 있었는지를 구해야 한다.

문제 분류

  • 자료 구조
  • 세그먼트 트리

문제 풀이

입력으로 주어지는 DVD의 번호에 대해, 해당 번호가 몇 번째 위치에 있는지 먼저 구해야 하는데, ary[]로 해결했다. ary[idx]idx번호의 DVD에 대한 위치이다. ary[]는 초기에 ary[i] = n - i;의 역전된 값으로 초기화한다. 즉, 리프노드의 0번 위치는 DVD의 맨 밑의 위치이고, n - 1번 위치는 맨 위의 위치로 초기화된다.

이제 DVD의 위치 상태를 구현하면 되는데, 세그먼트 트리를 이용하면 된다.
우선 기본적으로 n개의 DVD에 m개의 쿼리이므로 세그먼트 트리의 총 리프노드의 개수는 n - m개가 될 것이다.
우선 가장 왼쪽에 모든 DVD를 정렬하여 놓는다. 즉, 세그먼트 트리의 리프노드의0~n까지는 1로 초기화하고, n + 1~n + m까지는 0으로 초기화한다. 1이 곧 해당 위치에 DVD가 있음을 이야기한다.

이제 i번째의 쿼리를 가정하고,in번호의 DVD를 꺼내본다면, in의 위치를 ary[in]을 통해 받아 온다. 이후 ary[in]위치를 0으로 업데이트 한 뒤, 해당 위치부터 n + m - 1까지의 합이 곧 위에 쌓여 있는 DVD의 개수가 된다.
이후 in의 위치, 즉 ary[in]n + i로 치환해준다.

즉, [0,n1][0, n - 1]의 구간은 기본적으로 DVD가 쌓여있는 구간이고, [n,n+m1][n, n + m - 1]m개의 구간은 쿼리가 진행되면서 DVD가 쌓이는 구간이다.

어떠한 구간 사이의 거리를 구할 때 세그먼트 트리를 응용하는 문제였다.

풀이 코드

#include <stdio.h>
#include <iostream>
#include <algorithm>

using namespace std;

int seg[800000], t, n, m, ary[100001];

int construct(int l, int r, int idx)
{
	if (l == r) {
		if (r < n) seg[idx] = 1;
		else seg[idx] = 0;
	}
	else {
		int mid = (l + r) / 2;
		seg[idx] = construct(l, mid, idx * 2 + 1) + construct(mid + 1, r, idx * 2 + 2);
	}
	return seg[idx];
}

int update(int l, int r, int idx, int loc, int val)
{
	if (loc < l || loc > r) return seg[idx];
	if (l == r) {
		if (l == loc) seg[idx] = val;
	}
	else {
		int mid = (l + r) / 2;
		seg[idx] = update(l, mid, idx * 2 + 1, loc, val) + update(mid + 1, r, idx * 2 + 2, loc, val);
	}
	return seg[idx];
}

int sum(int start, int end, int l, int r, int idx)
{
	if (r < start || l > end) return 0;
	if (start <= l && r <= end) return seg[idx];
	int mid = (l + r) / 2;
	return sum(start, end, l, mid, idx * 2 + 1) + sum(start, end, mid + 1, r, idx * 2 + 2);
}

int main()
{
	int in, in2;
	cin >> t;
	while (t--) {
		scanf("%d%d", &n, &m);
		construct(0, n + m - 1, 0);
		for (int i = 1; i <= n; i++)
			ary[i] = n - i;

		for (int i = 0; i < m; i++) {
			scanf("%d", &in);
			in2 = ary[in];
			update(0, n + m - 1, 0, in2, 0);
			printf("%d ", sum(in2, n + m - 1, 0, n + m - 1, 0));
			update(0, n + m - 1, 0, n + i, 1);
			ary[in] = n + i;
		}
		printf("\n");
	}
	return 0;
}

0개의 댓글