문제 링크
https://www.acmicpc.net/problem/16566
N,M 이 400만으로 크다. O(NlogN) 정도의 알고리즘을 생각해야 한다.
상대가 낼 카드보다 큰 카드 중 가작 작은 녀석을 내면 되는 문제이다.
나의 카드들을 정렬하고, 상대의 카드가 나오면 그것에 맞춰서 적절한 카드를 주면 시간을 줄일 수 있을 것이다.
백준 10775 : 공항 문제와 거의 똑같다
https://velog.io/@sunkyuj/python-%EB%B0%B1%EC%A4%80-10775-%EA%B3%B5%ED%95%AD
풀이 선 요약
- 내 카드 정렬
- 상대의 모든 카드에 대해 반복
2-1. 상대 카드보다 큰 카드 중 가작 작은 카드의idx
찾기 (이분탐색)
2-2. 그 카드가 가리키는 disjoint set을 출력
2-3. 선택된 카드가 다음 disjoint set을 가리키도록union
n, m, k = map(int, input().split())
my = list(map(int, input().split()))
root = [i for i in range(m + 1)]
enemy = list(map(int, input().split()))
my.sort()
우선 값을 입력받고, 내 카드를 정렬한다.
def find(v):
if v == root[v]:
return v
root[v] = find(root[v])
return root[v]
def union(v1, v2): # union v1 to v2 (v1 -> v2)
r1 = find(v1)
r2 = find(v2)
root[r1] = r2
def ub(key):
s, e = 0, m - 1
while s <= e:
mid = (s + e) // 2
if my[mid] <= key:
s = mid + 1
else:
e = mid - 1
return s # return idx
다음은 쓸 함수들을 정의했다. Union-Find 함수들과 upper bound 를 찾아주는 함수이다.
for num in enemy:
idx = ub(num)
choice_idx = find(idx)
print(my[choice_idx])
union(choice_idx, choice_idx + 1)
마지막으로 본 알고리즘이다. 생각보다 매우 짧다...
그림으로 어떻게 돌아가는지 알아보자
우선 나의 카드 목록이다. 저 화살표는 root
가 자기 자신을 가리키고 있다는 것이다.
다음은 상대방이 낼 카드를 순서대로 나타낸 것이다
4 1 1 3 8
먼저 4를 상대방이 냈으니 우리는 4보다 큰 카드 중 가장 작은 카드인 5를 낼 것이다. 마찬가지로 상대가 1을 냈을 때 우리는 2를 낼 것이다.
그러면 root[5]
는 7을, root[2]
는 3을 가리키게 된다 (다음 disjoint set과 union
).
여기가 포인튼데, 상대방의 세 번째 숫자인 1을 내면, 우리는 그것 보다 큰 수 중 제일 작은 2를 내야 한다.
하지만 이때 그냥 2를 내는 것이 아닌, 2가 가리키는 disjoint set을 내는 것이다.
find
연산을 통해 갈 수 있는 데까지 가는 것이다!
만약 3 역시 다른곳을 가리키고 있다면 계속해서 나아간다.
이러한 방식으로 문제를 해결할 수 있다.
import sys
sys.setrecursionlimit(10 ** 8) # pypy 제출시 삭제!
input = lambda: sys.stdin.readline().rstrip()
# in_range = lambda y,x: 0<=y<n and 0<=x<m
def find(v):
if v == root[v]:
return v
root[v] = find(root[v])
return root[v]
def union(v1, v2): # union v1 to v2 (v1 -> v2)
r1 = find(v1)
r2 = find(v2)
root[r1] = r2
def ub(key):
s, e = 0, m - 1
while s <= e:
mid = (s + e) // 2
if my[mid] <= key:
s = mid + 1
else:
e = mid - 1
return s # return idx
n, m, k = map(int, input().split())
my = list(map(int, input().split()))
root = [i for i in range(m + 1)]
enemy = list(map(int, input().split()))
my.sort()
for num in enemy:
idx = ub(num)
choice_idx = find(idx)
print(my[choice_idx])
union(choice_idx, choice_idx + 1)