n x n의 크기를 가지는 배열이 주어진다. 각 인덱스는 3개의 값을 가진다.
0: 빈 칸
1: 집
2: 치킨 집
한 집에서 여러개의 치킨집까지의 거리 중 가장짧은거리를 치킨거리라고 한다.
각 집의 치킨거리를 합산했을 때 도시치킨거리라 한다.
m개의 치킨집만 남기고 나머지 치킨집을 폐업시켰을 때 도시치킨거리가 최솟값을 가질 때 그 값을 구해라.
ex)
# input
5 3
0 0 1 0 0
0 0 2 0 1
0 1 2 0 0
0 0 1 0 0
0 0 0 0 2
처음에는 itertool 라이브러리를 사용해서 m개의 치킨집만 선택한 모든 경우에 대해 구했다.
import sys
import itertools
input = sys.stdin.readline
n, m = map(int, input().split())
cityMap = [list(map(int, input().split())) for _ in range(n)]
homePosList = []
chickenPosList = []
answer = 1e9
for i in range(n):
for j in range(n):
# 집 좌표 기록
if cityMap[i][j] == 1:
homePosList.append((i, j))
# 치킨집 좌표 기록
elif cityMap[i][j] == 2:
chickenPosList.append((i, j))
# m개의 치킨집 선택
for selectedChickenPosGroup in itertools.combinations(chickenPosList, m):
chickenDist = 0
# 각 집마다 치킨거리 기록
homeChickenDistDict = {str(i): 1e9 for i in range(len(homePosList))}
for chickenPos in selectedChickenPosGroup:
# 한 집에서 여러개의 치킨집까지 거리 중 가장 짧은것만 기록
for i in range(len(homePosList)):
homeChickenDistDict[str(i)] = min(homeChickenDistDict[str(i)], abs(chickenPos[0] - homePosList[i][0]) + abs(chickenPos[1] - homePosList[i][1]))
# 도시의 치킨거리 계산
for i in range(len(homePosList)):
chickenDist += homeChickenDistDict[str(i)]
if chickenDist >= answer:
break
# 도시의 최소 치킨거리
answer = min(answer, chickenDist)
print(answer)
통과는 하였지만 1064ms로 꽤 오래걸렸고, 다른 사람들의 코드와 비교하여 오래걸렸다.
모든 경우에 대해 매번 거리를 계산해서 사용하기 때문에 오래걸린다고 생각했다.
방법을 바꾸어 집에서 모든 치킨집까지의 거리를 한 번만 미리 계산해두고, 선택된 집들의 인덱스를 기억했다가 인덱스 접근을 통해 거리계산없이 바로 비교만 할 수 있도록 코드를 바꿔보았다.
import sys
input = sys.stdin.readline
def selectChickenPos(index, cnt):
global answer
# m개의 치킨집 선택이 끝났을 때
if cnt == m:
total = 0
# 선택된 치킨집 index 번호를 통해 바로 거리비교 후 가장짧은거리 기록
for homeChickenDistGroup in homeChickenDistGroupList:
minDist = 1e9
for i in selectedChickenIndexList:
minDist = min(minDist, homeChickenDistGroup[i])
# 도시치킨거리
total += minDist
if total >= answer:
break
# 최소도시치킨거리
answer = min(answer, total)
# 몇 번 index 치킨집이 선택되었는지 기록
# 각 케이스가 끝나면 다음 케이스를 위한 pop
for i in range(index, len(chickenPosList)):
if i not in selectedChickenIndexList:
selectedChickenIndexList.append(i)
selectChickenPos(i + 1, cnt + 1)
selectedChickenIndexList.pop()
n, m = map(int, input().split())
cityMap = [list(map(int, input().split())) for _ in range(n)]
chickenPosList = []
homeChickenDistGroupList = []
answer = 1e9
for i in range(n):
for j in range(n):
# 치킨집 좌표 기록
if cityMap[i][j] == 2:
chickenPosList.append((i, j))
selectedChickenIndexList = []
# 한 집에서 모든 치킨집까지의 거리를 순차적으로 기록
for i in range(n):
for j in range(n):
if cityMap[i][j] == 1:
homeChickenDistGroupList.append([])
for chickenPos in chickenPosList:
homeChickenDist = abs(i - chickenPos[0]) + abs(j - chickenPos[1])
homeChickenDistGroupList[-1].append(homeChickenDist)
selectChickenPos(0, 0)
print(answer)
시간이 224ms로 단축되었다.