[Python] 백준 #17472 다리 만들기 2

이재원·2023년 10월 30일

Algorithm

목록 보기
27/29

📚문제: #17142 다리 만들기 2(Gold 1)

섬으로 이루어진 나라가 있고, 모든 섬을 다리로 연결하려고 한다. 이 나라의 지도는 N×M 크기의 이차원 격자로 나타낼 수 있고, 격자의 각 칸은 땅이거나 바다이다.섬은 연결된 땅이 상하좌우로 붙어있는 덩어리를 말하고, 아래 그림은 네 개의 섬으로 이루어진 나라이다. 색칠되어있는 칸은 땅이다.

다리는 바다에만 건설할 수 있고, 다리의 길이는 다리가 격자에서 차지하는 칸의 수이다. 다리를 연결해서 모든 섬을 연결하려고 한다. 섬 A에서 다리를 통해 섬 B로 갈 수 있을 때, 섬 A와 B를 연결되었다고 한다. 다리의 양 끝은 섬과 인접한 바다 위에 있어야 하고, 한 다리의 방향이 중간에 바뀌면 안된다. 또, 다리의 길이는 2 이상이어야 한다.

다리의 방향이 중간에 바뀌면 안되기 때문에, 다리의 방향은 가로 또는 세로가 될 수 밖에 없다. 방향이 가로인 다리는 다리의 양 끝이 가로 방향으로 섬과 인접해야 하고, 방향이 세로인 다리는 다리의 양 끝이 세로 방향으로 섬과 인접해야 한다.

섬 A와 B를 연결하는 다리가 중간에 섬 C와 인접한 바다를 지나가는 경우에 섬 C는 A, B와 연결되어있는 것이 아니다.

아래 그림은 섬을 모두 연결하는 올바른 2가지 방법이고, 다리는 회색으로 색칠되어 있다. 섬은 정수, 다리는 알파벳 대문자로 구분했다.
image
다리의 총 길이: 13, D는 2와 4를 연결하는 다리이고, 3과는 연결되어 있지 않다.
image
다리의 총 길이: 9 (최소)

나라의 정보가 주어졌을 때, 모든 섬을 연결하는 다리 길이의 최솟값을 구해보자.

입력

첫째 줄에 지도의 세로 크기 N과 가로 크기 M이 주어진다. 둘째 줄부터 N개의 줄에 지도의 정보가 주어진다. 각 줄은 M개의 수로 이루어져 있으며, 수는 0 또는 1이다. 0은 바다, 1은 땅을 의미한다.

출력

모든 섬을 연결하는 다리 길이의 최솟값을 출력한다. 모든 섬을 연결하는 것이 불가능하면 -1을 출력한다.

예제 모음

(입력1)
7 8
0 0 0 0 0 0 1 1
1 1 0 0 0 0 1 1
1 1 0 0 0 0 0 0
1 1 0 0 0 1 1 0
0 0 0 0 0 1 1 0
0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1

(출력1)
9

(입력2)
7 8
0 0 0 1 1 0 0 0
0 0 0 1 1 0 0 0
1 1 0 0 0 0 1 1
1 1 0 0 0 0 1 1
1 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1

(출력2)
10

아이디어 및 구현

  • 주어진 지도에서 BFS 알고리즘을 사용해 섬으로 이루어진 나라를 구분한다.
# 섬을 구분하는 BFS 알고리즘
def BFS(start, num):
    
    # 4방향
    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]
    
    # 시작지점
    start_r, start_c = start
    
    # 시작 노드 방문 처리
    graph[start_r][start_c] = num
    
    # 큐
    queue = deque([start])
    
    # 큐가 빌 때까지 반복
    while queue:
        
        cur_r, cur_c = queue.popleft()
        
        for i in range(4):
            
            move_r, move_c = cur_r + dr[i], cur_c + dc[i]
            
            # 영역을 벗어나지 않을 때
            if 0 <= move_r <= N-1 and 0 <= move_c <= M-1:
                
                if graph[move_r][move_c] == '*':
                    
                    # 방문처리
                    graph[move_r][move_c] = num
                    
                    # 큐에 추가
                    queue.append([move_r, move_c])
  • 프림 알고리즘을 활용하여 최소 스패닝 트리를 구성한다.
