재귀 함수 익히기 - 별 찍기

Hyeon·2022년 9월 28일
0

코딩테스트

목록 보기
26/44
post-thumbnail
  1. BOJ 2447 별 찍기 - 10
  1. BOJ 10994 별 찍기 - 19
  1. BOJ 18290 NM과 K (1)

1. BOJ 2447 별 찍기 - 10

문제 : BOJ 2447 별 찍기 - 10

재귀 함수 부분이 아직 부족한 것 같아서
BOJ 알고리즘 분류 중 재귀 카테고리의 문제를 풀고 있다.

시행 착오

'*'' ' 이 나타나는 row, col값을
재귀를 반복해서 0 또는 나머지가 1이 나오는 숫자로 만들었다.

각 행과 열을 모두 대입해야 하기 때문에
2중 for문으로 row, col값을 전달해주었고,
테스트 결과 정답은 출력되나, 시간 초과가 나왔다.

import sys

n = int(sys.stdin.readline().rstrip())


def sol(r, c):
    if r % 3 == 1 and c % 3 == 1:
        return ' '
    if r == 0 or c == 0:
        return '*'
    return sol(r // 3, c // 3)


for i in range(n):
    for j in range(n):
        sys.stdout.write(f'{sol(i, j)}')
    sys.stdout.write('\n')

코드

결국 풀지 못해서 다른 분들의 코드를 참고했다.

n=1 이라면, 기본 단위인

***
* *
***

을 return하고

그렇지 않으면,
분할 후 이전 단계의 별 패턴을 이용해서 새로운 별을 그린 뒤 return 해준다.

N = int(input())

def draw_star(n):
    if n == 3:
        return ['***', '* *', '***']
    
    sub_star = draw_star(n//3)
    new_star = []
    for star in sub_star:
        new_star.append(star * 3)
    for star in sub_star:
        new_star.append(star + ' '*(n//3) +star)
    for star in sub_star:
        new_star.append(star * 3)

    return new_star

star = draw_star(N)
for i in star:
    print(i)

익숙한 느낌

그런데 코드를 보다보니 어디서 본 구조다.

def merge_sort(arr, start, end):
    if end == start:
        return [arr[start]]

    mid = (start + end) // 2
    left_arr = merge_sort(arr, start, mid)
    right_arr = merge_sort(arr, mid + 1, end)

    l_idx = 0
    r_idx = 0
    m_idx = 0
    merged_arr = [0] * (len(left_arr) + len(right_arr))

    while l_idx < len(left_arr) and r_idx < len(right_arr):
        if left_arr[l_idx] < right_arr[r_idx]:
            merged_arr[m_idx] = left_arr[l_idx]
            l_idx += 1
        else:
            merged_arr[m_idx] = right_arr[r_idx]
            r_idx += 1
        m_idx += 1

    right_arr = right_arr if r_idx < len(right_arr) else left_arr
    r_idx = r_idx if r_idx < len(right_arr) else l_idx

    while r_idx < len(right_arr):
        merged_arr[m_idx] = right_arr[r_idx]
        r_idx += 1
        m_idx += 1

    return merged_arr

Merge Sort를 구현한 함수의 형태와 유사했다.😮

Merge Sort의 경우,
길이가 1인 최소 단위의 배열이 병합될 때와
길이가 1보다 큰 정렬된 배열이 병합될 때 모두 정렬을 유지할 수 있기 때문에
최소 단위로 분할 후 병합하는 방법을 이용할 수 있었다.

이번 문제 별찍기 - 10의 경우에는
애초에 재귀적인 패턴으로 별을 찍는데,
이전 단계(n//3)의 패턴을 이용해서 현재(n)의 별을 그려주어야 하기 때문에
이러한 방법을 이용할 수 있다.

공통점을 찾아보자면 다음과 같다.

  1. n이 최소 값에 도달하면 기본 단위를 return
  1. 최소 값이 아니라면 재귀 호출
  1. 부분 단위를 전달 받아 조건에 맞게 처리
  1. 처리된 데이터를 return

아직 재귀 문제를 능숙하게 다룰 수 없음을 이번 별 찍기 문제를 풀며 느꼈다.
문제를 바라보는 직관을 키우기 위한 노력이 좀 더 필요하다.


2. BOJ 10994 별 찍기 - 19

문제 : BOJ 10994 별 찍기 - 19

위 문제에서 알게 된 재귀 함수의 패턴을 응용하여 풀었다.
난이도는 조금 낮은 편이나, 유사한 방법으로 풀 수 있는 문제이다.

규칙 찾기

예제를 보고 규칙을 유추하여 풀어야 한다.
먼저, 각 입력에 대한 예제 출력 부터 확인하자.

입력값을 n이라고 할 때,
높이(n) = 높이(n-1) + 4
넓이(n) = 넓이(n-1) + 4
의 규칙을 갖고, 이전 사각형을 안쪽에 포함하는 그림이 그려지고 있다.

구현 하기

별찍기 10번과 마찬가지로,
최소 단위로 분할하여 별을 그리는 방법으로 구현해보자

먼저, n == 1 일 때가 가장 작은 값 '*' 이므로,
n이 1이되면 ['*'] 을 return 해준다.

def draw_star(n):
    if n == 1:
        return ['*']

그리고 마찬가지로 최소값(n==1) 이 아닐 경우
n을 1감소시켜 재귀 호출을 통해 (n-1)의 별을 전달받는다.
새로 그려줄 별을 저장할 배열도 선언해주자.

def draw_star(n):
    if n == 1:
        return ['*']
    sub_star = draw_star(n-1)
   	new_star = []

그 다음, return 받은 (n-1)별 그림을 이용해서 (n)별 그림을 그려주자

위와 아래의 길이가 커지는 것은 2개의 문자열을 직접 그려 append 할 수 있고
넓이의 경우 예제에서 충분히 일반항을 도출 할 수 있다.

[ 그림의 넓이 ]

n1234...
넓이15913...

n에 대한 그림의 넓이는 초항이 1, 공차가 4인 등차 수열 이므로
넓이의 일반항은
4×(n1)+14 \times (n-1) + 1
=4n3= 4n-3
이다.

첫번째 줄과 마지막 줄의 경우에는 넓이 만큼의 별을 그려주고
두번째 줄과 마지막-1번째 줄은 각각 양 옆에 별을 그린 뒤 나머지 넓이만큼 공백으로 채워준다.

def draw_star(n):
    if n == 1:
        return ['*']
    sub_star = draw_star(n-1)
   	new_star = []
    
    new_star.append('*'*(4*n-3))		# 첫째 줄
    new_star.append('*' + ' ' * (4*n-5) + '*')	# 둘째 줄
    
    
    new_star.append('*' + ' ' * (4*n-5) + '*')	# 마지막-1번째 줄	
    new_star.append('*'*(4*n-3))	# 마지막 줄

그리고 이 가운데 들어가는 그림은
재귀 호출로 전달받은 sub_star를 이용해 그려주어야 한다.

그림과 동일하게 출력 하기 위해서는
양 옆을 '* ' ' *'으로 둘러 쌓아야 한다.

그리고 다 그려준 그림을 return 해주며
재귀 호출을 통해 (n-1)의 그림을 받을 수 있게 해주자

def draw_star(n):
    if n == 1:
        return ['*']
    sub_star = draw_star(n-1)
   	new_star = []
    
    new_star.append('*'*(4*n-3))				# 첫째 줄
    new_star.append('*' + ' ' * (4*n-5) + '*')	# 둘째 줄
    
    # 가운데
    for i in sub_star:
        new_star.append('* ' + i + ' *')
        
    new_star.append('*' + ' ' * (4*n-5) + '*')	# 마지막-1번째 줄	
    new_star.append('*'*(4*n-3))				# 마지막 줄
    
    return new_star

코드

[ 전체 코드 ]

n = int(input())

def draw_star(n):
    if n == 1:
        return ['*']
    sub_star = draw_star(n-1)
    new_star = []
    
    new_star.append('*'*(4*n-3))
    new_star.append('*' + ' ' * (4*n-5) + '*')
    
    for i in sub_star:
        new_star.append('* ' + i + ' *')
        
    new_star.append('*' + ' ' * (4*n-5) + '*')
    new_star.append('*'*(4*n-3))
    
    return new_star

star = draw_star(n)
for i in star:
    print(i)

3. BOJ 18290 NM과 K (1)

문제 : BOJ 18290 NM과 K (1)

N ×\times M 크기의 격자(2차원 배열)을 입력 받아
K개의 칸을 선택해 각 칸에 있는 숫자의 최대 합을 구하는 문제

( 10000-10000 \le 칸에 있는 숫자 10000\le 10000 )

시행 착오

격자의 각 칸을 K개 조합하여 나올 수 있는 최대 합을 갱신해 주어야 하므로,
최대합을 갱신받을 변수와 호출간 최대합을 전달할 파라미터가 필요하다.
먼저 탈출 조건부터 작성해 주었다.

import sys

N, M, K = map(int, sys.stdin.readline().split())

grid = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]

max_num = -987654321000

def solution(depth, sum):
	global max_num
	if depth == K:
    	max_num = max(max_num, sum)
    	return

격자는 행과 열을 가지므로, 각 행과 열을 순회하고 선택해주어야 한다고 생각했다.

또한 각 칸의 사용 여부를 확인해야 하기 때문에,
입력 받은 격자와 동일한 크기의 N ×\times M 배열을 선언하고,
사용된 칸이라면 1로, 사용되지 않은 칸이라면 0으로 갱신해준다.

import sys

N, M, K = map(int, sys.stdin.readline().split())

grid = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
visited = [[0] * M for _ in range(N)]

max_num = -987654321000

def solution(depth):
	global max_num
	if depth == K:
    	max_num = max(max_num, sum)
    	return
    for i in range(N):
    	for j in range(M):
        	if visited[i][j] == 0:
            	visited[i][j] = 1
                solution(depth+1)
                visited[i][j] = 0
        	

그런데 재귀 호출 전에 인접한 칸을 방문했는지 확인해주어야 하기 때문에,
인접한 칸의 확인을 위한 배열move을 선언하고

  1. 만약 인접한 칸을 방문한 적이 있다면 방문을 취소하고 반복문을 다시 진행한다.
  1. 인접한 칸을 방문한 적이 없다면, 방문처리 후 재귀 호출을 진행한다.
import sys

N, M, K = map(int, sys.stdin.readline().split())

grid = [list(map(int, sys.stdin.readline().split())) for _ in range(N)]
visited = [[0] * M for _ in range(N)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]

max_num = -987654321000

def solution(depth, result):
	global max_num
	if depth == K:
    	max_num = max(max_num, result)
    	return
    
    for i in range(N):
        for j in range(M):
            if visited[i][j] == 0:
                flag = True
                for x, y in move:
                    if 0 <= x+i < N and 0 <= y+j < M:
                        if visited[x+i][y+j] == 1:
                            flag = False
                            break
                if flag is True:
                    visited[i][j] = 1
                    solution(depth+1, result+grid[i][j])
                    visited[i][j] = 0

solution(0, 0)
print(f'{max_num}')

결과 : Python3 시간초과

불필요한 탐색을 줄이자

불필요한 탐색을 어떻게 줄여야 할까.

우선 이 문제는 조합 문제이다.
[1, 2, 3] 이나 [1, 3, 2] 나 합은 6으로 동일 하기 때문에
같은 칸에 대해서 다른 순서로 탐색해줄 필요가 없다.

따라서, 1차원 배열의 탐색에서처럼
재귀 호출시 이전 탐색 위치를 파라미터로 넘겨주어
그 이후부터 탐색 하도록 코드를 변경해주어야 한다.

[ 1차원 배열의 조합 예시 ]

n, m = map(int, input().split())
used = [0] * (n+1)

# solution()의 파라미터 start에는 이전에 탐색한 index가 저장된다.
def solution(start, depth, result):
    if depth == m:
        print(result)
        return 
    # for문의 초기값을 start로 지정하여
    # start 이후의 값만 탐색하도록 구현했다.
    for i in range(start+1, n+1):
        if used[i] == 0:
            used[i] = 1
            solution(i, depth+1, result + str(i)+' ')
            used[i] = 0

solution(0, 0, '')

2차원 배열 -> 1차원 배열

2차원 배열의 이전 탐색 위치를 특정하기 위해 사용한 방법은 다음과 같다.

N ×\times M 크기의 2차원 배열을
N ×\times M 길이의 1차원 배열로 만들어 준다.

아래와 같은 N=4, M=5인 2차원 배열을 예로 들겠다.
각 배열안의 숫자는 0행 0열부터 3행 4열까지 순서대로 매긴것이다.

\01234
001234
156789
21011121314
31516171819

배열 안의 숫자가 13인 칸의 행과 열은 (2, 3)이다.

13은 길이가 5인 행을 2번 지난뒤 열을 3칸 이동한 것과 같으므로
이렇게 표현할 수 있다.
13=2×5+313=2 \times5 + 3

다른 칸도 동일하게 표현할 수 있다.
따라서 행이 R 열이 C인 칸의 순서를 표현하는 방법은 아래와 같다.

순서 = R ×\times 전체 열의 길이 ++ C

R = 순서 // 전체 열의 길이

C = 순서 % 전체 열의 길이

위 식에 따라
2차원 배열인 visited도 1차원 배열로 표현 가능하며
"이전 탐색 위치(r, c)" 를 1차원 배열 하나의 Index로 특정할 수 있다.

import sys
n, m, k = map(int, sys.stdin.readline().split())

nums = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]
visited = [0] * (n*m)

max_num = -9876543210000

def solution(depth, result, pre):
    global max_num

    if depth == k:
        max_num = max(max_num, result)
        return
    
    for i in range(pre, m * n):
        if visited[i] == 0:
            row = i // m
            col = i % m
            flag = True
            for x, y in move:
                if 0 <= (x+row) < n and 0 <= (y+col) < m:
                    if visited[(x+row)*m+y+col] == 1:
                        flag = False
                        break
            if flag is True:
                visited[i] = 1
                solution(depth+1, result+nums[row][col], i+1)
                visited[i] = 0


solution(0, 0, 0)
sys.stdout.write(f'{max_num}')

결과 : Python3 성공
30840kb 4748ms

그래도 재귀는 돈다

2차원 배열을 1차원 배열로 완벽하게 1대1 대응하는 방법을 생각해 낸 기쁨도 잠시,
여러번 시도했던 내 채점 직전에 올라온 옆사람의 채점 결과를 보게된다.

?


문제의 조건을 잘 읽어야지

문제에 이런 조건이 있다.

( 10000-10000 \le 칸에 있는 숫자 10000\le 10000 )

칸에 존재할 수 있는 최대값은 10000 이다.

우리는 K개 칸의 값의 최대합을 찾기위해 재귀를 반복한다.
그런데 현재 까지 선택된 칸의 합으로
더 이상 재귀를 진행하지 않아도 되는것을 판단할 수 있다면?

시간 더 줄이기

위에서 작성한 solution() 함수의 파라미터를 다시 가져와 보자

def solution(depth, result, pre):
  • depth는 재귀 호출의 깊이로, 이 문제에서는 현재까지 선택한 칸의 갯수이다.
  • result 는 이제까지 선택한 칸들의 합을 저장하고 전달해주는 변수이다.
  • 각 칸의 최대값은 10000이고,
  • K개 합의 최대값을 갱신하고 있는 변수는 max_num 이다.

그래서, max_num > 10000 ×\times ( K-depth ) ++ result 이면

더이상의 탐색은 max_num의 값을 새롭게 갱신할 수 없다. (무의미 하다)


[ 전체 코드 ]

import sys

n, m, k = map(int, sys.stdin.readline().split())

nums = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
move = [[0, 1], [1, 0], [0, -1], [-1, 0]]
visited = [0] * (n*m)

max_num = -9876543210000

def solution(depth, result, pre):
    global max_num

    if depth == k:
        max_num = max(max_num, result)
        return
    
    # 마법의 탈출
    if result + 10000 * (k-depth) < max_num:
        return
    
    for i in range(pre, m * n):
        if visited[i] == 0:
            row = i // m
            col = i % m
            flag = True
            for x, y in move:
                if 0 <= (x+row) < n and 0 <= (y+col) < m:
                    if visited[(x+row)*m+y+col] == 1:
                        flag = False
                        break
            if flag is True:
                visited[i] = 1
                solution(depth+1, result+nums[row][col], i+1)
                visited[i] = 0


solution(0, 0, 0)
sys.stdout.write(f'{max_num}')

결과 : Python3 성공
30840kb 80ms

Special thanks to jnl1128

profile
그럼에도 불구하고

0개의 댓글