문제 : 백준 1613번 역사
틀린 풀이 : 위상 정렬, DFS
틀린 이유 : 연결되어 있지만 관계를 모르는 경우가 존재
맞은 풀이 : Floyd-Warshall, DP
정리
굳이 거리를 계산할 필요 없을 것 같아 bool 타입으로 바꿔보았다
Floyd-Warshall 알고리즘
다익스트라 알고리즘(N^2)을 모든 노드에 적용한 것(N)과 시간 복잡도 면에서 유사
import sys, bisect
def init():
n, k = map(int, sys.stdin.readline().rstrip().split(' '))
order_list = []
directed_list = [[] for _ in range(n+1)]
undirected_list = [[] for _ in range(n+1)]
for i in range(k):
start, end = map(int, sys.stdin.readline().rstrip().split(' '))
order_list.append((start, end))
directed_list[start].append(end)
undirected_list[start].append(end)
undirected_list[end].append(start)
s = int(sys.stdin.readline().rstrip())
pair_list = []
for i in range(s):
start, end = map(int, sys.stdin.readline().rstrip().split(' '))
pair_list.append((start, end))
part_list = [0 for _ in range(n+1)]
index_list = [-1 for _ in range(n+1)]
sorted_list = []
return n, k, order_list, s, pair_list, part_list, undirected_list, directed_list, sorted_list, index_list
def dfs_ts(visited, curr_num):
visited[curr_num] = True
for next_num in directed_list[curr_num]:
if not visited[next_num]:
dfs_ts(visited, next_num)
sorted_list.append(curr_num)
index_list[curr_num] = len(sorted_list) - 1
def set_sorted_list():
visited = [False for i in range(n+1)]
for i in range(1, n+1):
if not visited[i]:
dfs_ts(visited, i)
def dfs(curr_num, part):
part_list[curr_num] = part
for next_num in undirected_list[curr_num]:
if part_list[next_num] == 0:
dfs(next_num, part)
def set_part_list():
part = 1
for i in range(1, n+1):
if part_list[i] == 0:
dfs(i, part)
part += 1
def check(pair):
start, end = pair
if part_list[start] == part_list[end]:
if index_list[start] < index_list[end]:
return 1
else:
return -1
else:
return 0
n, k, order_list, s, pair_list, part_list, undirected_list, directed_list, sorted_list, index_list = init()
set_sorted_list()
set_part_list()
for pair in pair_list:
sys.stdout.write(f'{check(pair)}\n')
import sys
def init():
n, k = map(int, sys.stdin.readline().rstrip().split(' '))
order_list = [tuple(map(int, sys.stdin.readline().rstrip().split(' '))) for _ in range(k)]
s = int(sys.stdin.readline().rstrip())
pair_list = [tuple(map(int, sys.stdin.readline().rstrip().split(' '))) for _ in range(s)]
# dist_list = [[float('inf') for _ in range(n+1)] for _ in range(n+1)]
dist_list = [[False for _ in range(n+1)] for _ in range(n+1)]
return n, k, order_list, s, pair_list, dist_list
n, k, order_list, s, pair_list, dist_list = init()
for start, end in order_list:
# dist_list[start][end] = 1
dist_list[start][end] = True
for k in range(1, n+1):
for i in range(1, n+1):
for j in range(1, n+1):
# dist_list[i][j] = min(dist_list[i][j], dist_list[i][k]+dist_list[k][j])
dist_list[i][j] = dist_list[i][j] or (dist_list[i][k] and dist_list[k][j])
for start, end in pair_list:
# if dist_list[start][end] == float('inf') and dist_list[end][start] == float('inf'):
# sys.stdout.write('0\n')
# elif dist_list[start][end] == float('inf'):
# sys.stdout.write('1\n')
# elif dist_list[end][start] == float('inf'):
# sys.stdout.write('-1\n')
if not dist_list[start][end] and not dist_list[end][start]:
sys.stdout.write('0\n')
elif not dist_list[start][end]:
sys.stdout.write('1\n')
elif not dist_list[end][start]:
sys.stdout.write('-1\n')