[알고리즘 중급 1/3] 브루트 포스

이미리·2023년 9월 15일
0

boj_Algorithm

목록 보기
23/25

14888번: 연산자 끼워넣기

구현보다 나누기 연산을 조건에 맞게 맞춰주는데 시간이 더 걸렸다.

import itertools

n = int(input())
numbers = list(map(int, input().split()))

arr = []
cc = '+-*/'
n_iters = list(map(int, input().split()))
for i in range(4) :
    num = n_iters[i]
    for _ in range(num) :
        arr.append(cc[i])

nPr = set(itertools.permutations(arr, len(arr)))

minValue = 1e10 + 1
maxValue = -1e10 - 1

for iter in nPr :
    tmp = numbers[0]
    for num_idx in range(1, n) :
        iter_idx = num_idx - 1
        if iter[iter_idx] == '+' :
            tmp += numbers[num_idx]
        elif iter[iter_idx] == '-' :
            tmp -= numbers[num_idx]
        elif iter[iter_idx] == '*' :
            tmp *= numbers[num_idx]
        elif iter[iter_idx] == '/' :
            if tmp < 0 :
                tmp = abs(tmp) // abs(numbers[num_idx])
                tmp *= -1
            else :
                tmp //= numbers[num_idx]
    minValue = min(minValue, tmp)
    maxValue = max(maxValue, tmp)

print(maxValue)
print(minValue)

15658번: 연산자 끼워넣기(2)

시간초과를 3번이나 당한... 일단 시간초과 당한 코드의 시간 복잡도를 계산해보겠다.

[시간복잡도]
DFS 하나당 N번의 loop를 돌게 되므로 O(n)의 시간복잡도를 가진다. 그런데 N개의 정점을 모두 방문 해야하므로 n*O(n)이므로 O(n^2)의 시간복잡도를 가지게 된다.

합이 N-1보다 크거나 같고, 4N보다 작거나 같은 4개의 정수가 연산자 개수의 범위이니, 최대 16n^2... 정도라고 생각하면 된다.

n = int(input())
numbers = list(map(int, input().split()))
cc = '+-*/'
arr = []
n_iters = list(map(int, input().split()))
for i in range(4) :
    num = n_iters[i]
    for _ in range(num) :
        arr.append(cc[i])

visited = [0] * len(arr)

minValue = 1e10 + 1
maxValue = -1e10 - 1

tmp_result = numbers[0]

def permutation(depth) :
    global tmp_result, minValue, maxValue
    if depth == (n - 1):
        minValue = min(minValue, tmp_result)
        maxValue = max(maxValue, tmp_result)
    else :
        for i in range(0, len(arr)) :
            if visited[i] == 0 :
                visited[i] = 1
                save_tmp = tmp_result
                if arr[i] == '+' :
                    tmp_result += numbers[depth + 1]
                elif arr[i] == '-' :
                    tmp_result -= numbers[depth + 1]
                elif arr[i] == '*' :
                    tmp_result *= numbers[depth + 1]
                elif arr[i] == '/' :
                    if tmp_result < 0 :
                        tmp_result = abs(tmp_result) // abs(numbers[depth + 1])
                        tmp_result *= -1
                    else :
                        tmp_result //= numbers[depth + 1]
                permutation(depth + 1)
                visited[i] = 0
                tmp_result = save_tmp

permutation(0)

print(maxValue)
print(minValue)

아래는 시간초과가 뜨지 않은 코드이다.
따로 배열을 만들어 연산자를 넣어주는게 아니고, 그냥 숫자만을 사용해서 계산해주었다. 그렇게 되면 계산해야 되는 시간복잡도를 쓸데없이 늘리기 않을 수 있다. (연산자 개수의 합 -> 각각의 개수로...)

n = int(input())
numbers = list(map(int, input().split()))
cc = '+-*/'
arr = []
add, minus, mul, div = list(map(int, input().split()))

minValue = 1e10 + 1
maxValue = -1e10 - 1

tmp_result = numbers[0]

