사이클 찾기 문제여서 union-find 알고리즘을 쓰려고 했는데 시간초과가 났다.
노드 개수는 100,000개이지만,
테스트케이스가 여러 개 주어지기 때문에 여기서 시간초과가 나는 듯 하다.
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**9)
def union_parent(x,y):
px = find_parent(x)
py = find_parent(y)
# 더 작은 숫자가 부모가 되도록
if px<py:
parent[y] = px
else:
parent[x] = py
def find_parent(x):
if parent[x]!=x:
parent[x] = find_parent(parent[x])
return parent[x]
for _ in range(int(input())):
n = int(input())
answer = n
parent = [i for i in range(n+1)] # 1번부터 시작
for n1,n2 in enumerate(list(map(int,input().split()))):
n1 += 1
if n1 == n2:
answer -= 1
else:
# 루트 노드가 서로 같다면 사이클이 발생한 것
if find_parent(n1) == find_parent(n2):
answer -= parent.count(parent[n1])
else:
union_parent(n1,n2)
print(answer)
- 애초에 union-find를 이용한 사이클 판별은 그래프가 '무방향'일 때 쓰는 것이다.
- 그래프가 방향 그래프일 때는, dfs로 사이클을 판별해야 한다.
import sys
input = sys.stdin.readline
sys.setrecursionlimit(111111)
def dfs(node):
global answer
visited[node]=True
cycle.append(node)
nxt = graph[node]
if visited[nxt]: ########## 이 부분 생략하면 시간초과
if nxt in cycle:
answer -= (len(cycle)-cycle.index(nxt))
return
if not visited[nxt]:
dfs(nxt)
for _ in range(int(input())): # 테스트케이스
n = int(input())
graph = [-1]+list(map(int,input().split()))
visited = [False]*(n+1) # 1번부터 시작
answer = n
for x in range(1,n+1):
if not visited[x]:
cycle = []
dfs(x)
print(answer)
if nxt in cycle:
answer -= (len(cycle)-cycle.index(nxt))
return
if visited[nxt]:
if nxt in cycle:
answer -= (len(cycle)-cycle.index(nxt))
return