[백준] 1167. 트리의 지름 (Python)

개미·2023년 2월 15일
0

알고리즘

목록 보기
1/12
post-custom-banner

📌 1167. 트리의 지름

https://www.acmicpc.net/problem/1167

풀이과정

1. 첫번째 시도

처음에 문제를 보고, 가장 긴 거리? 최단 경로 문제를 반대로 적용한건가? 라고 생각해서 플루워셜 알고리즘을 썼다. (지금 생각해보면 아무 논리도 없었음..) 그러다가 답이 맞지 않게 나온다는 것을 깨닫고 바로 다 지웠다.

2. 두번째 시도

DFS를 사용했다. (BFS 사용해도 된다)

n = int(input())

graph = [[] for _ in range(n+1)]
visited = [False]*(n+1) 
distance = [0]*(n+1)

for i in range(n):
  array = list(map(int, input().split()))[:-1]
  a = array[0]
  for j in range(1, len(array), 2):
    graph[a].append((array[j], array[j+1]))

def dfs(a, b):
  visited[a] = True
  for i in graph[a]:
    node, dist = i
    if visited[node] == False:
      distance[node] = max(distance[node], b+dist)
      dfs(node, max(distance[node], b+dist))

dfs(1, 0)
print(max(distance))

하지만 dfs(2,0)으로 바꿀 경우 다른 값이 나오는 문제가 발생했다. 따라서 for문을 돌면서 가장 큰 값을 도출하게 바꾸어주었다.

import sys
input = sys.stdin.readline
n = int(input())

graph = [[] for _ in range(n+1)]
visited = [False]*(n+1) 
distance = [0]*(n+1)

for i in range(n):
  array = list(map(int, input().split()))[:-1]
  a = array[0]
  for j in range(1, len(array), 2):
    graph[a].append((array[j], array[j+1]))

def dfs(a, b):
  visited[a] = True
  for i in graph[a]:
    node, dist = i
    if visited[node] == False:
      distance[node] = max(distance[node], b+dist)
      dfs(node, max(distance[node], b+dist))

result = 0
for i in range(1, n+1):
  dfs(i, 0)
  result = max(result, max(distance))
  visited = [False]*(n+1) 
  distance = [0]*(n+1)
print(result)

하지만 당연하게도 시간초과...

세번째 시도

효율적으로 코드를 완성하기 위해서는 수학적 개념을 하나 알고 있어야 한다.

노드 a에서 가장 먼 노드를 b라고 하면, b는 트리의 지름을 이루는 노드 중 하나이다.

증명은 https://blog.myungwoo.kr/112 참고

따라서 DFS를 실행시켜 가장 거리가 먼 노드를 찾은 후, 그 찾은 노드로 한번 더 DFS를 실행시켜야 한다.

💯 정답

import sys
input = sys.stdin.readline
n = int(input())

graph = [[] for _ in range(n+1)]
visited = [False]*(n+1) 
distance = [0]*(n+1)

for i in range(n):
  array = list(map(int, input().split()))[:-1]
  a = array[0]
  for j in range(1, len(array), 2):
    graph[a].append((array[j], array[j+1]))

def dfs(a, b):
  visited[a] = True
  for i in graph[a]:
    node, dist = i
    if visited[node] == False:
      distance[node] = max(distance[node], b+dist)
      dfs(node, max(distance[node], b+dist))

result = 0
idx = 1
dfs(1,0)
for i in range(1, n+1):
  if result < distance[i]:
    result = distance[i]
    idx = i

visited = [False]*(n+1) 
distance = [0]*(n+1)
dfs(idx,0)
print(max(distance))
profile
개발자
post-custom-banner

0개의 댓글