[백준 14888] 연산자 끼워넣기

델리만쥬 디퓨저·2025년 3월 21일
0

알고리즘

목록 보기
18/18


https://www.acmicpc.net/problem/14888

분석

  • 시간 제한이 2초이므로, O(N^2)까지는 넉넉할 것으로 보인다.
  • 첫째 줄에 숫자의 입력할 개수가 입력되고, 둘째 줄에 숫자들이 입력된다
  • 셋째 줄에 덧셈, 뺄셈, 곱셈, 나눗셈의 갯수가 각각 입력된다. 이 때, 총 합은 N-1이 된다.

설계

  • 파이썬의 permutations를 사용해 셋째 줄의 모든 경우의 수를 구한다.
  • 둘째 줄에 입력받은 수에 대하여 모든 경우의 수를 계산한다.
  • 최솟값과 최댓값을 구한다.

구조는 다음과 같이 될 것이다.

  1. 입력 값 처리
  2. itertools.permutations로 연산자의 모든 경우의 수를 구하는 배열 생성
  3. 만들어진 배열을 통해 계산한 결과를 배열에 저장
  4. 최솟값과 최댓값을 출력

import itertools  
  
  
def permutations(input, n):  
    result = []  
    for i in range(4):  
        if input[i] > 0:  
            for j in range(input[i]):  
                result.append(i)  
    return list(itertools.permutations(result, n - 1))  
  
  
n = int(input())  
arr = list(map(int, input().split()))  
op_input = list(map(int, input().split()))  
  
MIN = 1000000000  
MAX = -1000000000  
  
op_permutation_list = permutations(op_input, n)  
for op_list in op_permutation_list:  
    num = arr[0]  
    for i in range(1, n):  
        if op_list[i - 1] == 0:  
            num += arr[i]  
        elif op_list[i - 1] == 1:  
            num -= arr[i]  
        elif op_list[i - 1] == 2:  
            num *= arr[i]  
        else:  
            num = int(num / arr[i])  
    if num > MAX:  
        MAX = num  
    if num < MIN:  
        MIN = num  
  
print(MAX)  
print(MIN)

PyPy3에서는 통과하지만 Python3에서는 시간 초과인 것을 확인할 수 있다.


해당 코드의 문제

  • 순열 생성시 동일한 연산 기호가 경우 중복되는 순열이 만들어짐
  • 모든 순열을 리스트에 저장하므로 메모리 측면에서 부담이 증가
  • 순열 생성 및 중복된 순열에 대한 반복문 처리로 오버헤드 증가

DFS

이제 DFS를 사용하여 코드를 개선한다. 목표는 다음과 같다.

  • depth를 N으로 고정
  • 연산자의 남은 개수를 기준으로 계산하여 중복 제거
  • 불필요한 순열 생성 없이 바로 계산을 진행해 오버헤드 최소화

def find_min_max(depth, total, plus, minus, multiply, divide):  
    global MAX, MIN  
    if depth == n:  
        MAX = max(total, MAX)  
        MIN = min(total, MIN)  
        return  
  
    if plus:  
        find_min_max(depth + 1, total + nums[depth], plus - 1, minus, multiply, divide)  
    if minus:  
        find_min_max(depth + 1, total - nums[depth], plus, minus - 1, multiply, divide)  
    if multiply:  
        find_min_max(depth + 1, total * nums[depth], plus, minus, multiply - 1, divide)  
    if divide:  
        find_min_max(depth + 1, int(total / nums[depth]), plus, minus, multiply, divide - 1)  
  
  
n = int(input())  
nums = list(map(int, input().split()))  
op_input = list(map(int, input().split()))  
  
MAX = -1e9  
MIN = 1e9  
  
find_min_max(1, nums[0], op_input[0], op_input[1], op_input[2], op_input[3])  
print(MAX)  
print(MIN)

결과

메모리와 실행 시간이 크게 단축된 것을 확인할 수 있다.

profile
< 너만의 듀얼을 해!!! )

0개의 댓글

관련 채용 정보