def prim(start):
    
    # 전체비용
    total = 0
    
    # 간선 개수
    edge = 0
    
    # 우선순위 큐
    q = []
    
    # MST
    mst = set()
    
    # 시작노드 처리
    mst.add(start)
    
    # 시작노드와 붙어있는 간선을 후보군에 추가 (비용, 이웃노드)
    for neighbor in island[start]:
        
        heappush(q, (neighbor[1], neighbor[0]))
        
    # 큐가 빌 때까지 반복
    while q:
        
        cost, cur = heappop(q)
        
        # MST에 이미 포함된 노드이면 skip
        if cur in mst:
            
            continue
        
        else:
            
            # 전체비용 누적
            total += cost
            
            # 간선개수 누적
            edge += 1
            
            # mst에 추가
            mst.add(cur)
            
            # 간선이 섬의개수-1개이면 종료
            if edge == island_cnt-2:
                
                break
            
            # 현재노드와 이어진 간선을 후보군에 추가
            for neighbor in island[cur]:
                
                # mst에 포함된 노드가 아닐 때
                if neighbor[0] not in mst:
                    
                    heappush(q, (neighbor[1], neighbor[0]))

    # 전체비용 출력, 그래프의 간선이 (섬의개수-1)개이면 최소 스패닝 트리가 구성된 것. 
    if edge == island_cnt-1:
        
        print(total)
    
    # 간선이 (섬의개수-1)개이면 최소 스패닝 트리가 구성되지 않은 것
    else:
        
        print(-1)
  • 구분된 섬 사이의 거리를 모두 구해서 그래프 관계에 추가한다. 이 때 특정 두 섬 사이의 동일한 거리를 가지는 간선은 중복 제거를 통해 1개만 유지시킨다.
# 4방향
di = [-1, 1, 0, 0]
dj = [0, 0, -1, 1]

# 섬 사이의 거리를 모두 구해서 그래프 관계에 추가한다.
for i in range(N):
    
    for j in range(M):
        
        # 섬일 때
        if graph[i][j] > 0:
            
            # 4방향 조사
            for k in range(4):
                
                move_i, move_j = i + di[k], j + dj[k]
                
                dist = 0
                
                # 다른 섬과의 길이를 구하기 위해 반복
                while True:
                    
                    # 주어진 범위를 벗어나지 않을 때
                    if 0 <= move_i <= N-1 and 0 <= move_j <= M-1:
                        
                        # 바다일 때 거리를 1 추가
                        if graph[move_i][move_j] == 0:
                            
                            dist += 1
                            
                            # 한칸 이동
                            move_i += di[k]
                            move_j += dj[k]
                        
                        # 같은 섬일 때는 skip
                        elif graph[move_i][move_j] == graph[i][j]:
                            
                            break
                        
                        # 다른 섬을 만났을 때
                        elif graph[move_i][move_j] != graph[i][j]:
                            
                            # 거리가 1이면 break, 필요없는 간선
                            if dist == 1:
                                
                                break
                            
                            else:
                                
                                # 양방향 간선을 추가
                                island[graph[i][j]].append((graph[move_i][move_j], dist))
                                island[graph[move_i][move_j]].append((graph[i][j], dist))
                            
                            break
                            
                    # 영역을 벗어나면 
                    else:
                    
                        break

# 간선 중복 제거
for i in range(1, island_cnt):

    island[i] = list(set(island[i]))

전체 코드

import sys
from heapq import heappush, heappop
from collections import deque

# 섬을 구분하는 BFS 알고리즘
def BFS(start, num):
    
    # 4방향
    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]
    
    # 시작지점
    start_r, start_c = start
    
    # 시작 노드 방문 처리
    graph[start_r][start_c] = num
    
    # 큐
    queue = deque([start])
    
    # 큐가 빌 때까지 반복
    while queue:
        
        cur_r, cur_c = queue.popleft()
        
        for i in range(4):
            
            move_r, move_c = cur_r + dr[i], cur_c + dc[i]
            
            # 영역을 벗어나지 않을 때
            if 0 <= move_r <= N-1 and 0 <= move_c <= M-1:
                
                if graph[move_r][move_c] == '*':
                    
                    # 방문처리
                    graph[move_r][move_c] = num
                    
                    # 큐에 추가
                    queue.append([move_r, move_c])
                    
