Leetcode 743 - Network Delay Time

이두현·2021년 12월 29일
0

Leetcode 743

Dijkstra

언어: python3

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        # dijkstra 
        dist = [-1 for i in range(n+1)]
        current = set()
        current.add(k)
        pq =[] # pq contains (dist to start node, node #)
        dist[k] = 0
        heapq.heappush(pq,(0,k)) 
        
        graph = collections.defaultdict(list)   # contains (destination,cost)
        for time in times:
            u,v,w = time
            graph[u].append((v,w))
        
        while pq:
            d,node = heapq.heappop(pq)
            current.add(node)
            for adj,cost in graph[node]:
                if adj not in current:
                    '''
                    if dist[adj] == -1:
                        dist[adj] = d + cost
                    else:
                        dist[adj] = min(dist[adj],d+cost)
                    heapq.heappush(pq,(dist[adj],adj))
                    '''
                    if dist[adj] == -1 or dist[adj] > d + cost:
                        dist[adj] = d + cost
                        heapq.heappush(pq,(dist[adj],adj))
                        
        for index,i in enumerate(dist):
            if index!=0 and i==-1:
                return -1
        return max(dist)
  • 주석 부분으로 하면 heapq에 이미 추가된 불필요한 값이 다시 추가되어memory 에러가 난다.
  • 아래와 같이 수정하여 조건을 제시하였다.

faster version

class Solution(object):
    def networkDelayTime(self, times, n, k):
        """
        :type times: List[List[int]]
        :type n: int
        :type k: int
        :rtype: int
        """
        graph = collections.defaultdict(list)
        for u,v,w in times:
            graph[u].append((v,w))
            
        Q = [(0,k)] # (dist,node)
        dist = collections.defaultdict(int)
        
        while Q:
            time,node = heapq.heappop(Q)
            if node not in dist:
                dist[node] = time
            if dist[node] < time:
                continue
            for v,w in graph[node]:
                if v not in dist or dist[node]+w < dist[v]:
                    dist[v] = dist[node]+w
                    heapq.heappush(Q,(dist[v],v))
        
        if len(dist) == n:
            return max(dist.values())
        return -1
  • 차이점: dist 배열을 선언해놓지 않고 dictionary로 만들어 놓고 크기가 n인지 확인하여 모두 방문가능한지 판단한다
profile
0100101

0개의 댓글