[BOJ] 14888 - 연산자 끼워넣기

gmelan·2022년 2월 5일
0

알고리즘 트레이닝

목록 보기
3/14

풀어보기

접근

  1. 연산자 우선 순위는 무시하고, 첫 번째 요소부터 연산을 진행한다.
  2. 음수를 양수로 나눌 경우에는 양수로 바꾼 뒤 연산을 진행한 후 그 몫을 음수로 바꾼다.
  3. 연산 결과값 result 및 중간값 part_result는 모두
    (-(10^6) <= result, part_result <= 10^6) 범위 내에 존재한다.

이 문제에서 중간 연산 결과만으로 최종 연산 결과의 최대, 최소 여부를 판단할 수 없기 때문에, 별도의 가지치기는 불가능하며 따라서 모든 경우에 대하여 탐색을 진행한다.

코드 1 - 날 것 그대로의 코드

아이디어대로 구현해 본 코드이다. 계산을 하는 함수를 별도로 두어 연산 작업을 조금 더 간편하게 하고자 하였으나, 후술하듯이 오히려 구조만 복잡하게 만드는 꼴이 되어버렸다.

from sys import stdin

LIMIT = 10_0000_0000
min_ = LIMIT
max_ = -1 * LIMIT

N = int(stdin.readline())
A = tuple(int(i) for i in stdin.readline().strip().split())
operators = [int(i) for i in stdin.readline().strip().split()]
ERROR = -(LIMIT + 1)

def operate(operator_idx, a, b):
    if operator_idx == 0:
        return a + b
    elif operator_idx == 1:
        return a - b
    elif operator_idx == 2:
        return a * b
    else:
        if b == 0:
            return ERROR
        return -1 * ((-1 * a) // b) if a < 0 else a // b

def search(depth, part_result):
    global min_, max_, operators, ERROR, N, A

    if depth == N - 1:
        if part_result > max_:
            max_ = part_result
        if part_result < min_:
            min_ = part_result
        return

    for i in range(len(operators)):
        if operators[i] > 0:
            operators[i] -= 1
            res = operate(i, part_result, A[depth + 1])
            if res != ERROR:
                search(depth + 1, res)
            operators[i] += 1

search(0, A[0])
print(str(max_) + '\n' + str(min_))

코드 2 - 조금 더 개선된 코드

search 함수 내 반복문의 반복 횟수가 항상 4이고, 이를 매개변수화하여 불필요한 연산(operators 리스트의 각 요소 연산)을 제거할 수 있다는 사실을 발견하였다. 또한 별도로 연산만을 담당하는 함수 operate를 제거하면 별도의 처리(0으로 나누는 경우 오류코드 리턴 등)를 하지 않고도 그냥 조건문과 재귀를 통해 답을 구할 수 있었다.

from sys import stdin

LIMIT = 10_0000_0000
N = int(stdin.readline())
A = tuple(int(i) for i in stdin.readline().strip().split())
OP = tuple(int(i) for i in stdin.readline().strip().split())

min_ = LIMIT
max_ = -1 * LIMIT

def search(depth, part_result, add, sub, mul, div):
    global min_, max_, OP, N, A

    if depth == N - 1:
        if part_result > max_:
            max_ = part_result
        if part_result < min_:
            min_ = part_result
        return

    if add > 0:
        search(depth + 1, part_result + A[depth + 1], add - 1, sub, mul, div)
    
    if sub > 0:
        search(depth + 1, part_result - A[depth + 1], add, sub - 1, mul, div)

    if mul > 0:
        search(depth + 1, part_result * A[depth + 1], add, sub, mul - 1, div)

    if div > 0 and A[depth + 1] != 0:
        search(
            depth + 1,
            part_result // A[depth + 1] if part_result > 0 else - ((- part_result) // A[depth + 1]),
            add,
            sub,
            mul,
            div - 1
            )

search(0, A[0], OP[0], OP[1], OP[2], OP[3])
print(str(max_) + '\n' + str(min_))

코드 1보다 훨씬 간결해졌을뿐더러 처리 시간도 약간 단축되는 효과가 있었다. (124ms -> 88ms)

백트래킹 문제이지만 가지치기 과정이 없어 생각할 것이 그리 많지는 않던 문제였다. 나와 비슷한 생각을 하였다면 조금 더 어려운 문제를 풀어보아도 될 듯 하다.

0개의 댓글