def permutation(result, depth, add, minus, mul, div) :
    global maxValue, minValue
    if depth == n - 1 :
        minValue = min(minValue, result)
        maxValue = max(maxValue, result)
        return
    
    if add > 0 :
        permutation(result + numbers[depth + 1], depth + 1, add - 1, minus, mul, div)
    if minus > 0 :
        permutation(result - numbers[depth + 1], depth + 1, add, minus - 1, mul, div)
    if mul > 0 :
        permutation(result * numbers[depth + 1], depth + 1, add, minus, mul - 1, div)
    if div > 0 :
        tmp = result
        if tmp < 0 :
            tmp = abs(tmp) // abs(numbers[depth + 1])
            tmp *= -1
        else :
            tmp //= numbers[depth + 1]
        permutation(tmp, depth + 1, add, minus, mul, div - 1)

permutation(numbers[0], 0, add, minus, mul, div)
print(maxValue)
print(minValue)

2580번: 스도쿠

참고여부: O!

풀면서도.. 안될 것 같다는 느낌이 들었다. 시간초과 때문에..~
당연히 시간초과가 떴고, 코드를 고쳐야만 했다. 밑의 코드는 while문이 스도쿠가 완성될 때까지 돌아가기 때문이다.

반면, 백트랙킹을 사용한 코드를 보니 0의 위치를 따로 넣어두고, 그 0에 대해서만 처리를 해주는 방법이었다.
그런데........!!!!!!!!! 코드를 참고해도 계속 틀리는 이유는 무엇?
대체 뭐가 틀렸나 하고 코드를 복붙해 넣어봤는데 이게 웬걸..
시간초과가 뜬다. 알고보니 pypy3로 해야 더 빠른..

파이썬에 대한 기준이 너무 엄격하다!!!

arr = []
for _ in range(9) :
    arr.append(list(map(int, input().split())))

def make_no_zero() :
    global arr
    for i in range(9) :
        for j in range(9) :
            if arr[i][j] == 0:
                if (arr[i].count(0) == 1) :
                    arr[i][j] = 45 - sum(arr[i])
                    i += 1
                    continue;

                col_list = []
                for t in range(9) :
                    col_list.append(arr[t][j])
                if (col_list.count(0) == 1) :
                    arr[i][j] = 45 - sum(col_list)
                    continue;

                nine_list = []
                for n in range((i // 3) * 3, (i // 3) * 3 + 3) :
                    for m in range((j // 3) * 3, (j // 3) * 3 + 3) :
                        nine_list.append(arr[n][m])
                if (nine_list.count(0) == 1) :
                    arr[i][j] = 45 - sum(nine_list)
                    continue;


is_zero = True
while is_zero :
    is_zero = False
    for row in arr :
        if (row.count(0) > 0) :
            is_zero = True
            break
    if is_zero :
        make_no_zero()

for row in arr :
    print(*row)


# 0 3 5 4 6 9 2 7 8
# 7 0 0 1 0 5 6 0 9
# 0 6 0 2 7 8 1 3 5
# 3 2 1 0 4 6 8 9 7
# 8 0 4 9 1 0 5 0 6
# 5 9 6 8 0 0 4 1 3
# 9 1 7 6 5 2 0 8 0
# 6 0 3 7 0 1 9 0 0
# 2 5 8 3 9 4 7 6 0

1987번: 알파벳

dict, set의 in 연산: O(1)
list의 in 연산: O(n)

연산의 차이 때문에

하하! 엄청 헤맸다 ㅠㅠ
못 잊을 것 같다..

r, c = list(map(int, input().split()))
arr = []
for _ in range(r) :
    arr.append(list(input()))

road = set()

road.add(arr[0][0])

# up, down, right, left
rr = [-1, 1, 0, 0]
cc = [0, 0, 1, -1]
maxValue = 0
def bfs(row, col, cnt) :
    global maxValue
    maxValue = max(cnt, maxValue)
    for i in range(4) :
        newr = rr[i] + row
        newc = cc[i] + col
        if (newr >= 0 and newr < r and newc >= 0 and newc < c and arr[newr][newc] not in road) :
            road.add(arr[newr][newc])
            bfs(newr, newc, cnt + 1)
            road.remove(arr[newr][newc])
            
bfs(0, 0, 1)

print(maxValue)

0개의 댓글

관련 채용 정보