[파이썬]백준 14500 테트로미노

Byeonghyeon Kim·2021년 5월 13일
2

알고리즘문제

목록 보기
78/93
post-thumbnail

링크

백준 14500 테트로미노


처음엔 좌표마다 가능한 모양을 전부다 좌표로 계산해서 때려넣으려 했는데.. 물론 그렇게 풀 수도 있지만 정말 그런식으로 푸는건 별로 안좋아한다.
머리를 쓰는게 아니라 손가락만 겁나게 아프기 때문

결국 아이디어를 얻지 못해 살짝 검색을 했다.(코드는 안봤다)
특정 좌표로부터 dfs를 거리 3까지만 돌게되면 ㅗ모양을 제외한 모든 테트로미노를 만들 수 있다.
(왜그런지는 직접 그려보면 금새 알 수 있다.)

그 후 ㅗ 모양으로 탐색하도록 함수를 만들어 주었다.
중앙의 블록을 기준으로 삼아서 4방향을 전체 다 만들어 十모양으로 탐색 한 후 사방향중 한방향씩 빼가며 모양을 만들어 주었다.
만약 범위을 넘어서 3방향 밖에 탐색을 못했을 경우엔 바로 답과 비교를 해줬다.


이렇게 하니 풀긴풀었는데.. 무려 7700ms... ㅋㅋㅋㅋㅋㅋㅋㅋ
결국 다른사람의 코드를 참고하고나서야 최적화된 코드를 짤 수 있었다.

먼저 기존코드에서 ㅗ모양을 따로 탐색했던 것을 dfs에 포함시켰다.

idx = 1 일 때(즉, 두개의 블럭을 선택했을 때) 새로운 블럭에서 다음 블럭을 탐색하는 것이 아니라 다시 기존블럭에서 탐색하게 만들면 ㅗ모양을 만들 수 있다.
(이것도 조금만 생각해보면 알 수 있다. 절대 그리기 귀찮아서 글로 때우는거 아님)

그 후 가지치기 하는 코드를 추가했다.
여태 백트래킹을 풀 때 최소값을 구하는 경우는 쉽게 가지치기를 했는데 최대값을 구하는 경우엔 가지치기를 해본적이 없어서 생각도 못했다.
종이(2차원배열)에서 최대값을 찾아서 max(map(max, arr))
선택할 수 있는 남은 블럭의 갯수만큼 (3 - idx) 곱해주고
현재 누적합 total 에 더해서 ans와 비교해주는 방식으로 가지치기를 했다.

이렇게 가지치기를 하고 중복되는 코드를 합쳐주니 시간을 288ms로 극적으로 줄일 수 있었다.


정답 코드 / 7700ms

import sys; input = sys.stdin.readline

def dfs(r, c, idx, total):
    global ans
    if idx == 3:
        if total > ans:
            ans = total
    else:
        for i in range(4):
            nr = r + dr[i]
            nc = c + dc[i]
            if 0 <= nr < N and 0 <= nc < M:
                if visit[nr][nc] == 0:
                    visit[nr][nc] = 1
                    dfs(nr, nc, idx + 1, total + arr[nr][nc])
                    visit[nr][nc] = 0

def block(r, c, total):
    global ans
    make_block = 0
    for i in range(4):
        nr = r + dr[i]
        nc = c + dc[i]
        if 0 <= nr < N and 0 <= nc < M:
            make_block += 1
            total += arr[nr][nc]

    if make_block == 3:
        if total > ans:
            ans = total

    if make_block == 4:
        for i in range(4):
            nr = r + dr[i]
            nc = c + dc[i]
            total -= arr[nr][nc]
            if total > ans:
                ans = total
            total += arr[nr][nc]


N, M = map(int, input().split())
arr = [list(map(int, input().split())) for _ in range(N)]
visit = [([0] * M) for _ in range(N)]
dr = [-1, 0, 1, 0]
dc = [0, 1, 0, -1]
ans = 0

for r in range(N):
    for c in range(M):
        visit[r][c] = 1
        dfs(r, c, 0, arr[r][c])
        block(r, c, arr[r][c])
        visit[r][c] = 0

print(ans)

정답 코드 / 288ms

import sys; input = sys.stdin.readline

def dfs(r, c, idx, total):
    global ans
    if ans >= total + max_val * (3 - idx):
        return
    if idx == 3:
        ans = max(ans, total)
        return
    else:
        for i in range(4):
            nr = r + dr[i]
            nc = c + dc[i]
            if 0 <= nr < N and 0 <= nc < M and visit[nr][nc] == 0:
                if idx == 1:
                    visit[nr][nc] = 1
                    dfs(r, c, idx + 1, total + arr[nr][nc])
                    visit[nr][nc] = 0
                visit[nr][nc] = 1
                dfs(nr, nc, idx + 1, total + arr[nr][nc])
                visit[nr][nc] = 0


N, M = map(int, input().split())
arr = [list(map(int, input().split())) for _ in range(N)]
visit = [([0] * M) for _ in range(N)]
dr = [-1, 0, 1, 0]
dc = [0, 1, 0, -1]
ans = 0
max_val = max(map(max, arr))

for r in range(N):
    for c in range(M):
        visit[r][c] = 1
        dfs(r, c, 0, arr[r][c])
        visit[r][c] = 0

print(ans)

알게된 것👨‍💻

  • 현재 상태에서 만들 수 있는 최대값을 이용해서 최대값을 구할 때도 가지치기를 할 수 있다.
  • dfs를 응용하면 여러가지 모양으로 탐색하도록 만들 수 있다.
profile
자기 주도 개발전 (개발, 발전)

1개의 댓글

comment-user-thumbnail
2022년 3월 9일

아주좋소!

답글 달기