BOJ 1626 두 번째로 작은 스패닝 트리

LONGNEW·2022년 1월 28일
0

BOJ

목록 보기
309/333

https://www.acmicpc.net/problem/1626
시간 2초, 메모리 128MB

input :

  • V E(1 ≤ V ≤ 50,000, 1 ≤ E ≤ 200,000)
  • u v w(0 <= w < 100,000)

output :

  • 두 번째로 작은 스패닝 트리의 값을 출력

  • 스패닝 트리나 두 번째로 작은 스패닝 트리가 존재하지 않는다면 -1을 출력


기본적인 아이디어는 BOJ 15481 그래프와 MST
의 간선들 중 최댓값을 가져오는 것을 응용해야 한다.

가장 큰 차이점으로는 트리가 존재하지 않을 수도 있다는 것이다.

제한

  1. 스패닝 트리의 존재여부
  2. 두 번째로 작은 스패닝 트리의 존재여부

1번의 경우에는 MST의 간선의 개수로 체크할 수 있다.
중요한 것은 2번이다. 이루고 있는 간선들 중 가장 큰 값을 찾고, 새로운 간선을 추가하는데 값이 동일한 경우가 있을 수 있다.
그런 경우에는 2번쨰로 작은 간선을 찾아두고 이 경우도 체크를 해야 한다.

다음 풀이

  1. MST 제작
  2. 연결되어 있는 간선 제외 하고 MST를 찾기

이전과 동일하게 크루스칼 + union-find를 통해 MST 제작.
이거도 결국에는 특정 노드 두개가 이루는 사이클 중, 가장 값이 큰 놈을 찾아야 함.
그렇기에 LCA를 사용해서 더 빠르게 체킹을 함.
그러나, 이 문제에서는 2개의 간선을 저장해야 해서 lca 노드, lca 간선을 저장 하는 두 개의 다른 배열을 가지도록 하는 방법도 좋음.

longest라는 함수를 만들어서, lca를 돌면서 체킹할 때도 이를 사용할 수 있다.
lca 돌 때 마지막에는 해당하는 노드가 찾을 게 2개라서 이걸 2번 체킹하면 된다.

추가

DFS보다 BFS를 사용하는 방안이 메모리를 조금 덜 사용하는 것 같음. 다른 부분들도 코드를 고쳤지만 이 탐색 방법을 바꾼 것이 가장 유효한듯함.
쓰레드가 덜 생겨서 그렇지 않을까 싶음.

import sys
from math import log2
from collections import deque

def find(node):
    if parent[node] != node:
        parent[node] = find(parent[node])
    return parent[node]

def union(a, b):
    parent_a = find(a)
    parent_b = find(b)

    if parent_a > parent_b:
        parent[parent_a] = parent_b
    else:
        parent[parent_b] = parent_a

def bfs():
    depth[0] = 0
    q = deque([0])

    while q:
        node = q.popleft()

        for next_node, cost in graph[node]:
            if depth[next_node] != -1:
                continue

            depth[next_node] = depth[node] + 1
            parent[next_node][0] = node
            two_weight[next_node][0] = [cost, -1]
            q.append(next_node)

def longest(arr1, arr2):
    temp = list(set(arr1 + arr2))
    temp.sort(reverse=True)

    while len(temp) < 2:
        temp.append(-1)

    temp = temp[:2]
    return temp

def set_parent():
    bfs()

    for log in range(1, k):
        for node in range(1, v):
            next_node = parent[node][log - 1]
            parent[node][log] = parent[next_node][log - 1]

            weight1, weight2 = two_weight[node][log - 1], two_weight[next_node][log - 1]
            two_weight[node][log] = longest(weight1, weight2)

def lca(a, b):
    ret = [-1, -1]
    if depth[a] > depth[b]:
        a, b = b, a

    for log in range(k - 1, -1, -1):
        if depth[b] - depth[a] >= (1 << log):
            ret = longest(ret, two_weight[b][log])
            b = parent[b][log]

    if a == b:
        return ret

    for log in range(k - 1, -1, -1):
        if parent[b][log] != parent[a][log]:
            ret = longest(ret, two_weight[a][log])
            ret = longest(ret, two_weight[b][log])
            b = parent[b][log]
            a = parent[a][log]

    ret = longest(ret, two_weight[a][0])
    ret = longest(ret, two_weight[b][0])

    return ret


v, e = map(int, sys.stdin.readline().split())
edge, graph, used = [], [[] for i in range(v)], [0] * e
parent = [i for i in range(v)]

for i in range(e):
    a, b, c = map(int, sys.stdin.readline().split())
    a -= 1
    b -= 1
    edge.append((c, a, b))

edge.sort()
mst, cnt = 0, 0

for i in range(e):
    c, a, b = edge[i]

    if find(a) != find(b):
        union(a, b)
        graph[a].append((b, c))
        graph[b].append((a, c))

        used[i] = 1
        mst += c
        cnt += 1

if cnt != v - 1:
    print(-1)
    exit(0)

ans, k = float("inf"), int(log2(v)) + 1
depth, two_weight = [-1] * v, [[[-1, -1] for _ in range(k)] for _ in range(v)]
parent = [[-1] * k for i in range(v)]

set_parent()
for i in range(e):
    if used[i]:
        continue

    w, u, v = edge[i]
    weight = lca(u, v)

    if weight[0] != w:
        ans = min(ans, mst - weight[0] + w)
    elif weight[1] != w and weight[1] != -1:
        ans = min(ans, mst - weight[1] + w)

if ans == float("inf"):
    print(-1)
    exit(0)

print(ans)

시간을 많이 사용하긴 했지만, 풀어서 기분은 좋다.
드디어 2학년 2하기 알고리즘의 굴레에서 벗어난 느낌이다.

0개의 댓글