https://www.acmicpc.net/problem/17472
MST 응용 문제로 BFS, DFS만 섞어주면 된다.
섬을 찾아 구분해주고 다리를 모두 만들어보면 된다. N, M이 작기 때문에 시간은 여유롭다.
다리를 만들 때 모든 (i, j) 에 대해 상하좌우 모든 방향으로 다리를 이어본다.
다른 섬에 도달하고 길이가 2보다 작지 않다면 저장해두었다가 MST로 만들어 답을 낸다.
모든 섬이 이어지지 않을 수 있기 때문에 visit의 size로 하면 된다.
코드에서 landMark는 각 섬의 고유번호라고 생각하면 된다.
입력에서 0과 1을 썼으므로 편의상 2부터 숫자를 매겼다.
Java는 프림 알고리즘, Python은 크루스칼 알고리즘으로 해결했다.
import java.io.*;
import java.util.*;
class Point {
int y;
int x;
public Point(int y, int x) {
this.y = y;
this.x = x;
}
}
class Pair implements Comparable<Pair>{
int dest;
int cost;
public Pair(int dest, int cost) {
this.dest = dest;
this.cost = cost;
}
public int compareTo(Pair o) {
return cost - o.cost;
}
}
public class Main {
private static final int INIT_LANDMARK = 2;
private static final int[] dx = {0,0,1,-1};
private static final int[] dy = {1,-1,0,0};
private static int[][] graph;
private static ArrayList<Pair>[] line;
private static int N;
private static int M;
public static void main(String[] args) throws IOException {
Main main = new Main();
BufferedReader br = main.getReader();
int[] temp = main.getInt(br);
N = temp[0];
M = temp[1];
graph = new int[N][M];
for (int i=0; i<N; i++) {
temp = main.getInt(br);
for(int j=0; j<M; j++) {
graph[i][j] = temp[j];
}
}
System.out.println(main.solve());
}
private int solve() {
int landMark = INIT_LANDMARK;
for(int i=0; i<N; i++) {
for(int j=0; j<M; j++) {
if (graph[i][j] != 1) continue;
findLand(i, j, landMark);
landMark++;
}
}
line = new ArrayList[landMark];
for (int i=INIT_LANDMARK; i<landMark; i++) {
line[i] = new ArrayList<Pair>();
}
for(int i=0; i<N; i++) {
for(int j=0; j<M; j++) {
if(graph[i][j] == 0) continue;
for(int dir=0; dir<4; dir++) {
findBridge(i, j, dir, graph[i][j], 0);
}
}
}
PriorityQueue<Pair> pq = new PriorityQueue<Pair>();
HashSet<Integer> visit = new HashSet<Integer>();
int cur = INIT_LANDMARK;
int ans = 0;
for (int i=INIT_LANDMARK; i<landMark; i++) {
visit.add(cur);
line[cur].forEach(pq::add);
while (!pq.isEmpty()) {
Pair poll = pq.poll();
if (!visit.contains(poll.dest)) {
cur = poll.dest;
ans = ans + poll.cost;
break;
}
}
}
return visit.size() == landMark - INIT_LANDMARK ? ans : -1;
}
private void findLand(int initY, int initX, int landMark) {
Queue<Point> queue = new LinkedList<Point>();
queue.add(new Point(initY, initX));
graph[initY][initX] = landMark;
while (!queue.isEmpty()) {
Point poll = queue.poll();
for (int dir=0; dir<4; dir++) {
int ny = dy[dir] + poll.y;
int nx = dx[dir] + poll.x;
if (isOutofRange(ny, nx)) continue;
if (graph[ny][nx] != 1) continue;
graph[ny][nx] = landMark;
queue.add(new Point(ny, nx));
}
}
}
private void findBridge(int y, int x, int dir, int landMark, int distance) {
int ny = dy[dir] + y;
int nx = dx[dir] + x;
if (isOutofRange(ny, nx)) return;
int otherLandMark = graph[ny][nx];
if (otherLandMark == landMark) return;
if (otherLandMark == 0) {
findBridge(ny, nx, dir, landMark, distance+1);
return;
}
if (distance < 2) return;
line[landMark].add(new Pair(otherLandMark, distance));
line[otherLandMark].add(new Pair(landMark, distance));
}
private boolean isOutofRange(int y, int x) {
return y >= N || y < 0 || x >= M || x < 0;
}
private BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(System.in));
}
private int[] getInt(BufferedReader br) throws IOException {
return Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
}
}
import sys
r=sys.stdin.readline
from collections import deque
def isOut(y,x):
return y>=n or y<0 or x<0 or x>=m
dx=[0,0,1,-1]
dy=[1,-1,0,0]
n,m=map(int,r().split())
graph=[]
for i in range(n):
graph.append(list(map(int,r().split())))
# BFS
def findLand(a,b,c):
q=deque([(a,b)])
graph[a][b]=c
while q:
y,x=q.popleft()
for i in range(4):
ny = dy[i] + y
nx = dx[i] + x
if isOut(ny,nx): continue
if graph[ny][nx] != 1: continue
graph[ny][nx] = c
q.append((ny,nx))
lands=2
for i in range(n):
for j in range(m):
if graph[i][j] != 1: continue
findLand(i,j,lands)
lands+=1
# DFS
line=[]
def dfs(y,x,k,l,w):
ny = y + dy[k]
nx = x + dx[k]
if isOut(ny,nx): return
if graph[ny][nx] == l: return
if graph[ny][nx] == 0:
dfs(ny,nx,k,l,w+1)
return
if w<2: return
line.append((l-2,graph[ny][nx]-2,w))
for i in range(n):
for j in range(m):
if not graph[i][j]: continue
for k in range(4):
dfs(i,j,k,graph[i][j],0)
# MST
# 섬 c-2 개
parent=[i for i in range(lands-2)]
def find(x):
if parent[x]==x:
return x
parent[x]=find(parent[x])
return parent[x]
def merge(x,y):
x=find(x)
y=find(y)
if x>y: parent[x]=y
else: parent[y]=x
def isUnion(x,y):
x=find(x)
y=find(y)
if x==y: return True
return False
line.sort(key=lambda x : x[2])
ans=0
for a,b,c in line:
if isUnion(a,b): continue
merge(a,b)
ans+=c
parent[0]=find(parent[0])
for i in range(1,lands-2):
if parent[0]!=find(parent[i]):
print(-1)
exit()
print(ans)