백준 16637번 괄호 추가하기 (Python, 구현, 브루트포스, DFS, Gold3)

전승재·2023년 9월 2일
0

알고리즘

목록 보기
35/88

백준 16637번 괄호 추가하기 문제 바로가기

문제 이해

문제

길이가 N인 수식이 있다. 수식은 0보다 크거나 같고, 9보다 작거나 같은 정수와 연산자(+, -, ×)로 이루어져 있다. 연산자 우선순위는 모두 동일하기 때문에, 수식을 계산할 때는 왼쪽에서부터 순서대로 계산해야 한다. 예를 들어, 3+8×7-9×2의 결과는 136이다.

수식에 괄호를 추가하면, 괄호 안에 들어있는 식은 먼저 계산해야 한다. 단, 괄호 안에는 연산자가 하나만 들어 있어야 한다. 예를 들어, 3+8×7-9×2에 괄호를 3+(8×7)-(9×2)와 같이 추가했으면, 식의 결과는 41이 된다. 하지만, 중첩된 괄호는 사용할 수 없다. 즉, 3+((8×7)-9)×2, 3+((8×7)-(9×2))은 모두 괄호 안에 괄호가 있기 때문에, 올바른 식이 아니다.

수식이 주어졌을 때, 괄호를 적절히 추가해 만들 수 있는 식의 결과의 최댓값을 구하는 프로그램을 작성하시오. 추가하는 괄호 개수의 제한은 없으며, 추가하지 않아도 된다.

입력

첫째 줄에 수식의 길이 N(1 ≤ N ≤ 19)가 주어진다. 둘째 줄에는 수식이 주어진다. 수식에 포함된 정수는 모두 0보다 크거나 같고, 9보다 작거나 같다. 문자열은 정수로 시작하고, 연산자와 정수가 번갈아가면서 나온다. 연산자는 +, -, 중 하나이다. 여기서 는 곱하기 연산을 나타내는 × 연산이다. 항상 올바른 수식만 주어지기 때문에, N은 홀수이다.

출력

첫째 줄에 괄호를 적절히 추가해서 얻을 수 있는 결과의 최댓값을 출력한다. 정답은 231보다 작고, -231보다 크다.

문제 접근

