[백준] 17472 - 다리 만들기 2

안우진·2024년 3월 19일
0

백준

목록 보기
15/21

[문제]


https://www.acmicpc.net/problem/17472

[풀이]


MST 응용 문제로 BFS, DFS만 섞어주면 된다.
섬을 찾아 구분해주고 다리를 모두 만들어보면 된다. N, M이 작기 때문에 시간은 여유롭다.
다리를 만들 때 모든 (i, j) 에 대해 상하좌우 모든 방향으로 다리를 이어본다.
다른 섬에 도달하고 길이가 2보다 작지 않다면 저장해두었다가 MST로 만들어 답을 낸다.
모든 섬이 이어지지 않을 수 있기 때문에 visit의 size로 하면 된다.

코드에서 landMark는 각 섬의 고유번호라고 생각하면 된다.
입력에서 0과 1을 썼으므로 편의상 2부터 숫자를 매겼다.

Java는 프림 알고리즘, Python은 크루스칼 알고리즘으로 해결했다.

[코드]


import java.io.*;
import java.util.*;

class Point {
	int y;
	int x;
	public Point(int y, int x) {
		this.y = y;
		this.x = x;
	}
}

class Pair implements Comparable<Pair>{
	int dest;
	int cost;
	public Pair(int dest, int cost) {
		this.dest = dest;
		this.cost = cost;
	}
	public int compareTo(Pair o) {
		return cost - o.cost;
	}	
}

public class Main {
	
	private static final int INIT_LANDMARK = 2;
	private static final int[] dx = {0,0,1,-1};
	private static final int[] dy = {1,-1,0,0};
	private static int[][] graph;
	private static ArrayList<Pair>[] line;
	private static int N;
	private static int M;

	public static void main(String[] args) throws IOException {
		Main main = new Main();
		BufferedReader br = main.getReader();
		int[] temp = main.getInt(br);
		N = temp[0];
		M = temp[1];
		graph = new int[N][M];
		for (int i=0; i<N; i++) {
			temp = main.getInt(br);
			for(int j=0; j<M; j++) {
				graph[i][j] = temp[j];
			}
		}
		System.out.println(main.solve());
	}
	
	private int solve() {
		int landMark = INIT_LANDMARK;
		for(int i=0; i<N; i++) {
			for(int j=0; j<M; j++) {
				if (graph[i][j] != 1) continue;
				findLand(i, j, landMark);
				landMark++;
			}
		}
		
		line = new ArrayList[landMark];
		for (int i=INIT_LANDMARK; i<landMark; i++) {
			line[i] = new ArrayList<Pair>();
		}
		
		for(int i=0; i<N; i++) {
			for(int j=0; j<M; j++) {
				if(graph[i][j] == 0) continue;
				for(int dir=0; dir<4; dir++) {
					findBridge(i, j, dir, graph[i][j], 0);
				}
			}
		}
		
		PriorityQueue<Pair> pq = new PriorityQueue<Pair>();
		HashSet<Integer> visit = new HashSet<Integer>();
		int cur = INIT_LANDMARK;
		int ans = 0;
		for (int i=INIT_LANDMARK; i<landMark; i++) {
			visit.add(cur);
			line[cur].forEach(pq::add);
			while (!pq.isEmpty()) {
				Pair poll = pq.poll();
				if (!visit.contains(poll.dest)) {
					cur = poll.dest;
					ans = ans + poll.cost;
					break;
				}
			}
		}
		return visit.size() == landMark - INIT_LANDMARK ? ans : -1;
	}
	
	private void findLand(int initY, int initX, int landMark) {
		Queue<Point> queue = new LinkedList<Point>();
		queue.add(new Point(initY, initX));
		graph[initY][initX] = landMark;
		while (!queue.isEmpty()) {
			Point poll = queue.poll();
			for (int dir=0; dir<4; dir++) {
				int ny = dy[dir] + poll.y;
				int nx = dx[dir] + poll.x;
				if (isOutofRange(ny, nx)) continue;
				if (graph[ny][nx] != 1) continue;
				graph[ny][nx] = landMark;
				queue.add(new Point(ny, nx));
			}
		}
	}
	
	private void findBridge(int y, int x, int dir, int landMark, int distance) {
		int ny = dy[dir] + y;
		int nx = dx[dir] + x;
		if (isOutofRange(ny, nx)) return;
		int otherLandMark = graph[ny][nx];
		if (otherLandMark == landMark) return;
		if (otherLandMark == 0) {
			findBridge(ny, nx, dir, landMark, distance+1);
			return;
		}
		if (distance < 2) return;
		line[landMark].add(new Pair(otherLandMark, distance));
		line[otherLandMark].add(new Pair(landMark, distance));
	}
	
	private boolean isOutofRange(int y, int x) {
		return y >= N || y < 0 || x >= M || x < 0;
	}

	private BufferedReader getReader() {
		return new BufferedReader(new InputStreamReader(System.in));
	}
	
	private int[] getInt(BufferedReader br) throws IOException {
		return Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
	}
}

import sys
r=sys.stdin.readline
from collections import deque

def isOut(y,x):
  return y>=n or y<0 or x<0 or x>=m

dx=[0,0,1,-1]
dy=[1,-1,0,0]
n,m=map(int,r().split())
graph=[]
for i in range(n):
  graph.append(list(map(int,r().split())))

# BFS
def findLand(a,b,c):
  q=deque([(a,b)])
  graph[a][b]=c
  while q:
    y,x=q.popleft()
    for i in range(4):
      ny = dy[i] + y
      nx = dx[i] + x
      if isOut(ny,nx): continue
      if graph[ny][nx] != 1: continue
      graph[ny][nx] = c
      q.append((ny,nx))

lands=2
for i in range(n):
  for j in range(m):
    if graph[i][j] != 1: continue
    findLand(i,j,lands)
    lands+=1

# DFS
line=[]
def dfs(y,x,k,l,w):
  ny = y + dy[k]
  nx = x + dx[k]
  if isOut(ny,nx): return
  if graph[ny][nx] == l: return
  if graph[ny][nx] == 0:
    dfs(ny,nx,k,l,w+1)
    return
  if w<2: return
  line.append((l-2,graph[ny][nx]-2,w))

for i in range(n):
  for j in range(m):
    if not graph[i][j]: continue
    for k in range(4):
      dfs(i,j,k,graph[i][j],0)

# MST
# 섬 c-2 개
parent=[i for i in range(lands-2)]

def find(x):
  if parent[x]==x:
    return x
  parent[x]=find(parent[x])
  return parent[x]

def merge(x,y):
  x=find(x)
  y=find(y)
  if x>y: parent[x]=y
  else: parent[y]=x

def isUnion(x,y):
  x=find(x)
  y=find(y)
  if x==y: return True
  return False

line.sort(key=lambda x : x[2])

ans=0
for a,b,c in line:
  if isUnion(a,b): continue
  merge(a,b)
  ans+=c

parent[0]=find(parent[0])
for i in range(1,lands-2):
  if parent[0]!=find(parent[i]):
    print(-1)
    exit()
print(ans)

0개의 댓글