[C++] 백준 2236: 굉장한 학생

Cyan·2024년 7월 14일
0

코딩 테스트

목록 보기
157/166

백준 2236: 굉장한 학생

문제 요약

N명의 학생이 참여하여 세 번의 시험을 치렀다. N명의 학생들은 세 번의 시험에 모두 응시하였다. 조교는 각각의 시험에서 같은 등수의 학생이 한 명도 없도록 성적을 매겼다.

A라는 학생이 B라는 학생보다 세 번의 시험에서 모두 성적이 좋다면, A가 B보다 '대단하다'고 한다. 또, C라는 학생보다 '대단한' 학생이 한 명도 없으면, C를 '굉장하다'고 한다.

세 번의 시험에서 각 학생의 성적이 주어졌을 때, '굉장한' 학생의 수를 구하는 프로그램을 작성하시오.

문제 분류

  • 자료 구조
  • 세그먼트 트리

문제 풀이

정말 어렵다. 인터넷에 여러 글을 참고하여 겨우 풀었다.

이 문제의 요는 '굉장한 학생'의 정의에 있다. 어떠한 학생이 '굉장한 학생'인가 아닌가는, 그 학생보다 '대단한 학생'이 있는지 파악하면 된다. 즉, 어떤 학생이 그 학생보다 '대단한 학생'이 존재하지 않으면 그 학생은 '굉장한 학생'이 된다.

여기서 '대단한 학생'은 어떻게 찾는가?
우선 '높은 등수'라는 것은, 석차를 의미하는 수가 낮다는 것이고, '낮은 등수'라고 하면, 석차를 의미하는 수 높다는 것이다. 1등은 가장 높은 등수이고, n등은 가장 낮은 등수이다. '높다', '낮다'의 의미를 헷갈리지 않도록 주의한다.

우선 '굉장한 학생'은 세 과목의 등수가 하나라도 자기보다 높은 학생이 있어서는 안 된다. 즉, 자기보다 낮은 등수의 학생에 대해서는 고려하지 않아도 된다.
첫 번째 과목의 등수 수열을 aia_i, 두 번째 과목의 등수 수열을 bib_i, 세 번째 등수 수열을 cic_i라 하자. ii는 등수이고, aia_iii등을 한 학생의 번호이다.

우리는 가령 aia_i가 '굉장한 학생인지 알고 싶다. 그러면 0<j<i0 < j < i인 모든 aja_j에 대해서만 생각하면 된다. 즉, ii보다 낮은 등수는 고려하지 않아도 된다. 이미 ii의 등수가 더 높기 때문이다. 우리의 목표는 모든 과목에서 자신보다 높은 등수를 가진 학생이 있는가 없는가 파악하는 것이기 때문에, 자기보다 높은 등수의 학생들에만 관심을 기울이면 된다.

이번에는 aia_i번 학생이 bb수열에서는 몇 등을 했는지 알아내야 한다. 이 등수를 bibi라고 하면, bbi=aib_{bi} = a_ibibi를 찾으면 된다. 그리고 역시 0<bj<bi0 < bj < bi인 구간을 생각해본다. 거의 다 왔다!

이제 (0,bi)(0, bi)인 구간의 학생 들이 cic_i에서는 몇 등을 했는지 알아보면 되는데, 모두 알아볼 필요는 없고 최솟값만 알아내면 된다. aia_i번 학생의 cic_i에서의 등수 cici, 즉 cci=aic_{ci} = a_icici(0,bi)(0, bi) 구간의 학생들의 cc등수의 최솟값 cjcj를 구해서, cicicjcj를 비교하면 된다.

말로 풀어쓰면 굉장히 어려운데, 직접 예시로 생각해보면 할만하다.

예제 입력
2 5 3 8 10 7 1 6 9 4
1 2 3 4 5 6 7 8 9 10
3 8 7 10 5 4 1 2 6 9
에 대해 3번 학생을 생각해보자. 3번 학생은 '굉장한 학생'인가? 첫 번째 과목에서 3보다 등수가 높은 학생은 2, 5뿐이다. 이제 이 25가 나머지 두 과목에서도 높은 등수를 가져가는지 확인하면 된다. 즉, 두 번째, 세 번째 과목에서도 3보다 낮은 등수의 학생들에 대해서는 아예 고려하지 않는다.

두 번째 과목에서는 어떨까? 우선 3의 등수는 3(=bi)이다. 3(=bi)등보다 높은 구간, 즉 (0,3(=bi))(0, 3(=bi))의 등수 구간에 있는 번호들이 세 번째 과목에서 가장 높은 등수, 즉 최솟값을 구하면 된다. (0,3)(0,3)구간에 존재하는 학생들 중에 1은 고려하지 않고, 2만 남았는데, 세 번째 과목에서 2의 등수는 8이므로 (0,3)(0,3)구간에서의 세 번째 과목의 등수 최솟값은 8이다. 여기서 3의 등수는 1이고, 1 < 8 이므로 3번 학생은 '굉장한 학생'이라고 볼 수 있다.

