[백준] 16637: 괄호 추가하기 (Python)

박성욱·2023년 3월 3일
0

알고리즘

목록 보기
2/13

문제

https://www.acmicpc.net/problem/16637

접근 방법

  1. 중첩되는 괄호가 없으므로, 각 연산자마다 괄호 가질수 있는 경우의 수를 구하기 위해 비트 마스킹 사용
  2. 비트마스킹을 통해 괄호가 연속되는 경우의 수 백트래킹
  3. 스택을 통해 수식이 저장된 리스트의 결과 구하기
  4. 수식 계산시에 우선순위 없이 왼쪽 먼저 계산하는 조건을 지켜야 됨

풀이 코드

def cul(a,center,b): # 각 연산자마다 계산해주는 함수
    if center == '+':
        return a + b
    elif center == '-':
        return a - b
    elif center == '*':
        return a * b

def cal(string): # 스택을 통해 수식의 계산순서를 지정
    stack = []
    
    for ch in string:
        # print(stack)
        if ch.isdecimal(): # 숫자가 들어왔을때
            if stack and not stack[-1].isdecimal() and stack[-1] != '(': # 이전에 저장된게 여는괄호도 아니고, 연산자라면
                center = stack.pop()
                a = stack.pop()
                stack.append(cul(a,center,int(ch))) # 계산결과 저장
                
            else:
                stack.append(int(ch)) # 숫자 저장
            
        elif ch == ')': # 괄호가 닫혔을 때

            a = stack.pop()
            stack.pop() # '('
            stack.append(a)

            if len(stack) > 2: # 괄호 안 결과 이후에 앞에 수식이 있을때
                b = stack.pop()
                c = stack.pop()
                a = stack.pop()
                stack.append(cul(a,c,b))

        
        else:
            stack.append(ch)

    if len(stack) > 2:
        b = stack.pop()
        c = stack.pop()
        a = stack.pop()
        stack.append(cul(a,c,b))

    # print(*string, stack,sep=' ')

    return stack[0]


N = int(input()) # 수식의 길이 항상 홀수
M = N // 2 # 연산자 개수
text = list(input())

mx = cal(text) # 초기 최대치는 괄호 없는 상태로 지정

for i in range(1<<M):
    for j in range(M-1):
        if i & (1<<j) and i & (1<<(j+1)): # 괄호 연속 제거
            break
    else: # 연속괄호 없을 때
        temp = text[:]
        
        for j in range(M-1,-1,-1): # 비트마스킹으로 괄호 씌워주기
            if i & (1<<j):
                temp[j*2:j*2+3] = ['('] + text[j*2:j*2+3] + [')']
            else:
                # temp= temp[j*2:j*2+3]
                pass
                    
        if mx < cal(temp):
            mx = cal(temp)
print(mx)

0개의 댓글