
일반적인 dfs문제로 1인 노드를 탐색해서 연결된 영역의 개수를 카운트하는 문제이다. 일반적으로 이와 같은 문제는 예제로 전체 맵에 대한 입력으로 하는 반면 해당 문제는 1인 노드의 좌표 값을 입력으로 받는다. 그래서 입력받은 맵의 크기만큼 모든 값이 0인 2차원 배열을 선언한 후 해당 입력 받은 좌표 값을 1로 바꿔준 후에 로직을 수행하도록 접근했다.
다음과 같이 코드를 작성했는데 예제 입력에 대한 출력이 제대로 되지 않는 것을 확인했다. 디버그 모드로 실행해서 코드에서 문제가 되는 부분을 찾았는데 다음 부분에서 문제가 발생했다.
코드에서 모든 값인 0인 2차원 배열 선언을 한 후에 배열에서 입력받은 좌표 값을 1로 바꿔주었다.
graph = [[0] * M] * N
for j in range(K):
x, y = map(int, input().split())
graph[y][x] = 1
해당 로직을 수행 후에 graph를 출력해봤는데,, 다음과 같이 출력이 되는 것을 확인했다.
[
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 1, 1, 1]
]
graph[0][0] = 1 로 바꿔서 확인해보니 문제를 파악할 수 있었다. 나는 첫번째 배열의 첫번째 요소의 값만 바꿧지만 모든 배열의 첫번째 요소가 1로 바뀐 것을 볼 수 있다.
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]
결론은 위와 같이 2차원 배열을 선언할 경우 list의 얕은 복사가 발생한다. 즉 [0] * M] 길이의 list가 얕은 복사로 단순히 N개 생기는 것. 얕은 복사는 객체를 참조하기 때문에 하나의 값만 바뀌어도 나머지가 전부 바뀌게 된다.(for문 쓰는 것을 피하려다 아주 기본적인 부분에서 놓쳤다.. )
# 유기농 배추
T = int(input())
dx = [-1, 1, 0, 0]
dy = [0, 0, -1, 1]
for i in range(T):
M, N, K = map(int, input().split())
graph = [[0] * M] * N
for j in range(K):
x, y = map(int, input().split())
graph[y][x] = 1
def dfs(x, y):
graph[x][y] = 0
for k in range(4):
nx = dx[k] + x
ny = dy[k] + y
if 0 <= nx < N and 0 <= ny < M and graph[nx][ny] == 1:
dfs(nx, ny)
result = 0
for y in range(N):
for m in range(M):
if graph[y][m] == 1:
dfs(y, m)
result += 1
print(result)
2차원 배열을 다음과 같이 선언해서 얕은 복사가 일어나지 않게 했고 추가적으로 [런타임 에러 (RecursionError)] 가 발생하는 것을 확인해서 코드를 추가해줬다.
import sys
sys.setrecursionlimit(10**6)
...
# graph = [[0] * M] * N
graph = [[0 for col in range(M)] for row in range(N)]
for j in range(K):
x, y = map(int, input().split())
graph[y][x] = 1
# 유기농 배추
import sys
sys.setrecursionlimit(10**6)
T = int(input())
dx = [-1, 1, 0, 0]
dy = [0, 0, -1, 1]
for i in range(T):
M, N, K = map(int, input().split())
graph = [[0 for col in range(M)] for row in range(N)]
for j in range(K):
x, y = map(int, input().split())
graph[y][x] = 1
def dfs(x, y):
graph[x][y] = 0
for k in range(4):
nx = dx[k] + x
ny = dy[k] + y
if 0 <= nx < N and 0 <= ny < M and graph[nx][ny] == 1:
dfs(nx, ny)
result = 0
for y in range(N):
for m in range(M):
if graph[y][m] == 1:
dfs(y, m)
result += 1
print(result)