def prim(start):
    
    # 전체비용
    total = 0
    
    # 간선 개수
    edge = 0
    
    # 우선순위 큐
    q = []
    
    # MST
    mst = set()
    
    # 시작노드 처리
    mst.add(start)
    
    # 시작노드와 붙어있는 간선을 후보군에 추가 (비용, 이웃노드)
    for neighbor in island[start]:
        
        heappush(q, (neighbor[1], neighbor[0]))
        
    # 큐가 빌 때까지 반복
    while q:
        
        cost, cur = heappop(q)
        
        # MST에 이미 포함된 노드이면 skip
        if cur in mst:
            
            continue
        
        else:
            
            # 전체비용 누적
            total += cost
            
            # 간선개수 누적
            edge += 1
            
            # mst에 추가
            mst.add(cur)
            
            # 간선이 섬의개수-1개이면 종료
            if edge == island_cnt-1:
                
                break
            
            # 현재노드와 이어진 간선을 후보군에 추가
            for neighbor in island[cur]:
                
                # mst에 포함된 노드가 아닐 때
                if neighbor[0] not in mst:
                    
                    heappush(q, (neighbor[1], neighbor[0]))

    # 전체비용 출력
    if edge == island_cnt-1:
        
        print(total)
    
    else:
        
        print(-1)
    
# N, M이 주어진다.
N, M = map(int, sys.stdin.readline().split())

# 지도
graph = []

# N개의 줄에 지도의 정보가 주어진다. 0은 바다, 1은 땅
for _ in range(N):
    
    graph.append(list(map(int, sys.stdin.readline().split())))

# 섬을 '*'로 표현
for i in range(N):
    
    for j in range(M):
        
        if graph[i][j] == 1:
            
            graph[i][j] = '*'
            
# 그래프를 순회하며 섬을 구분
island_cnt = 1

for i in range(N):
    
    for j in range(M):
        
        # 섬을 발견하면 BFS 실행
        if graph[i][j] == '*':
            
            BFS([i,j], island_cnt)
            
            # 섬 번호 증가
            island_cnt += 1

# 섬 생성 (1번 ~ N번)
island = [[] for _ in range(island_cnt)]

# 섬 사이의 거리를 모두 구해서 그래프 관계에 추가한다.

# 4방향
di = [-1, 1, 0, 0]
dj = [0, 0, -1, 1]

for i in range(N):
    
    for j in range(M):
        
        # 섬일 때
        if graph[i][j] > 0:
            
            for k in range(4):
                
                move_i, move_j = i + di[k], j + dj[k]
                
                dist = 0
                
                # 다른 섬과의 길이를 구하기 위해 반복
                while True:
                    
                    if 0 <= move_i <= N-1 and 0 <= move_j <= M-1:
                        
                        # 바다일 때 거리를 1 추가
                        if graph[move_i][move_j] == 0:
                            
                            dist += 1
                            
                            # 한칸 이동
                            move_i += di[k]
                            move_j += dj[k]
                        
                        # 같은 섬일 때는 skip
                        elif graph[move_i][move_j] == graph[i][j]:
                            
                            break
                        
                        # 다른 섬을 만났을 때
                        elif graph[move_i][move_j] != graph[i][j]:
                            
                            # 거리가 1이면 break
                            if dist == 1:
                                
                                break
                            
                            else:
                                
                                # 간선을 추가
                                island[graph[i][j]].append((graph[move_i][move_j], dist))
                                island[graph[move_i][move_j]].append((graph[i][j], dist))
                            
                            break
                            
                    # 영역을 벗어나면 
                    else:
                    
                        break

# 간선 중복 제거
for i in range(1, island_cnt):

    island[i] = list(set(island[i]))

island_cnt -= 1

# 프림 알고리즘 실행
prim(1)

0개의 댓글