[Python] 백준 #2887 행성 터널

이재원·2023년 10월 25일

Algorithm

목록 보기
26/29

📚문제: #2887 행성 터널(Platinum 5)

때는 2040년, 이민혁은 우주에 자신만의 왕국을 만들었다. 왕국은 N개의 행성으로 이루어져 있다. 민혁이는 이 행성을 효율적으로 지배하기 위해서 행성을 연결하는 터널을 만들려고 한다.

행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.

민혁이는 터널을 총 N-1개 건설해서 모든 행성이 서로 연결되게 하려고 한다. 이때, 모든 행성을 터널로 연결하는데 필요한 최소 비용을 구하는 프로그램을 작성하시오.

입력

첫째 줄에 행성의 개수 N이 주어진다. (1 ≤ N ≤ 100,000) 다음 N개 줄에는 각 행성의 x, y, z좌표가 주어진다. 좌표는 -109보다 크거나 같고, 109보다 작거나 같은 정수이다. 한 위치에 행성이 두 개 이상 있는 경우는 없다.

출력

첫째 줄에 모든 행성을 터널로 연결하는데 필요한 최소 비용을 출력한다.

예제 모음

(입력)
5
11 -15 -15
14 -5 -15
-1 -1 -5
10 -4 -1
19 -4 19

(출력)
4

아이디어 및 구현

  • 모든 행성에서 다른 모든 행성으로 가는 간선비용을 계산하여 그래프를 구성할 시 행성의 개수가 최대 100,000개 이므로 간선은 최대 50억개로 구성되며 이는 메모리초과를 유발한다.
  • 간선의 선택 후보군을 줄이기 위하여 정렬을 활용한다.
    • 각 행성의 x, y, z값을 기준으로 정렬한 뒤 인접한 행성끼리의 간선 정보를 그래프에 반영한다.
    • N이 행성의 개수일 때 최소 스패닝 트리를 구성하기 위한 간선의 후보군은 기존의 N(N-1)/2개에서 3(N-1)개로 감소한다. 최대 간선의 개수는 약 300,000개이다.
# 행성 좌표 셋
planet_loc = []

# 행성 (0번 ~ N-1번)
planet = [[] for _ in range(N)]

# 각 행성의 x, y, z좌표가 주어진다. i는 행성의 정보이다.
for i in range(N):

    x, y, z = map(int, sys.stdin.readline().split())

    planet_loc.append([x,y,z,i])

# x,y,z좌표에 대한 오름차순 정렬
planet_loc_x_sort = sorted(planet_loc, key=lambda t:t[0])
planet_loc_y_sort = sorted(planet_loc, key=lambda t:t[1])
planet_loc_z_sort = sorted(planet_loc, key=lambda t:t[2])

# 간선 추가(x, y, z좌표 별로 오름차순 정렬한 데이터를 기준으로)
for i in range(N-1):

    x_cost = abs(planet_loc_x_sort[i+1][0]-planet_loc_x_sort[i][0])
    
    n1, m1 = planet_loc_x_sort[i+1][-1], planet_loc_x_sort[i][-1]

    planet[n1].append((m1, x_cost))
    planet[m1].append((n1, x_cost))

    y_cost = abs(planet_loc_y_sort[i+1][1]-planet_loc_y_sort[i][1])

    n2, m2 = planet_loc_y_sort[i+1][-1], planet_loc_y_sort[i][-1]

    planet[n2].append((m2, y_cost))
    planet[m2].append((n2, y_cost))

    z_cost = abs(planet_loc_z_sort[i+1][2]-planet_loc_z_sort[i][2])

    n3, m3 = planet_loc_z_sort[i+1][-1], planet_loc_z_sort[i][-1]

    planet[n3].append((m3, z_cost))
    planet[m3].append((n3, z_cost))
  • 프림 알고리즘으로 N-1개의 간선으로 구성된 최소 스패닝 트리 비용을 계산한다.
