[백준] 1525. 퍼즐 (Python)

개미·2023년 2월 28일
0

알고리즘

목록 보기
9/12
post-custom-banner

📌 1525. 퍼즐

https://www.acmicpc.net/problem/1525

풀이과정

처음에는 백준 1327. 소트게임과 마찬가지로 2차원 그래프 그대로 이용하여 BFS를 이용하여 풀고자 했다. 하지만 graph의 모양이 계속 바뀌는 문제가 생겼다.

DFS로 풀려고 시도해보았는데, 저번 소트게임에서와 마찬가지로, 어디서 dfs를 멈춰야 하는지 감이 서지 않았고, 결과도 계속 recursion error, segmentation fault 오류가 났다.

코드는 다음과 같다.

from collections import deque
import copy
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**7)

graph = []
array = []

for _ in range(3):
  a, b, c = map(int, input().split())
  graph.append([a, b, c])
  array.append(a)
  array.append(b)
  array.append(c)
array.sort()
target = [[0]*3 for _ in range(3)]

for i in range(8):
  target[i//3][i%3] = array[i+1]

q = deque()

visited = set()
print("".join(map(str, graph)))
visited.add("".join(map(str, graph)))

dx = [1,-1,0,0]
dy = [0,0,1,-1]
'''
while q:
  x, y = q.popleft()
  for i in range(4):
    nx = x + dx[i]
    ny = y + dy[i]
    if nx>=0 and nx<3 and ny>=0 and ny<3:
      graph[x][y] = graph[nx][ny]
      graph[nx][ny] = 0
      if graph not in visited:
        visited.add(graph)
        q.append(nx, ny)
'''


answer = int(1e9)
def dfs(x, y, dist, graph):
  global answer, visited
  if graph == target:
    answer = min(answer, dist)
    return
  if dist > 100:
    return
  for i in range(4):
    nx = x + dx[i]
    ny = y + dy[i]
    if nx>=0 and nx<3 and ny>=0 and ny<3:
      graph2 = copy.deepcopy(graph)
      graph2[x][y] = graph[nx][ny]
      graph2[nx][ny] = 0
      vgraph = "".join(map(str, graph2))
      if vgraph not in visited:
        visited.add(vgraph)
        dfs(nx, ny, dist+1, graph2)

for i in range(3):
  for j in range(3):
    if graph[i][j] == 0:
      dfs(i, j, 0, graph)

print(answer)

풀이방법을 살펴보니, 2차원을 그대로 사용하는 것이 코드의 복잡도를 야기하는 듯 하다. 하지만 나는 코드를 최대한 유지하면서 수정하려고 q에 들어가는 그래프는 2차원을 유지하고 visited에서만 string으로 바꾸어서 넣어 주었다.
그리고, bfs에서 그래프 모양이 계속 바뀌는 문제점을 q에 아예 그래프 자체를 넣어주는 방식으로 해결하였다. 그리고 visited를 딕셔너리 형태로 받아서 최소 횟수를 저장할 수 있도록 하였다.

from collections import deque
import copy
import sys
input = sys.stdin.readline

graph = []
array = []

for _ in range(3):
  a, b, c = map(int, input().split())
  graph.append([a, b, c])
  array.append(a)
  array.append(b)
  array.append(c)
array.sort()
target = [[0]*3 for _ in range(3)]

for i in range(8):
  target[i//3][i%3] = array[i+1]

q = deque()
q.append(graph)

visited = { "".join(map(str, graph)): 0}
#visited.add("".join(map(str, graph)))

dx = [1,-1,0,0]
dy = [0,0,1,-1]

def bfs():
  while q:
    ngraph = q.popleft()
    cnt = visited["".join(map(str, ngraph))]
    if ngraph == target:
      return cnt
    x = 0
    y = 0
    for i in range(3):
      for j in range(3):
        if ngraph[i][j] == 0:
          x = i
          y = j
          break
    for i in range(4):
      nx = x + dx[i]
      ny = y + dy[i]
      if nx>=0 and nx<3 and ny>=0 and ny<3:
        #ngraph[x][y], ngraph[nx][ny] = ngraph[nx][ny], ngraph[x][y]     
        graph2 = copy.deepcopy(ngraph)
        graph2[x][y] = ngraph[nx][ny]
        graph2[nx][ny] = 0

        vgraph = "".join(map(str, graph2))
        if visited.get(vgraph, 0) == 0:
          visited[vgraph] = cnt + 1
          q.append(graph2)
          #print(q)
        #ngraph[x][y], ngraph[nx][ny] = ngraph[nx][ny], ngraph[x][y]
  return -1
print(bfs())

하지만 2차원을 유지하였더니, 메모리 초과가 났다. string이 리스트보다 메모리를 덜 차지하는 듯하다. 최종 답안은 q와 visited 모두에서 2차원 리스트를 차례로 늘여놓은 string 방식으로 넣어주었다.

💯 정답

from collections import deque
import sys
input = sys.stdin.readline

graph = ""

for _ in range(3):
  a, b, c = map(int, input().split())
  graph += str(a)
  graph += str(b)
  graph += str(c)

q = deque()
q.append(graph)

visited = {graph: 0} # 딕셔너리 형태 (그래프 모양: 최소 횟수)

target = "123456780"

dx = [1,-1,0,0]
dy = [0,0,1,-1]

def bfs():
  while q:
    ngraph = q.popleft()
    cnt = visited[ngraph]
    if ngraph == target:
      return cnt
    zero = ngraph.index('0') # 0이 위치해 있는 인덱스 값
    x = zero//3
    y = zero%3
    
    for i in range(4):
      nx = x + dx[i]
      ny = y + dy[i]
      if nx>=0 and nx<3 and ny>=0 and ny<3:
        nzero = nx*3 + ny
        ngraph_list = list(ngraph)
        ngraph_list[zero], ngraph_list[nzero] = ngraph_list[nzero], ngraph_list[zero]
        str_ng_list = "".join(ngraph_list)
        if visited.get(str_ng_list, 0) == 0:
          visited[str_ng_list] = cnt + 1
          #print(visited)
          q.append(str_ng_list)
  return -1
print(bfs())
profile
개발자
post-custom-banner

0개의 댓글