이 문제를 보고 처음에 든 생각은 구현문제라고 생각했다.
그래서 3단계로 나누어서 진행하면 되겠다고 생각하고 아래와 같이 3단계로 나눴다.

  • 괄호 개수 0부터 (N//2)//2개 까지 반복하기
  • 식 실행하기
  • 최대값 구하기

괄호의 최대 개수가 (N//2)//2개이기 때문에 반복문을 사용해서 모든 경우를 고려하여 괄호를 넣어주고 이를 계산하고 최대값을 구하면 되겠다고 생각하고 시작했다.

따라서 아래와 같이 식을 왼쪽에서부터 차례대로 계산할 수 있는 함수를 생성해주었다.

def calculate(a, op, b):
    if op =='+':
        num = a + b
    elif op =='*':
        num = a * b
    elif op == '-':
        num = a - b
    return num   
def do(list):
    result = 0
    operator = '+'
    while list:
        val = list.popleft()
        if val == '(':
            num1 = list.popleft() # 괄호 계산
            op = list.popleft()
            num2 = list.popleft()
            num = calculate(num1,op,num2)
            list.popleft() # ) pop
            result = calculate(result, operator, num)# 괄호 결과값과 현재까지의 결과값 계산
            if len(list)==0: # 계산 모두 끝났으면 종료
                return result
            operator = list.popleft()
        else:
            result = calculate(result, operator, int(val))
            if len(list)==0: # 계산 모두 끝났으면 종료
                return result
            operator = list.popleft() # 계산 아직 남았으면 operator 저장
    return result

사실 이 다음이 문제였다.
괄호를 모든 부분에 추가해야한다는 것을 for문으로 구현하고 이를 deepcopy하여 do 함수에 넣어야하는데 이렇게 되면 시간제한 0.5초에 걸려 시간초과가 될 것 같았다.

따라서 방법이 잘못되었다고 생각하고 다른 방법으로 진행해야겠다고 생각하여 1시간 넘게 고민했던 것 같다.

그러다가 결국 구글링을 했는데, 다른 분들의 코드를 보니 DFS로 진행하는 사람들이 많이 있었던 것 같다.

그중에 가장 직관적인 방법으로 풀어봤다.

DFS를 진행하는데 하나의 함수에서 2번의 재귀를 진행한다. 재귀는 어떨 때 진행하냐면 현재까지의 결과를 파라미터에 담고, 그 다음 계산을 진행하려고 할 때 다음번에 오는 수가 괄호에 속해있거나, 속해있지 않는다고 가정하고 두 가지의 방법을 모두 수행한다.

예를 들어 1+2+3-4-5가 입력이고 현재 1+2까지 수행했다면 현재까지의 값은 3일 것이다. 따라서 다음 연산을 진행하려 할 때 아래와 같이 두 가지로 나뉜다.

  • 3+3-4-5 => -3
  • 3+(3-4)-5 => -3

이 두 식은 다시 나뉘게 될 것이다. 하지만 이 때 괄호는 2개의 숫자와 1개의 연산자를 포함해야 하고 괄호안에 괄호는 있을 수 없기 때문에

  • 3+(3-4)-5인 식은 더이상 괄호를 추가할 수 없다.

이렇게 현재의 인덱스와 현재까지의 결과를 파라미터로 넘겨주면서 진행해야 할 것이다.

문제 풀이

calculate

입력된 a와 b를 operator값이 +, -, * 임에 따라 적절히 계산하여 return 해주는 함수이다.

def calculate(a, op, b):
    if op=='+':
        num = a + b
    elif op=='*':
        num = a * b
    elif op == '-':
        num = a - b
    return num   

DFS

index의 위치에 따라 다음 수가 괄호에 갇혀있을 수도, 없을 수도 있기 때문에 그 조건에 맞춰 재귀를 진행한다.
index가 만약에 마지막에서 1번째라면 식의 마지막까지 계산했다는 의미이기 때문에 최댓값을 result에 저장한다.

index가 마지막에서 1번째에 오기전까지는 다음번의 수와 더하고

index가 마지막에서 3번째에 오기전까지는 다음번의 수와 다다음번의 수가 괄호에 들어갈 수 있기 때문에 dfs를 실행해준다.

이렇게 여러 재귀함수가 실행되면서 결과값을 result에 넣을 것이다.
재귀가 모두 끝나게 되면 result엔 최종적으로 가장 컸던 결과가 저장된다.
따라서 이를 print해주면 정답이 출력된다.

def dfs(index, value):
    global result

    if index == N - 1:
        result = max(result, value)
        return

    if index + 2 < N:
        # 다음번에 나오는 수와 계산하는데, 다음번에 나오는 수가 괄호가 쳐져있지 않을 때
        next_value = calculate(value, cal[index + 1], int(cal[index + 2])) 
        dfs(index + 2, next_value) # 다음번에 나오는 수와 계산했기 때문에 index+2, 다음번에 나오는 수까지 계산한 결과를 파라미터로 넣어줌

    if index + 4 < N:
        # 다음번에 나오는 수와 계산하는데, 다음번에 나오는 수가 다다음번에 나오는 수와 괄호가 쳐져있을 때
        next_next_value = calculate(int(cal[index+2]), cal[index+3], int(cal[index+4])) # 괄호 처리(다음번 수 (+, *, -) 다다음번 수)
        next_value = calculate(value, cal[index + 1], next_next_value) # 다음번에 나오는 수
        dfs(index + 4, next_value) # index + 4까지 계산 끝났기 때문에 index + 4, 현재까지의 결과를 파라미터로 넣어줌

제출 코드

import sys
N = int(sys.stdin.readline())
cal = list(sys.stdin.readline().strip())
result = -1*2**31 # 최솟값

def calculate(a, op, b):
    if op=='+':
        num = a + b
    elif op=='*':
        num = a * b
    elif op == '-':
        num = a - b
    return num   

def dfs(index, value):
    global result

    if index == N - 1:
        result = max(result, value)
        return

    if index + 2 < N:
        # 다음번에 나오는 수와 계산하는데, 다음번에 나오는 수가 괄호가 쳐져있지 않을 때
        next_value = calculate(value, cal[index + 1], int(cal[index + 2])) 
        dfs(index + 2, next_value) # 다음번에 나오는 수와 계산했기 때문에 index+2, 다음번에 나오는 수까지 계산한 결과를 파라미터로 넣어줌

    if index + 4 < N:
        # 다음번에 나오는 수와 계산하는데, 다음번에 나오는 수가 다다음번에 나오는 수와 괄호가 쳐져있을 때
        next_next_value = calculate(int(cal[index+2]), cal[index+3], int(cal[index+4])) # 괄호 처리(다음번 수 (+, *, -) 다다음번 수)
        next_value = calculate(value, cal[index + 1], next_next_value) # 다음번에 나오는 수
        dfs(index + 4, next_value) # index + 4까지 계산 끝났기 때문에 index + 4, 현재까지의 결과를 파라미터로 넣어줌

dfs(0, int(cal[0]))
print(result)

0개의 댓글