[ BOJ / Python ] 1240번 노드사이의 거리

황승환·2022년 2월 2일
0

Python

목록 보기
146/498


이번 문제는 깊이우선탐색으로 해결하려 하였지만 시간초과가 발생하여서 너비우선탐색으로 해결하게 되었다. 우선 깊이우선탐색으로 풀이는 다음과 같다.

def dfs(cur, cnt):
    if cur==b:
        return cnt
    visited[cur]=True
    for i in range(n+1):
        if tree[cur][i]>0 and visited[i]==False:
            visited[i]=True
            cnt=dfs(i, cnt+tree[cur][i])
    return cnt
n, m=map(int, input().split())
tree=[[0]*(n+1) for _ in range(n+1)]
for _ in range(n-1):
    s, e, d=map(int, input().split())
    tree[s][e]=d
    tree[e][s]=d
for _ in range(m):
    a, b=map(int, input().split())
    visited=[False for _ in range(n+1)]
    print(dfs(a, 0))

파이참에서는 원하는 값을 잘 출력해주었지만 시간초과가 발생하였고, 다른 사람들의 풀이를 찾아보니 모두 BFS로 접근하여 해결한 것을 확인할 수 있었다. 방문처리와 동시에 시작 노드로부터의 거리를 저장하는데 사용할 리스트 visited를 이용하여 너비우선탐색으로 거리를 측정하였다. 너비우선탐색의 원리대로 큐가 완전히 비게 되면 반복이 종료되고 만약 그 전에 도착점에 도달하면 반복이 종료된다.

  • 큐를 사용하기 위해 deque를 가져온다.
  • bfs 함수를 start, end를 인자로 갖도록 하여 선언한다.
    -> 방문처리와 거리계산에 사용할 리스트 visited를 -1 n+1개로 채운다.
    -> visited[start]를 0으로 갱신한다.
    -> 큐로 사용할 deque q를 선언한다.
    -> q에 start를 넣는다.
    -> q가 빌때까지 반복하는 while문을 돌린다.
    --> q의 가장 앞의 수를 지우고 그 값을 임시변수 cur에 저장한다.
    --> 만약 cur이 end와 같다면 반복문을 종료한다.
    --> tree[cur]을 순회하는 next, dist에 대한 for문을 돌린다.
    ---> 만약 visited[next]가 -1보다 클 경우, 반복을 계속 진행한다.
    ----> visited[next]를 visited[cur]+dist로 갱신한다.
    ----> q에 next를 넣어준다.
    -> visited[end]를 반환한다.
  • n과 m을 입력받는다.
  • tree를 n+1개의 리스트로 선언한다.
  • n-1번 반복하는 for문을 돌린다.
    -> a, b, d를 입력받는다.
    -> tree[a]에 (b, d)를 넣는다.
    -> tree[b]에 (a, d)를 넣는다.
  • m번 반복하는 for문을 돌린다.
    -> s, e를 입력받는다.
    -> bfs(s, e)를 출력한다.

Code

from collections import deque
def bfs(start, end):
    visited = [-1] * (n + 1)
    visited[start]=0
    q=deque()
    q.append(start)
    while q:
        cur=q.popleft()
        if cur==end:
            break
        for next, dist in tree[cur]:
            if visited[next]>-1:
                continue
            visited[next]=visited[cur]+dist
            q.append(next)
    return visited[end]
n, m=map(int ,input().split())
tree=[[] for _ in range(n+1)]
for _ in range(n-1):
    a, b, d=map(int, input().split())
    tree[a].append((b, d))
    tree[b].append((a, d))
for _ in range(m):
    s, e=map(int, input().split())
    print(bfs(s, e))

profile
꾸준함을 꿈꾸는 SW 전공 학부생의 개발 일기

0개의 댓글