우선 문제를 봤을 때, 치킨집(2)과 집(1)사이의 거리를 구해야 답을 도출할 수 있으니, 각각의 좌표를 기록해둘 필요가 있다고 생각이 들었다.
house_coor = [(i,j) for i in range(n) for j in range(n) if city[i][j] == 1 ]
chicken_coor = [(i,j) for i in range(n) for j in range(n) if city[i][j] == 2 ]
그 다음은 뭘 할 수 있을까... 하다가 문득 생각이 들었던건,
치킨집 중에서 m만큼을 고른다는 것. 그리고 이 중에서 각 집까지의 거리(치킨거리)가 최소가 되는 거리를 출력해 내야한다는 것에서 조합을 이용하기로 마음먹었다.
근데 이렇게하면 모든 경우의 수를 다 따져보는 완전탐색이 될텐데 이대로 사용할 수 있을까? 하고 브레이크가 걸렸다.
일단 치킨집의 개수가 M ≤ 치킨집 수 ≤ 13이라고 한다.
그렇다면 최대로 많은 경우의 수는 13C𝚖 인데 이러면 최대 1000단위 정도인 듯하다.
조합을 이용해보자!
## input
import sys
put = sys.stdin.readline
n,m = map(int,put().split())
city = [list(map(int, put().split())) for _ in range(n)]
만약 m=2이고, 선택된 치킨집의 좌표가 (0,1),(4,1)이라고 해보자.(이게 여러 경우 중 한 케이스가 되겠다!)
구현한 내용에서는 고른 치킨집을 기준으로 치킨집에서 각 집까지의 거리로 업데이트한다.
위에서 말한 내용을 아래처럼 구현해보았다.
answer = 0
chicken_cases = list(itertools.combinations(chicken_coor,m))
for case in chicken_cases:
shortest_dist = [[0]*n for _ in range(n)] # 경우의 수마다 초기화
for cx, cy in case:
for hx,hy in house_coor:
distance = abs(cx-hx) + abs(cy-hy)
if shortest_dist[hx][hy] == 0:
shortest_dist[hx][hy] = distance
if shortest_dist[hx][hy] > distance:
shortest_dist[hx][hy] = distance
if answer == 0 :
answer = sum([ sum(s) for s in shortest_dist])
else:
answer = min (answer,sum([ sum(s) for s in shortest_dist]))
예시대로면 (0,1)의 치킨집을 기준으로 모든 집까지의 거리를 해당 집의 좌표에 업데이트한다. 이후 (4,1)의 치킨집을 기준으로 모든 집까지의 거리를 고려할 때 나온 거리가 기록된 거리보다 작다면 새로운 거리로 업데이트하는 방식이다.
전체 코드는 다음과 같다.
## method
def sol(n, m, city):
import itertools
house_coor = [(i,j) for i in range(n) for j in range(n) if city[i][j] == 1 ]
chicken_coor = [(i,j) for i in range(n) for j in range(n) if city[i][j] == 2 ]
answer = 0
chicken_cases = list(itertools.combinations(chicken_coor,m))
for case in chicken_cases:
shortest_dist = [[0]*n for _ in range(n)] # 경우의 수마다 초기화
for cx, cy in case:
for hx,hy in house_coor:
distance = abs(cx-hx) + abs(cy-hy)
if shortest_dist[hx][hy] == 0:
shortest_dist[hx][hy] = distance
if shortest_dist[hx][hy] > distance:
shortest_dist[hx][hy] = distance
if answer == 0 :
answer = sum([ sum(s) for s in shortest_dist])
else:
answer = min (answer,sum([ sum(s) for s in shortest_dist]))
return answer
## input
import sys
put = sys.stdin.readline
n,m = map(int,put().split())
city = [list(map(int, put().split())) for _ in range(n)]
## output
print(sol(n,m,city))
많은 도움이 되었습니다, 감사합니다.