문제 : BOJ 18290 NM과 K (1)
N M 크기의 격자(2차원 배열)을 입력 받아
K개의 칸을 선택해 각 칸에 있는 숫자의 최대 합을 구하는 문제( 칸에 있는 숫자 )
격자의 각 칸을 K개 조합하여 나올 수 있는 최대 합을 갱신해 주어야 하므로,
최대합을 갱신받을 변수와 호출간 최대합을 전달할 파라미터가 필요하다.
먼저 탈출 조건부터 작성해 주었다.
import sys
N, M, K = map(int, sys.stdin.readline().split())
grid = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
max_num = -987654321000
def solution(depth, sum):
global max_num
if depth == K:
max_num = max(max_num, sum)
return
격자는 행과 열을 가지므로, 각 행과 열을 순회하고 선택해주어야 한다고 생각했다.
또한 각 칸의 사용 여부를 확인해야 하기 때문에,
입력 받은 격자와 동일한 크기의 N M 배열을 선언하고,
사용된 칸이라면 1로, 사용되지 않은 칸이라면 0으로 갱신해준다.
import sys
N, M, K = map(int, sys.stdin.readline().split())
grid = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
visited = [[0] * M for _ in range(N)]
max_num = -987654321000
def solution(depth):
global max_num
if depth == K:
max_num = max(max_num, sum)
return
for i in range(N):
for j in range(M):
if visited[i][j] == 0:
visited[i][j] = 1
solution(depth+1)
visited[i][j] = 0
그런데 재귀 호출 전에 인접한 칸을 방문했는지 확인해주어야 하기 때문에,
인접한 칸의 확인을 위한 배열move
을 선언하고
import sys
N, M, K = map(int, sys.stdin.readline().split())
grid = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
visited = [[0] * M for _ in range(N)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]
max_num = -987654321000
def solution(depth, result):
global max_num
if depth == K:
max_num = max(max_num, result)
return
for i in range(N):
for j in range(M):
if visited[i][j] == 0:
flag = True
for x, y in move:
if 0 <= x+i < N and 0 <= y+j < M:
if visited[x+i][y+j] == 1:
flag = False
break
if flag is True:
visited[i][j] = 1
solution(depth+1, result+grid[i][j])
visited[i][j] = 0
solution(0, 0)
print(f'{max_num}')
결과 : Python3 시간초과
불필요한 탐색을 어떻게 줄여야 할까.
우선 이 문제는 조합 문제이다.
[1, 2, 3] 이나 [1, 3, 2] 나 합은 6으로 동일 하기 때문에
같은 칸에 대해서 다른 순서로 탐색해줄 필요가 없다.
따라서, 1차원 배열의 탐색에서처럼
재귀 호출시 이전 탐색 위치를 파라미터로 넘겨주어
그 이후부터 탐색 하도록 코드를 변경해주어야 한다.
n, m = map(int, input().split())
used = [0] * (n+1)
# solution()의 파라미터 start에는 이전에 탐색한 index가 저장된다.
def solution(start, depth, result):
if depth == m:
print(result)
return
# for문의 초기값을 start로 지정하여
# start 이후의 값만 탐색하도록 구현했다.
for i in range(start+1, n+1):
if used[i] == 0:
used[i] = 1
solution(i, depth+1, result + str(i)+' ')
used[i] = 0
solution(0, 0, '')
2차원 배열의 이전 탐색 위치를 특정하기 위해 사용한 방법은 다음과 같다.
N M 크기의 2차원 배열을
N M 길이의 1차원 배열로 만들어 준다.
아래와 같은 N=4, M=5인 2차원 배열을 예로 들겠다.
각 배열안의 숫자는 0행 0열부터 3행 4열까지 순서대로 매긴것이다.
\ | 0 | 1 | 2 | 3 | 4 |
---|---|---|---|---|---|
0 | 0 | 1 | 2 | 3 | 4 |
1 | 5 | 6 | 7 | 8 | 9 |
2 | 10 | 11 | 12 | 13 | 14 |
3 | 15 | 16 | 17 | 18 | 19 |
배열 안의 숫자가 13
인 칸의 행과 열은 (2, 3)이다.
이 13
은 길이가 5인 행을 2번 지난뒤 열을 3칸 이동한 것과 같으므로
이렇게 표현할 수 있다.
다른 칸도 동일하게 표현할 수 있다.
따라서 행이 R 열이 C인 칸의 순서를 표현하는 방법은 아래와 같다.
순서 = R 전체 열의 길이 C
R = 순서 // 전체 열의 길이
C = 순서 % 전체 열의 길이
위 식에 따라
2차원 배열인 visited
도 1차원 배열로 표현 가능하며
"이전 탐색 위치(r, c)" 를 1차원 배열 하나의 Index로 특정할 수 있다.
import sys
n, m, k = map(int, sys.stdin.readline().split())
nums = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]
visited = [0] * (n*m)
max_num = -9876543210000
def solution(depth, result, pre):
global max_num
if depth == k:
max_num = max(max_num, result)
return
for i in range(pre, m * n):
if visited[i] == 0:
row = i // m
col = i % m
flag = True
for x, y in move:
if 0 <= (x+row) < n and 0 <= (y+col) < m:
if visited[(x+row)*m+y+col] == 1:
flag = False
break
if flag is True:
visited[i] = 1
solution(depth+1, result+nums[row][col], i+1)
visited[i] = 0
solution(0, 0, 0)
sys.stdout.write(f'{max_num}')
결과 : Python3 성공
30840kb 4748ms
2차원 배열을 1차원 배열로 완벽하게 1대1 대응하는 방법을 생각해 낸 기쁨도 잠시,
여러번 시도했던 내 채점 직전에 올라온 옆사람의 채점 결과를 보게된다.
?
문제에 이런 조건이 있다.
( 칸에 있는 숫자 )
칸에 존재할 수 있는 최대값은 10000 이다.
우리는 K개 칸의 값의 최대합을 찾기위해 재귀를 반복한다.
그런데 현재 까지 선택된 칸의 합으로
더 이상 재귀를 진행하지 않아도 되는것을 판단할 수 있다면?
위에서 작성한 solution()
함수의 파라미터를 다시 가져와 보자
def solution(depth, result, pre):
depth
는 재귀 호출의 깊이로, 이 문제에서는 현재까지 선택한 칸의 갯수이다.result
는 이제까지 선택한 칸들의 합을 저장하고 전달해주는 변수이다.10000
이고,K
개 합의 최대값을 갱신하고 있는 변수는 max_num
이다.그래서, max_num
> 10000
( K
depth
) result
이면
더이상의 탐색은 max_num
의 값을 새롭게 갱신할 수 없다. (무의미 하다)
[ 전체 코드 ]
import sys
n, m, k = map(int, sys.stdin.readline().split())
nums = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]
visited = [0] * (n*m)
max_num = -9876543210000
def solution(depth, result, pre):
global max_num
if depth == k:
max_num = max(max_num, result)
return
# 마법의 탈출
if result + 10000 * (k-depth) < max_num:
return
for i in range(pre, m * n):
if visited[i] == 0:
row = i // m
col = i % m
flag = True
for x, y in move:
if 0 <= (x+row) < n and 0 <= (y+col) < m:
if visited[(x+row)*m+y+col] == 1:
flag = False
break
if flag is True:
visited[i] = 1
solution(depth+1, result+nums[row][col], i+1)
visited[i] = 0
solution(0, 0, 0)
sys.stdout.write(f'{max_num}')
결과 : Python3 성공
30840kb 80ms
Special thanks to jnl1128