최근 알고리즘 관련 오픈채팅방에 들어갔다.
반 년 넘게 알고리즘을 독학하면서 흥미도 점점 떨어지고 문제도 점점 어려워져서 환기를 주려는 것이 목적이었는데 정말 많은 고수분들이 톡방에 있더라...
톡방에 있는 류호석이라는 고수 분께서 최근에 벽 부수고 이동하기라는 문제를 어떻게 접근해야 하고 어떤 방식으로 풀 수 있는 지 설명하시는 라이브 방송을 해주셔서 한 번 영상을 봐야겠다 싶었다.
옛날에 정말 어렵게 풀기도 했고 그 때 당시는 이해하고 풀었다 쳐도 지금 당장 다시 풀라하면 풀 수 있을 지도 확신하지 못했기 때문이다.
어제 새벽에 잠깐 보다 잘라했는데 홀린 듯이 다 봐버렸다....
정말 많은 것을 배울 수 있어서 리뷰해본다.
오늘 적는 내용은 류호석님의 풀이대로 풀어본 후기(?) 느낌이다.
문제링크
https://www.acmicpc.net/problem/2206
설명
이 문제는 벽을 최대 1개 부수고 목적지까지 갈 수 있는 최단경로의 길이를 구하는 문제이다.
행렬 형식의 맵과 간선의 가중치가 없는 거로 보아 bfs를 활용하는 거 같다.
맵에 있는 벽을 하나씩 제거하고 bfs를 실행하는 방식으로 완전탐색을 실행해볼 때, 각 방식의 시간복잡도는 다음과 같다.
벽을 하나씩 제거 : 벽의 개수가 최대 N x M개이기 때문에 O(N x M)
bfs의 시간복잡도(인접리스트일 때) : O(N x M)
-> O(정점 + 간선) = O(N x M + 4 x N x M)
이 둘을 곱한 O(N^2 x M^2)의 시간복잡도로는 TLE가 발생하기 때문에 완전탐색으로는 풀 수가 없다는 걸 알 수 있다.
그렇다면 다른 방식으로 풀어야하는데 보통 bfs를 실행할 때, 시작점에서 x까지 도달하는 최단 거리를 저장하는 방식으로 푼다.
ex. dist[x] = 시작점에서 x까지 가는 최단거리
이 문제는 여기서 벽을 부수는 개수라는 또 다른 조건(?)이 추가되는데, 여기서 dist[x]는 항상 최단거리라는 하나의 값만을 가지는 것이 포인트라고 강조하셨다.
dist[x] = {최단거리, 부수고 온 벽의 개수} 같은 두 개의 값을 가지면 안된다는 뜻이다.
실제로 bfs를 돌릴 때, 좌표 말고도 부수고 온 벽의 개수 같은 부가적인 정보를 같이 담으면서 실행한 풀이가 많다. 이와 같이 풀이하면 어떻게든 반례를 나타낼 수 있다고 하는데 대부분의 풀이에서는 발생하는 여러 가지 반례들을 처리하기 위해 여러 조건을 생각하며 추가 코드를 작성해주면서 풀이가 조금 복잡해진다.
그렇다면 dist[x]가 최단거리라는 하나의 정보만을 담으면서 부순 벽의 개수까지 포함하려면 어떻게 해야할까?
답은 하나의 차원을 추가하는 것이다.
예를 들어 dist[x][0]는 부수고 온 벽의 개수가 0인 상태로 x까지 도달하는 최단거리를 뜻한다. 이 같은 방식으로는 최단거리라는 하나의 값만을 가지면서 부수고 온 벽의 개수까지 표현이 가능하다.
이 문제뿐만 아니라 다른 문제에서도 이 개념이 정말 중요하다고 강조하셨다.
그렇다면 우리는 인접 행렬로 구현된 맵의 하나의 정점마다 두 가지 정보를 담은 dist[y][x][k] 배열을 통해 이 문제를 해결할 수 있음을 알 수 있다.
그 다음 또 한 가지를 생각해야 하는데 간선을 연결할 때의 과정을 생각해봐야한다.
a좌표에서 b좌표로 갈 때, 만약 b좌표가 0(벽x)이라면, a정점과 b정점은 k = 0과 k = 0, 그리고 k = 1과 k = 1이 연결 될 수 있다.
반면에, b좌표가 1(벽)이라면, a정점과 b정점은 k = 0과 k = 1일 때 연결된다. b좌표가 벽이기 때문에 반드시 벽을 부숴야 도달할 수 있기 때문이다.
이어진 간선을 토대로 bfs를 실행한다면 dist[n][m][0]과 dist[n][m][1] 중 작은 값이 정답이 될 것이다.
문제의 풀이 방식을 이해하고 혼자 코드를 짜보려고 했는데 쉽지 않았다. 아마 dist[x]에 다른 차원을 추가해서 풀어본 적이 처음이고, 1학년 때 배운 구조체를 정말 오랜만에 활용한 풀이라서 그런 거 같다. 사실 굳이 구조체를 활용할 필요는 없는 거 같지만 앞으로 정말 많이 쓰일 거 같아서 일부로 구조체를 활용한 호석님 풀이대로 풀어보려고 했다.
호석님의 코드를 보면 주석이 많았고 코드가 정말 깔끔했던 거 같다.
만약 구조체를 사용하지 않고 tuple같은 걸 사용했다면 코드가 조금 복잡해졌을 거 같다.
그리고 나는 늘 int[]보다는 vector를 자주 활용하면서 문제를 풀었었고 오늘도 그랬는데 같은 풀이방식임에도 불구하고 메모리초과를 받았다. 톡방에 문의하니 vector의 경우 내가 메모리를 a로 잡아도 a 이상의 메모리를 할당하기 때문에 메모리 초과가 발생할 수 있다고 하셨다. 더 깊게 알려면 vector의 내부 구현을 알아야 한다고 하셨는데 이 부분도 찾아봐야 할 거 같다.
또 주석의 중요성도 깨닫게 되었는데 지금까지 난 내 코드에 주석을 거의 달지 않았다. 지금 와서 생각해보면 나중에 내 코드를 다시 봐도 한 번에 이해하지 못하는 상황이 자주 발생할텐데, 이 때 주석조차 없으면 정말 이해하기 힘들 수도 있겠다 싶었다.
강의를 봐서 정말 다행이라고 생각한다. 앞으로는 알고리즘 문제에 접근할 때 지금까지와는 조금이라도 다른 방식으로 접근하여 더 효율적으로 문제를 풀 수 있을 거 같다는 느낌이 든다! 조만간 벽 부수고 이동하기 2와 3도 풀어보고 4도 다시 한 번 풀어봐야겠다.
코드
#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>
using namespace std;
int n, m;
struct Node
{
int y, x, cnt;
};
vector<vector<char>> map(1001, vector<char>(1001));
vector<vector<int>> dir = { {-1, 0},{0,1},{1,0},{0,-1} };
vector<Node> graph[1001][1001][2];
int dist[1001][1001][2];
void input()
{
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
cin >> map[i][j];
}
}
void connect() // 간선 연결 작업
{
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
{
for (int d = 0; d < 4; d++)
{
int ni = i + dir[d][0];
int nj = j + dir[d][1];
if (ni >= 1 && ni <= n && nj >= 1 && nj <= m) // map 범위 안이라면
{
if (map[ni][nj] == '0') // 만약 벽이 아니라면 0과 1노드 각각 연결
{
graph[i][j][0].push_back({ ni,nj,0 });
graph[i][j][1].push_back({ ni,nj,1 });
}
else // 만약 벽이라면 부수고 이동해야하므로
{
graph[i][j][0].push_back({ ni,nj,1 });
}
}
}
}
}
}
void solution()
{
connect();
// dist 초기화
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
{
for (int k = 0; k < 2; k++)
dist[i][j][k] = -1;
}
}
dist[1][1][0] = 1;
queue<Node> q;
q.push({ 1,1,0 }); // 시작점은 항상 0이므로, k = 0 삽입
while (!q.empty())
{
Node cur = q.front();
q.pop();
for (int i = 0; i < graph[cur.y][cur.x][cur.cnt].size(); i++) // 연결된 모든 간선으로 가기
{
Node ncur = graph[cur.y][cur.x][cur.cnt][i];
if (dist[ncur.y][ncur.x][ncur.cnt] == -1) // 아직 방문하지 않았다면
{
dist[ncur.y][ncur.x][ncur.cnt] = dist[cur.y][cur.x][cur.cnt] + 1;
q.push(ncur);
}
}
}
// 정답 출력
if (dist[n][m][0] == -1)
dist[n][m][0] = 1e9;
if (dist[n][m][1] == -1)
dist[n][m][1] = 1e9;
int ans = min(dist[n][m][0], dist[n][m][1]);
ans != 1e9 ? cout << ans : cout << -1;
}
void solve()
{
input();
solution();
}
int main(void)
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
solve();
return 0;
}