[BOJ] 다리 만들기2- union-find, 크루스칼

김가영·2021년 3월 4일
0

Algorithm

목록 보기
69/78
post-thumbnail

17472 다리 만들기 2 첫 골드 3

import sys
from collections import deque
input = sys.stdin.readline

height, width = list(map(int, input().split()))
bridges = []
for _ in range(height):
    bridges.append(list(map(int, input().split())))

def findConnected(y,x): # y,x 에 연결된 모든 땅을 찾기
    answer = set()
    answer.add((y,x))
    v = deque([(y,x)])
    bridges[y][x] = 0
    while v:
        a,b = v.popleft()
        for ny, nx in [(a+1,b),(a,b+1),(a-1,b),(a,b-1)]:
            if 0 <= ny < height and 0 <= nx < width and bridges[ny][nx] == 1:
                answer.add((ny,nx))
                v.append((ny,nx))
                bridges[ny][nx] = 0
    return answer


land = {}
all_land = set()
nland = 0
for ny in range(height):
    for nx in range(width):
        if bridges[ny][nx] == 1:
            land[nland] = findConnected(ny,nx)
            all_land = all_land | land[nland]
            nland += 1

def findBridge(a,b): # a번째 섬과 b번째 섬을 연결하는 bridge 길이 찾기
    land_a = land[a]
    land_b = land[b]
    short = 20

    ay = {j for j,i in land_a}
    by = {j for j,i in land_b}
    for dy in ay & by:
        ax = {i for j,i in land_a if j == dy}
        bx = {i for j,i in land_b if j == dy}

        for x in ax:
            right = x
            while right + 1 < width and (dy, right + 1) not in all_land:
                right += 1
            if right + 1 in bx:
                s = right - x
                short = min(s, short) if s > 1 else short
            left = x
            while left - 1 >= 0 and (dy, left - 1) not in all_land:
                left -= 1
            if left - 1 in bx:
                s = x - left
                short = min(s, short) if s > 1 else short

    ax = {i for j,i in land_a}
    bx = {i for j,i in land_b}
    for dx in ax & bx:
        ay = {j for j,i in land_a if i == dx}
        by = {j for j,i in land_b if i == dx}

        for y in ay:
            right = y
            while right + 1 < height and (right + 1, dx) not in all_land:
                right += 1
            if right + 1 in by:
                s = right - y
                short = min(s, short) if s > 1 else short
            left = y
            while left - 1 >= 0 and (left - 1, dx) not in all_land:
                left -= 1
            if left - 1 in by:
                s = y - left
                short = min(s, short) if s > 1 else short

    return short if short != 20 else -1
brid = []

for i in range(len(land)):
    for j in range(i + 1, len(land)):
        x = findBridge(i,j)
        if x != -1:
            brid.append((i,j,x))

brid.sort(key = lambda x : x[2])
# union -find
root = [i for i in range(nland)]

def union(a,b):
    a = find(a)
    b = find(b)
    if a < b:
        root[b] = a
    else:
        root[a] = b

def find(a):
    if root[a] == a:
        return a
    r = find(root[a])
    root[a] = r
    return root[a]

answer = 0
for b in brid:
    start, end, d = b
    if find(start) != find(end):
        answer += d
        union(start, end)
# 모두 연결됐는 지 확인하기
connected = True
for r in root:
    if find(r) != 0:
        connected = False
        break
# print(brid)
print(answer) if connected else print(-1)

1. 섬 찾기

land = {}
all_land = set()
nland = 0
for ny in range(height):
    for nx in range(width):
        if bridges[ny][nx] == 1:
            land[nland] = findConnected(ny,nx)
            all_land = all_land | land[nland]
            nland += 1

findConnected(y,x) : y,x 와 연결된 1 들을 모두 찾는 함수
land : 각 섬들의 위치 배열
all_land : 모든 섬들의 위치 set
nland : 섬 개수

2. 다리 찾기

brid = []

for i in range(len(land)):
    for j in range(i + 1, len(land)):
        x = findBridge(i,j)
        if x != -1:
            brid.append((i,j,x))

brid.sort(key = lambda x : x[2])

findBridge : 각 섬들을 연결하는 다리를 찾는다. 겹치는 부분은 생각하지 않으므로 길이만 return 한다.
모든 bridge 는 길이로 sort 한다.

3. 연결하기

크루스칼 알고리즘을 이용했다. 모든 간선을 비용(다리의 길이)을 기준으로 정렬하고, 비용이 작은 간선부터 양 끝의 두 정점을 비교한다.
두 정점의 최상위 정점을 확인하고, 서로 다를 경우 두 정점을 연결한다(중복 연결은 비용 손실) 사이클이 발생하는 지의 여부는 Union-Find 알고리즘을 이용


answer = 0
for b in brid:
    start, end, d = b
    if find(start) != find(end):
        answer += d
        union(start, end)
# 모두 연결됐는 지 확인하기
connected = True
for r in root:
    if find(r) != 0:
        connected = False
        break
# print(brid)
print(answer) if connected else print(-1)

모든 섬들의 최상위 정점이 0인지 확인, 만약 0이 아니라면 모든 섬이 연결된 것이 아니므로 -1 을 출력한다.

profile
개발블로그

0개의 댓글