이번에는 7번 학생에 대해 빠르게 판별해보자. 우선 첫 번째 과목에서 7번 학생보다 높은 등수의 학생들만을 고려하고서, 7번 학생의 두 번째 과목의 등수는 7(=bi)이다. (0,7)(0, 7)구간에 대해 세 번째 과목 등수의 가장 높은 등수(최솟값)을 구하면 다음과 같다.
1은 미고려, 28, 31, 4는 미고려, 55, 6은 미고려. 이 중 최솟값은 1이다. 여기서 7의 세 번째 과목의 등수는 3이므로, 3 < 1은 거짓이므로 7번 학생은 '굉장한 학생'이 아니다.

정리하자면, 어떤 학생(aa)보다 첫 번째 과목에서 더 높은 등수의 학생들에 대해서만 고려하여, 두 번째 과목에서 aa보다 높은 등수를 가진 학생들에 대해 세 번째 과목에서의 가장 높은 등수(최솟값)이 aa의 등수보다 낮다면 aa는 '굉장한 학생'이 되는 것이다.

이를 세그먼트 트리로 구축하는데, 리프 노드는 두 번째 과목의 등수, 값은 세 번째 과목의 등수이다. 즉, 세그먼트 트리의 구간을 두 번째 과목의 등수로 표현하여, 해당 구간에서 세 번째 과목 등수의 최솟값을 구해낸다. 다시 말해서, 두 번째 과목의 등수를 가진 학생이 세 번째 과목의 등수로 어떻게 바뀌었는지 표현된다.

세그먼트 트리는 기본적으로 MAX의 상수로 초기화된다. 아직 탐색하지 않은 다른 학생들을 배제하기 위함이다. 첫 번째 과목의 가장 높은 등수의 학생부터 탐색하며 트리를 업데이트해 나간다.

나는 이 문제를 풀 때에 등수와 학생 번호를 0번부터 세어서 해결하였다.

첫 과목의 등수 입력은 예제 입력처럼 i번 등수에 ary[i]번 학생으로 받아왔지만, 두 번째, 세 번째 입력은 i번 학생의 ary[i]번 등수로 입력받았다.

이는 i번째 학생의 등수를 O(1)O(1)로 얻어오기 위함이다.

입력은 2차원 배열로 받아 0번 행은 첫 번째 과목, 1번 행은 두 번째 과목, 2번 행은 세 번째 과목에 대한 입력이고, 입력 형식은 위에서 말한 바와 같다.

이후 ary[2][ary[0][i]]Min(0, n - 1, 0, 0, ary[1][ary[0][i]] - 1)을 비교하여, ary[2][ary[0][i]]가 더 작다면 ary[0][i]번 학생은 '굉장한 학생'이다.

풀이 코드

#include <stdio.h>
#include <iostream>
#include <algorithm>
#define MAX 500001

using namespace std;

int ary[3][500001], seg[2000000];

int construct(int l, int r, int idx)
{
	if (l == r) seg[idx] = MAX;
	else {
		int mid = (l + r) / 2;
		seg[idx] = min(construct(l, mid, idx * 2 + 1), construct(mid + 1, r, idx * 2 + 2));
	}
	return seg[idx];
}

int update(int l, int r, int idx, int loc, int val)
{
	if (l > loc || r < loc) return seg[idx];
	if (l == r) {
		if (l == loc) seg[idx] = val;
	}
	else {
		int mid = (l + r) / 2;
		seg[idx] = min(update(l, mid, idx * 2 + 1, loc, val), update(mid + 1, r, idx * 2 + 2, loc, val));
	}
	return seg[idx];
}

int Min(int l, int r, int idx, int L, int R)
{
	if (L > R) return MAX;
	if (r < L || l > R) return MAX;
	if (L <= l && r <= R) return seg[idx];
	int mid = (l + r) / 2;
	return min(Min(l, mid, idx * 2 + 1, L, R), Min(mid + 1, r, idx * 2 + 2, L, R));
}

int main()
{
	int n, in, cnt = 0;
	cin >> n;
	construct(0, n - 1, 0);

	for (int i = 0; i < n; i++) {
		scanf("%d", &ary[0][i]);
		ary[0][i]--;
	}
	for (int i = 1; i < 3; i++) {
		for (int j = 0; j < n; j++) {
			scanf("%d", &in);
			ary[i][in - 1] = j;
		}
	}
    
	for (int i = 0; i < n ; i++) {
		in = ary[0][i];
		update(0, n - 1, 0, ary[1][in], ary[2][in]);
		int mm = Min(0, n - 1, 0, 0, ary[1][in] - 1);
		if (ary[2][in] < mm)
			cnt++;
	}
	printf("%d", cnt);
	
	return 0;
}

0개의 댓글