[백준] 16638번 괄호 추가하기 2 - Python / 알고리즘 중급 2/3 - 브루트 포스 - 문제 (연습)

ByungJik_Oh·2025년 8월 7일
0

[Baekjoon Online Judge]

목록 보기
220/244
post-thumbnail



💡 문제

길이가 N인 수식이 있다. 수식은 0보다 크거나 같고, 9보다 작거나 같은 정수와 연산자(+, -, ×)로 이루어져 있다. 곱하기의 연산자 우선순위가 더하기와 빼기보다 높기 때문에, 곱하기를 먼저 계산 해야 한다. 수식을 계산할 때는 왼쪽에서부터 순서대로 계산해야 한다. 예를 들어, 3+8×7-9×2의 결과는 41이다.

수식에 괄호를 추가하면, 괄호 안에 들어있는 식은 먼저 계산해야 한다. 단, 괄호 안에는 연산자가 하나만 들어 있어야 한다. 예를 들어, 3+8×7-9×2에 괄호를 (3+8)×7-(9×2)와 같이 추가했으면, 식의 결과는 59가 된다. 하지만, 중첩된 괄호는 사용할 수 없다. 즉, 3+((8×7)-9)×2, 3+((8×7)-(9×2))은 모두 괄호 안에 괄호가 있기 때문에, 올바른 식이 아니다.

수식이 주어졌을 때, 괄호를 적절히 추가해 만들 수 있는 식의 결과의 최댓값을 구하는 프로그램을 작성하시오. 추가하는 괄호 개수의 제한은 없으며, 추가하지 않아도 된다.

입력

첫째 줄에 수식의 길이 N(1 ≤ N ≤ 19)가 주어진다. 둘째 줄에는 수식이 주어진다. 수식에 포함된 정수는 모두 0보다 크거나 같고, 9보다 작거나 같다. 문자열은 정수로 시작하고, 연산자와 정수가 번갈아가면서 나온다. 연산자는 +, -, * 중 하나이다. 여기서 *는 곱하기 연산을 나타내는 × 연산이다. 항상 올바른 수식만 주어지기 때문에, N은 홀수이다.

출력

첫째 줄에 괄호를 적절히 추가해서 얻을 수 있는 결과의 최댓값을 출력한다. 정답은 231^{31}보다 작고, -231^{31}보다 크다.


💭 접근

이 문제는 16637번 괄호 추가하기 문제와 비슷하지만 괄호에 들어있는 연산자만이 우선순위를 가지는 이전 문제와 달리, 이번 문제는 괄호에 들어있는 연산자가 최우선으로, 그리고 곱셈 연산이 그 다음 우선순위를 가지므로 연산할 때 괄호로 묶지 않은 곱셈 연산을 먼저 처리해주는 로직이 추가로 필요하다.

우선, 괄호로 묶을 연산자를 선택한다.

def dfs(start, depth):
    if depth == cnt:
        calculate()
        return
    
    for i in range(start, len(oper_idx)):
        tmp.append(oper_idx[i])
        dfs(i + 2, depth + 1)
        tmp.pop()

이때, 괄호가 중첩되어 그려지지 않도록하기 위해 선택한 연산자 바로 다음 연산자는 선택하지 않도록 다음 반복문의 start를 +1이 아닌 +2로하여 바로 다음 연산자를 선택하지 못하도록 해야한다.

이후, 괄호로 묶을 연산자를 선택했다면 문제에서 주어진 조건에 따라 결과를 구할 차례이다.
먼저 괄호로 묶인 연산자들을 먼저 처리한다.

tmp_s = s[:]
tmp_oper = sorted(tmp, reverse=True)
for idx in tmp_oper:
    a = int(tmp_s.pop(idx - 1))
    oper = tmp_s.pop(idx - 1)
    b = int(tmp_s.pop(idx - 1))

    if oper == '+':
        tmp_s.insert(idx - 1, a + b)
    elif oper == '-':
        tmp_s.insert(idx - 1, a - b)
    elif oper == '*':
        tmp_s.insert(idx - 1, a * b)

이후, 남아있는 곱셈 연산을 처리한다.

times_idx = [idx for idx in range(len(tmp_s)) if tmp_s[idx] == '*']
tmp_times = sorted(times_idx, reverse=True)
for idx in tmp_times:
    a = int(tmp_s.pop(idx - 1))
    oper = tmp_s.pop(idx - 1)
    b = int(tmp_s.pop(idx - 1))

    if oper == '+':
        tmp_s.insert(idx - 1, a + b)
    elif oper == '-':
        tmp_s.insert(idx - 1, a - b)
    elif oper == '*':
        tmp_s.insert(idx - 1, a * b)

마지막으로 남아있는 덧셈과 뺄셈을 순서대로 계산한 뒤, 정답을 갱신해주면 된다.

while len(tmp_s) >= 3:
    a = int(tmp_s.pop(0))
    oper = tmp_s.pop(0)
    b = int(tmp_s.pop(0))

    if oper == '+':
        tmp_s.insert(0, a + b)
    elif oper == '-':
        tmp_s.insert(0, a - b)
    elif oper == '*':
        tmp_s.insert(0, a * b)

ans = max(ans, int(tmp_s[-1]))

📒 코드

def dfs(start, depth):
    if depth == cnt:
        calculate()
        return
    
    for i in range(start, len(oper_idx)):
        tmp.append(oper_idx[i])
        dfs(i + 2, depth + 1)
        tmp.pop()
    
def calculate():
    global ans

    tmp_s = s[:]
    tmp_oper = sorted(tmp, reverse=True)
    for idx in tmp_oper:
        a = int(tmp_s.pop(idx - 1))
        oper = tmp_s.pop(idx - 1)
        b = int(tmp_s.pop(idx - 1))

        if oper == '+':
            tmp_s.insert(idx - 1, a + b)
        elif oper == '-':
            tmp_s.insert(idx - 1, a - b)
        elif oper == '*':
            tmp_s.insert(idx - 1, a * b)

    times_idx = [idx for idx in range(len(tmp_s)) if tmp_s[idx] == '*']
    tmp_times = sorted(times_idx, reverse=True)
    for idx in tmp_times:
        a = int(tmp_s.pop(idx - 1))
        oper = tmp_s.pop(idx - 1)
        b = int(tmp_s.pop(idx - 1))

        if oper == '+':
            tmp_s.insert(idx - 1, a + b)
        elif oper == '-':
            tmp_s.insert(idx - 1, a - b)
        elif oper == '*':
            tmp_s.insert(idx - 1, a * b)

    while len(tmp_s) >= 3:
        a = int(tmp_s.pop(0))
        oper = tmp_s.pop(0)
        b = int(tmp_s.pop(0))

        if oper == '+':
            tmp_s.insert(0, a + b)
        elif oper == '-':
            tmp_s.insert(0, a - b)
        elif oper == '*':
            tmp_s.insert(0, a * b)

    ans = max(ans, int(tmp_s[-1]))

n = int(input())
s = list(input())

oper_idx = [i for i in range(1, n - 1, 2)]
times_idx = [idx for idx in oper_idx if s[idx] == '*']

ans = -1e10
tmp = []
for i in range(len(oper_idx)//2 + 1):
    cnt = i
    dfs(0, 0)
print(ans)

💭 후기

16637번 괄호 추가하기 문제와 거의 똑같은 문제라서 쉽게 해결할 수 있었다.


🔗 문제 출처

https://www.acmicpc.net/problem/16638


profile
精進 "정성을 기울여 노력하고 매진한다"

0개의 댓글