def prim(start):

    # MST 비용
    total = 0

    # 간선 선택 횟수
    edge = 0

    # 우선순위 큐
    q = []

    # 포함된 행성 리스트
    mst = set()

    # 시작노드 처리
    mst.add(start)

    # 시작노드와 붙어있는 간선을 큐에 추가 (비용, 노드)
    for neighbor in planet[start]:

        heappush(q, (neighbor[1], neighbor[0]))
    
    # 큐가 빌 때까지 반복
    while q:

        # 비용과 행성꺼내기
        cost, cur = heappop(q)

        # 리스트에 이미 포함되어있으면 skip
        if cur in mst:

            continue
        
        else:

            # 전체 비용에 누적
            total += cost

            # 간선 선택 1회 추가
            edge += 1

            # MST에 노드 추가
            mst.add(cur)

            # N-1개의 간선이 획득되면 최소 스패닝 트리 만족, 종료
            if edge == N-1:

                break

            # cur의 이웃들을 살펴봅니다.
            for neighbor in planet[cur]:
                
                # mst 구성되지 않은 노드는 후보에 추가
                if neighbor[0] not in mst:

                    heappush(q, (neighbor[1], neighbor[0]))

    # 전체 간선 비용
    print(total)

전체 코드

import sys
from heapq import heappush, heappop

def prim(start):

    # MST 비용
    total = 0

    # 간선 선택 횟수
    edge = 0

    # 우선순위 큐
    q = []

    # 포함된 행성 리스트
    mst = set()

    # 시작노드 처리
    mst.add(start)

    # 시작노드와 붙어있는 간선을 큐에 추가 (비용, 노드)
    for neighbor in planet[start]:

        heappush(q, (neighbor[1], neighbor[0]))
    
    # 큐가 빌 때까지 반복
    while q:

        # 비용과 행성꺼내기
        cost, cur = heappop(q)

        # 리스트에 이미 포함되어있으면 skip
        if cur in mst:

            continue
        
        else:

            # 전체 비용에 누적
            total += cost

            # 간선 선택 1회 추가
            edge += 1

            # MST에 노드 추가
            mst.add(cur)

            # N-1개의 간선이 획득되면 최소 스패닝 트리 만족, 종료
            if edge == N-1:

                break

            # cur의 이웃들을 살펴봅니다.
            for neighbor in planet[cur]:
                
                # mst 구성되지 않은 노드는 후보에 추가
                if neighbor[0] not in mst:

                    heappush(q, (neighbor[1], neighbor[0]))

    # 전체 간선 비용
    print(total)

# 행성의 개수 N이 주어집니다.
N = int(sys.stdin.readline().rstrip())

# 행성 좌표 셋
planet_loc = []

# 행성 (0번 ~ N-1번)
planet = [[] for _ in range(N)]

# 각 행성의 x, y, z좌표가 주어진다.
for i in range(N):

    x, y, z = map(int, sys.stdin.readline().split())

    planet_loc.append([x,y,z,i])

# x,y,z좌표에 대한 오름차순 정렬
planet_loc_x_sort = sorted(planet_loc, key=lambda t:t[0])
planet_loc_y_sort = sorted(planet_loc, key=lambda t:t[1])
planet_loc_z_sort = sorted(planet_loc, key=lambda t:t[2])

# 간선 추가(x, y, z좌표 별로 오름차순 정렬한 데이터를 기준으로)
for i in range(N-1):

    x_cost = abs(planet_loc_x_sort[i+1][0]-planet_loc_x_sort[i][0])
    
    n1, m1 = planet_loc_x_sort[i+1][-1], planet_loc_x_sort[i][-1]

    planet[n1].append((m1, x_cost))
    planet[m1].append((n1, x_cost))

    y_cost = abs(planet_loc_y_sort[i+1][1]-planet_loc_y_sort[i][1])

    n2, m2 = planet_loc_y_sort[i+1][-1], planet_loc_y_sort[i][-1]

    planet[n2].append((m2, y_cost))
    planet[m2].append((n2, y_cost))

    z_cost = abs(planet_loc_z_sort[i+1][2]-planet_loc_z_sort[i][2])

    n3, m3 = planet_loc_z_sort[i+1][-1], planet_loc_z_sort[i][-1]

    planet[n3].append((m3, z_cost))
    planet[m3].append((n3, z_cost))

# Prim Algorithm Execute
prim(0)

0개의 댓글