오늘은 이번 문제에 대해 생각했던 과정에 대해 말씀드리겠습니다.
서론
처음 저는 이 문제를 보고 다익스트라 문제라고 생각하지 못했습니다.
다익스트라 알고리즘으로 구현한다고 했을 때 모든 정점에 대해 최단 시간을 가져오고, 그 배열 내에서 시작점과 도착지에 대한 시간만 가져와 오고 가는 데 걸리는 시간을 구하고 그 중 최솟값을 구한다.
어휴 복잡해!!
하면서 저는 간단하게 플로이드 워셜을 썼습니다.
플로이드 워셜 알고리즘을 활용해 구현하다.
플로이드 워셜 알고리즘을 활용해 구현하는 과정은 쉬웠습니다.
private fun floydWarshall(): Array<IntArray> {
val time = Array(graph.size) { IntArray(graph.size) { INF } }
for(i in 1 until graph.size) {
time[i][i] = 0
}
for(i in 1 until graph.size) {
for (neighbor in graph[i]) {
time[i][neighbor.node] = neighbor.time
}
}
for(k in 1 until graph.size) {
for(i in 1 until graph.size) {
for(j in 1 until graph.size) {
time[i][j] = min(time[i][j], time[i][k] + time[k][j])
}
}
}
return time
}
위와 같이 각 정점에 대해 최소 시간을 구해주고
var result = Int.MIN_VALUE
for(i in 1 until graph.size) {
result = max(result, time[i][meet] + time[meet][i])
}
return result
서로 간의 거리 중 가장 오래 걸리는 시간에 대해 최댓값을 구해줬습니다.
그리곤 맞았죠.
하지만 여기서 저는 의문이 들었습니다.
진짜 저게 풀이가 맞을까 하고요.
그래서 질문 게시판에 가봤더니 다음과 같은 글이 있었습니다.
그래서 저는 아 운이 좋아서 맞은 거구나 하고 다시 다익스트라 알고리즘으로 구현해야겠다고 생각하게 됩니다.
다익스트라 알고리즘으로 구현하다.
다익스트라 알고리즘으로 구현하는 건 생각보다 쉬웠습니다.
왜일까요?
질문 게시판 보면서 답을 봐버렸기 때문이죠..ㅎ
로직은 다음과 같았습니다.
graph = Array(size = n + 1) { arrayListOf() }
reverseGraph = Array(size = n + 1) { arrayListOf() }
repeat(m) {
val (start, end, time) = br.readLine().split(' ').map { it.toInt() }
graph[start].add(Node(end, time))
reverseGraph[end].add(Node(start, time))
}
val time = dijkstra(graph, meet)
val reverseTime = dijkstra(reverseGraph, meet)
var result = Int.MIN_VALUE
for(i in 1 until graph.size) {
result = max(result, time[i] + reverseTime[i])
}
return result
최종 코드
import java.util.PriorityQueue
import kotlin.math.max
import kotlin.math.min
/*
* 백준 1238번. 파티
* https://www.acmicpc.net/problem/1238
*/
data class Node(
val node: Int,
val time: Int
): Comparable<Node> {
override fun compareTo(other: Node): Int {
return time - other.time
}
}
private lateinit var graph: Array<ArrayList<Node>>
private lateinit var reverseGraph: Array<ArrayList<Node>>
private var meet = -1
private const val INF = 1000 * 100 + 1
private fun main() {
initVariable()
val result = getResult()
printResult(result)
}
private fun initVariable() {
val br = System.`in`.bufferedReader()
val (n, m, x) = br.readLine().split(' ').map { it.toInt() }
graph = Array(size = n + 1) { arrayListOf() }
reverseGraph = Array(size = n + 1) { arrayListOf() }
repeat(m) {
val (start, end, time) = br.readLine().split(' ').map { it.toInt() }
graph[start].add(Node(end, time))
reverseGraph[end].add(Node(start, time))
}
meet = x
br.close()
}
private fun getResult(): Int {
val time = dijkstra(graph, meet)
val reverseTime = dijkstra(reverseGraph, meet)
var result = Int.MIN_VALUE
for(i in 1 until graph.size) {
result = max(result, time[i] + reverseTime[i])
}
return result
}
private fun dijkstra(graph: Array<ArrayList<Node>>,start: Int): IntArray {
val time = IntArray(graph.size) { INF }
time[start] = 0
val queue = PriorityQueue<Node>()
queue.add(Node(start, 0))
while (queue.isNotEmpty()) {
val cur = queue.remove()
for(neighbor in graph[cur.node]) {
if(time[neighbor.node] > time[cur.node] + neighbor.time) {
time[neighbor.node] = time[cur.node] + neighbor.time
queue.add(Node(neighbor.node, cur.time + neighbor.time))
}
}
}
return time
}
private fun printResult(result: Int) {
val bw = System.out.bufferedWriter()
bw.write("$result\n")
bw.flush()
bw.close()
}
마치며
시간 차이가 보이시나요? ( 2MlogN vs N^3 )
확실히 어떤 알고리즘을 선택하느냐에 따라 시간 차이가 확실히 난다는 것을 깨닫게 된 문제였던 것 같습니다.
그뿐만 아니라 처음 생각했던 방법으로 구현했더라면 O(NMlogN)이었던 반면 플로이드 워셜 알고리즘은 O(N^3)로 더 컸습니다.
시간 복잡도가 더 큰 알고리즘인 것을 알고 있었음에도 귀찮다는 이유로 사용했던 저 자신에게 반성하게 되었던 문제였던 것 같습니다.