익숙함에 속기 쉬운 문제다. 문제에서 주어진 입력 형태나 문제 내용을 봤을 때 bfs를 쉽게 떠올릴 수 있다. bfs로 이 문제를 구현하려면 생각보다 코드가 길어지는 것을 볼 수 있다.
#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>
#define MAX 1000000000
std::pair<int, int> direction[4] = { {0, 1}, {1, 0}, {-1, 0}, {0, -1} };
int N, M;
int map[52][52];
int distMap[52][52];
int dist[100][26] = { 0, }; // [HouseNum][ChickenNum]
std::vector<std::pair<int, int>> housePos;
std::vector<std::pair<int, int>> chickenPos;
std::queue<std::pair<int, int>> q;
void init_map() {
for (int row = 0; row <= N + 1; row++) {
for (int col = 0; col <= N + 1; col++) {
map[row][col] = -1;
}
}
}
void reset_dist_map() {
for (int row = 1; row <= N; row++) {
for (int col = 1; col <= N; col++) {
distMap[row][col] = 0;
}
}
}
int get_chicken_num(int row, int col) {
for (int index = 0; index < chickenPos.size(); index++) {
if (chickenPos[index].first == row && chickenPos[index].second == col) {
return index;
}
}
return -1;
}
void bfs(int count) {
while (!q.empty()) {
int curRow = q.front().first;
int curCol = q.front().second;
q.pop();
for (auto dir : direction) {
int nextRow = curRow + dir.first;
int nextCol = curCol + dir.second;
if (map[nextRow][nextCol] == -1 || distMap[nextRow][nextCol]) {
continue;
}
if (map[nextRow][nextCol] == 2) {
dist[count][get_chicken_num(nextRow, nextCol)] = distMap[curRow][curCol];
}
distMap[nextRow][nextCol] = distMap[curRow][curCol] + 1;
q.push({ nextRow, nextCol });
}
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(NULL);
std::cout.tie(NULL);
std::cin >> N >> M;
init_map();
for (int row = 1; row <= N; row++) {
for (int col = 1; col <= N; col++) {
std::cin >> map[row][col];
if (map[row][col] == 1) {
housePos.push_back({ row, col });
}
else if (map[row][col] == 2) {
chickenPos.push_back({ row, col });
}
}
}
for (int count = 0; count < housePos.size(); count++) {
q.push({ housePos[count]});
distMap[housePos[count].first][housePos[count].second] = 1;
bfs(count);
reset_dist_map();
}
int totalMin = MAX;
std::vector<int> combiTemp(chickenPos.size(), 0);
for (int index = chickenPos.size() - 1; M > 0; index--, M--) {
combiTemp[index] = 1;
}
do {
int totalDist = 0;
for (int houseCnt = 0; houseCnt < housePos.size(); houseCnt++) {
int min = MAX;
for (int chickenCnt = 0; chickenCnt < chickenPos.size(); chickenCnt++) {
if (combiTemp[chickenCnt] == 1 && min > dist[houseCnt][chickenCnt]) {
min = dist[houseCnt][chickenCnt];
}
}
totalDist += min;
}
if (totalMin > totalDist) {
totalMin = totalDist;
}
} while (std::next_permutation(combiTemp.begin(), combiTemp.end()));
std::cout << totalMin << std::endl;
}
하지만 bfs를 활용하지 않아도 된다. 어차피 입력에서 집과 가게를 구별할 수 있고, bfs를 돌리게 된다면 각 집의 위치에 따라 bfs를 새롭게 돌려야 하기 때문에 모든 집, 치킨 가게의 좌표하는 저장할 배열과 각 지점까지의 거리를 저장할 배열이 별도로 필요하다. 물론 초기화하는 비용은 덤이다.
그럴 필요 없이 집과 치킨 가게의 좌표를 저장한 배열만 있어도 각 원소의 좌표 값만으로 거리를 쉽게 구할 수 있다. 이렇게 간단한 풀이가 있음에도 익숙한 유형의 문제라 생각하여 무작정 풀다가 낭패를 보는 타입의 문제이다.
아래는 bfs 없이 구현한 코드의 전문이다.
#include <iostream>
#include <algorithm>
#include <vector>
#include <math.h>
#define MAX 1000000000
int N, M;
std::pair<int, int> housePos[100];
std::pair<int, int> chickenPos[13];
int dist[100][13];
int houseNum = 0;
int chickenNum = 0;
void get_dist() {
for (int houseCnt = 0; houseCnt < houseNum; houseCnt++) {
for (int chickenCnt = 0; chickenCnt < chickenNum; chickenCnt++) {
dist[houseCnt][chickenCnt] = std::abs(housePos[houseCnt].first - chickenPos[chickenCnt].first) + std::abs(housePos[houseCnt].second - chickenPos[chickenCnt].second);
}
}
}
int combination() {
std::vector<int> combiTemp(chickenNum, 0);
for (int index = chickenNum - 1; M > 0; index--, M--) {
combiTemp[index] = 1;
}
int distSumMin = MAX;
do {
int distSum = 0;
for (int houseCnt = 0; houseCnt < houseNum; houseCnt++) {
int min = MAX;
for (int chickenCnt = 0; chickenCnt < chickenNum; chickenCnt++) {
if (combiTemp[chickenCnt] && min > dist[houseCnt][chickenCnt]) {
min = dist[houseCnt][chickenCnt];
}
}
distSum += min;
}
if (distSumMin > distSum) {
distSumMin = distSum;
}
} while (std::next_permutation(combiTemp.begin(), combiTemp.end()));
return distSumMin;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(NULL);
std::cout.tie(NULL);
std::cin >> N >> M;
int node;
for (int row = 0; row < N; row++) {
for (int col = 0; col < N; col++) {
std::cin >> node;
if (node == 1) {
housePos[houseNum++] = {row, col};
}
else if (node == 2) {
chickenPos[chickenNum++] = {row, col};
}
}
}
get_dist();
std::cout << combination() << std::endl;
}