분할 정복을 이용한 거듭제곱

Noah·2024년 7월 31일

알고리즘

목록 보기
3/20

분할 정복을 이용한 거듭제곱이란?

기본적으로 거듭제곱을 한다고 하면 다음과 같이

# 2^n 승을 구한다고 치면
res = 1
for i in range(n):
	res *= 2

식으로 구하기 때문에 시간 복잡도는 O(n)O(n)입니다.
하지만 만약 100만제곱이라면 당연히 시간초과가 나게 됩니다. 분할 정복은 작은 부분으로 쪼개서 생각하는 알고리즘인데, 이를 거듭제곱에 적용해서 시간 복잡도를 O(logn)O(\log n)으로 줄일 수 있습니다.

a17a^{17}을 계산한다고 해봅시다. 지수 법칙에 따라
a17=(a8)2aa^{17} = (a^8)^2*a이고, 이는
a17=((a4)2)2a=(((a2)2)2)2a=(((aa)2)2)2a^{17} = ((a^4)^2)^2*a = (((a^2)^2)^2)^2*a = (((a*a)^2)^2)^2 입니다.

따라서 연산 횟수는 5번이고, 이는 ceil(logn)ceil(\log n)입니다.
이때 우리는 지수가 짝수인 경우, 그대로 출력하고, 홀수인 경우 2로 나눈 값에 한번 더 곱해주는 알고리즘을 생각 할 수 있습니다.

이때 모듈러 연산을 하면 어떻게 될까요?
예를 들어 27mod72^{7} \mod7를 계산한다고 해봅시다. 간단하게 직접 계산하면 1282(mod7)128 \equiv 2\pmod 7 입니다. 이걸 2의 승수를 계산하면서 계산한다면

22(mod7)44(mod7)81(mod7)162(mod7)324(mod7)641(mod7)1282(mod7)2 \equiv 2\pmod 7\\4 \equiv 4\pmod 7\\8 \equiv 1\pmod 7\\16 \equiv 2\pmod 7 \\32 \equiv 4\pmod 7\\64 \equiv 1\pmod 7\\128 \equiv 2\pmod 7

이때 나머지끼리 계속 2를 곱해주고, mod7\mod7을 해주고 있지만 최종 결과는 앞에서의 결과와 같습니다. 즉, 나머지를 구하고 싶다면 나머지끼리만 계산해도 된다는 것입니다.

코드(BOJ 1629)

# 재귀함수를 이용한 코드
import sys 

def power(base, exponent, mod):
	if exponent == 0:
    	return 1
    elif exponent % 2 == 1:
    	temp = power(base, exponent // 2, mod)
        return (temp*temp*base) % mod
    else:
    	temp = power(base, exponent // 2, mod)
        return (temp*temp) % mod
 
 input = sys.stdin.readline
 base, exponent, mod = map(int, input().split())
 print(power(base, exponent, mod))
# 반복문을 이용한 코드
import sys 

def power(base, exponent, mod):
	res = 1
   	base %= mod
    while exponent > 0:
        if exponent % 2 == 1:
            res = res * base % mod
        exponent = exponent // 2
        base = (base * base) % mod # 가장 가까운 2의 승수와 홀수부분을 곱하는 코드
 
 input = sys.stdin.readline
 base, exponent, mod = map(int, input().split())
 print(power(base, exponent, mod))

행렬에서의 활용(BOJ 10830)

행렬 제곱은 정사각형 행렬에서만 할 수 있는데, 이를 위한 함수를 따로 만들어줘야 합니다. 또, 행렬은 ×1\times1 연산이 안되므로 exponent=1exponent = 1 일때 basebase를 리턴해줘야합니다. 정사각형 행렬이기 때문에 시간복잡도가 O(n2)O(n^2)이지만, 전에 만들었던 코드를 활용해서 O(n3)O(n^3)으로 풀어봤습니다.

def m_m(arr, tp): # 행렬의 곱
    res = []
    for i in range(len(arr)):
        temp = []
        for j in range(len(arr)):
            sum = 0
            for k in range(len(arr[i])):
                sum += tp[i][k] * arr[k][j]
            temp.append(sum%1000)
        res.append(temp)
    return res

def power(base, exponent):
    if exponent == 1:
        return base
    elif (exponent % 2) == 1:
        temp = power(base, (exponent - 1) // 2)
        return m_m(m_m(temp, temp), base) # 홀수 제곱일 경우 
    else:
        temp = power(base, exponent // 2)
        return m_m(temp, temp) 

if __name__ == "__main__":
    arr = []
    n, m = map(int, input().split())
    for i in range(n):
        temp = list(map(int, input().split()))
        arr.append(temp)
    res = power(arr, m)
    for i in res:
        for j in i:
            print(j%1000, end=" ")
        print()

행렬 제곱으로 피보나치 수를 표현하는 법(BOJ 2749)

[1110]\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\\
위와 같은 행렬이 있다고 하면 행렬의 곱의 특성에 의해 오른쪽 위 원소가 위 두 원소의 합이 됩니다. 그렇기 때문에 이것의 거듭제곱을 통해 피보나치 수를 도출할 수 있습니다.

이런 방법으로 원래 피보나치 수를 구하는 다이나믹 프로그래밍은 시간 복잡도가 O(n)O(n)이지만, 이 방법은 분할 정복을 사용한 거듭제곱을 사용하기 때문에 시간 복잡도가 O(logn)O(\log n)으로 획기적으로 줄어들게 됩니다.

import sys

def m_m(arr, tp):
    res = []
    for i in range(len(arr)):
        temp = []
        for j in range(len(arr)):
            sum = 0
            for k in range(len(arr[i])):
                sum += tp[i][k] * arr[k][j]
            temp.append(sum%1000000)
        res.append(temp)
    return res

def power(base, exponent):
    if exponent == 1:
        return base
    elif (exponent % 2) == 1:
        temp = power(base, (exponent - 1) // 2)
        return m_m(m_m(temp, temp), base)
    else:
        temp = power(base, exponent // 2)
        return m_m(temp, temp)

if __name__ == "__main__":
    arr = [[1, 1], [1, 0]]
    m = int(sys.stdin.readline())
    res = power(arr, m)
    print(res[0][1]%1000000)
profile
부산소프트웨어마이스터고 4기 | 자세한 내용은 홈페이지(노션)의 테크 블로그에서 확인할 수 있습니다.

0개의 댓글