카라츠바의 빠른 곱셈 알고리즘

지식 저장소·2021년 11월 25일
0

문제해결기법

목록 보기
9/21

카라츠바의 빠른 곱셈 알고리즘은 분할 정복 알고리즘의 한 예입니다. 정수형 타입으로 나타낼 수 없을 정도로 큰 정수를 곱하는 알고리즘입니다. 정수형 타입으로 나타낼 수 없기 때문에 정수형 배열을 이용해 곱셈을 합니다. 기본적인 곱셈 알고리즘은 산수 시간에 배운 방법을 그대로 사용하는 것입니다. 정수형 배열로 정수를 표현할 때는 작은 자릿수를 배열의 가장 처음 위치에 저장합니다. 이렇게 하는 이유는 A[i]A[i]에 주어진 자릿수의 크기를 10i10^i로 쉽게 구할 수 있습니다. 따라서 A[i]A[i]B[i]B[i]를 곱한 결과를 C[i+j]C[i+j]에 저장하는 등, 훨씬 직관적인 코드를 작성할 수 있습니다.

두 큰 수를 곱하는 O(n2)O(n^2) 시간 알고리즘

// num[]의 자릿수 올림을 처리한다.
public static void normalize(ArrayList<Integer> num) {
    num.add(0);
    // 자릿수 올림을 처리한다.
    for (int i = 0; i + 1 < num.size(); i++) {
        if (num.get(i) < 0) {
            int borrow = (Math.abs(num.get(i)) + 9) / 10;
            num.set(i + 1, num.get(i + 1) - borrow);
            num.set(i, num.get(i) + borrow * 10);
        } else {
            num.set(i + 1, num.get(i + 1) + num.get(i) / 10);
            num.set(i, num.get(i) % 10);
        }
    }
    while(num.size() > 1 && num.get(num.size() - 1) == 0) num.remove(num.size() - 1);
}
// 두 긴 자연수의 곱을 반환한다.
// 각 배열에는 각 수의 자릿수가 1의 자리에서부터 시작해 저장되어 있다.
// 예: multiply({3, 2, 1}, {6, 5, 4}) = 123 * 456 = 56088 = {8, 8, 0, 6, 5}
public static ArrayList<Integer> multiply(final ArrayList<Integer> a, final ArrayList<Integer> b) {
    ArrayList<Integer> c = new ArrayList<>();
    for (int i = 0; i < a.size() + b.size() + 1; i++) {
        c.add(0);
    }
    for (int i = 0; i < a.size(); i++) {
        for (int j = 0; j < b.size(); j++) {
            c.set(i + j, c.get(i + j) + a.get(i) * b.get(j));
        }
    }
    normalize(c);
    return c;
}

이 알고리즘의 시간 복잡도는 두 정수의 길이가 모두 nn이라고 할 때 O(n2)O(n^2)입니다. nn번 실행되는 for문이 두 번 겹쳐 있기 때문에 이 점은 자명합니다.
참고로 이 nomalize 함수에서 음수 처리는 왜 하는지 궁금할 수 있습니다. 그 이유는 카라츠바의 빠른 곱셈 알고리즘을 구현할 때 배열끼리 뺄셈을 하는 경우가 있는데 그 때 쓰기 위해 음수까지 처리했습니다.

카라츠바의 빠른 곱셈 알고리즘

카라츠바의 빠른 곱셈 알고리즘은 두 수를 각각 절반으로 쪼갭니다. aabb가 각각 256자리 수라면 a1a_1b1b_1은 첫 128자리, a0a_0b0b_0은 그 다음 128자리를 저장하도록 하는 것이죠. 그러면 aabb를 다음과 같이 쓸 수 있습니다.
a=a1×10128+a0a=a_1\times 10^{128}+a_0
b=b1×10128+b0b=b_1\times 10^{128}+b_0
카라츠바는 이 때 a×ba\times b를 네 개의 조각을 이용해 표현하는 방법을 살펴보았습니다. 예를 들면 다음과 같이 표현할 수 있지요.
a×b=(a1×10128+a0)×(b1×10128+b0)=a1×b1×10256+(a1×b0+a0×b1)×10128+a0×b0a\times b=(a_1\times 10^{128}+a_0)\times (b_1\times 10^{128}+b_0)=a_1\times b_1\times 10^{256}+(a_1\times b_0+a_0\times b_1)\times 10^{128}+a_0\times b_0
이 방법에서 우리는 큰 정수 두 개를 한 번 곱하는 대신, 절반 크기로 나눈 작은 조각을 네 번 곱합니다(10의 거듭제곱과 곱하는 것은 그냥 뒤에 0을 붙이는 시프트 연산으로 구현하면 되니 곱셈으로 치지 않습니다). 이 방법의 시간 복잡도는 O(n2)O(n^2)입니다. 왜냐하면 길이 nn인 두 정수를 곱하는 데 드는 시간은 덧셈과 시프트 연산에 걸리는 시간 O(n)O(n)과, n2n\over2 길이 조각들의 곱셈 네 번으로 나눌 수 있기 때문입니다.(길이 nn인 두 정수를 곱하는 데 드는 시간을 T(n)T(n)이라고 하면 T(n)=O(n)+4T(n2)T(n)=O(n)+4\sdot T({n\over2})이 됩니다). 분할 정복을 구현해도 시간 복잡도는 같은데 왜 분할 정복을 구현하느냐? 카라츠바는 중요한 중요한 사실을 발견했기 때문입니다.
카라츠바가 발견한 것은 다음과 같이 a×ba\times b를 표현했을 때 네 번 대신 세 번의 곱셈으로만 이 값을 계산할 수 있다는 것 입니다.
a×b=a1×b1z2×10256+(a1×b0+a0×b1)z1×10128+a0×b0z0a\times b=\underbrace{a_1\times b_1}_{z_2}\times 10^{256}+\underbrace{(a_1\times b_0+a_0\times b_1)}_{z_1}\times 10^{128}+\underbrace{a_0\times b_0}_{z_0}

조각들의 곱을 각각 위와 같이 z2,z1,z0z_2, z_1, z_0이라고 씁시다. 우선 z0z_0z2z_2를 각각 한 번의 곱셈으로 구합니다. 그리고 다음 식을 이용하지요.
(a0+a1)×(b0+b1)=a0×b0z0+a1×b0+a0×b1z1+a1×b1z2=z0+z1+z2(a_0+a_1)\times(b_0+b_1)=\underbrace{a_0\times b_0}_{z_0}+\underbrace{a_1\times b_0+a_0\times b_1}_{z_1}+\underbrace{a_1\times b_1}_{z_2}=z_0+z_1+z_2

따라서 위 식의 결과에서 z0z_0z2z_2를 빼서 z1z_1을 구할 수 있습니다. 다음과 같은 코드 조각을 이용하면 되지요.

z2 = a1 * b1;
z0 = a0 * b0;
z1 = (a0 + a1) * (b0 + b1) - z0 - z2;

이 과정은 곱셈을 세 번밖에 쓰지 않습니다. 이 세 결과를 적절히 조합해 원래 두 수의 답을 구해낼 수 있습니다.

카라츠바의 빠른 정수 곱셈 알고리즘

// a += b * (10^k);를 구현합니다.
public static void addTo(ArrayList<Integer> a, final ArrayList<Integer> b, int k) {
    int size = Math.max(a.size(), b.size() + k);
    while (a.size() != size) a.add(0);
    for (int i = 0; i < b.size(); i++) {
        a.set(i + k, a.get(i + k) + b.get(i));
    }
}
// a -= b;를 구현합니다. a >= b를 가정합니다.
public static void subFrom(ArrayList<Integer> a, final ArrayList<Integer> b) {
    for (int i = 0; i < b.size(); i++) {
        a.set(i, a.get(i) - b.get(i));
    }
    normalize(a);
}
// 배열의 일부분만 따로 떼어냅니다.
public static ArrayList<Integer> subList(ArrayList<Integer> a, int fromIndex, int toIndex) {
    ArrayList<Integer> list = new ArrayList<>();
    for (int i = fromIndex; i < toIndex; i++) {
        list.add(a.get(i));
    }
    return list;
}
// 두 긴 정수의 곱을 반환합니다.
public static ArrayList<Integer> karatsuba(final ArrayList<Integer> a, final ArrayList<Integer> b) {
    int an = a.size(); int bn = b.size();
    // a가 b보다 짧을 경우 둘을 바꾼다.
    if (an < bn) return karatsuba(b, a);
    // 기저 사례: a나 b가 비어 있는 경우
    if (an == 0 || bn == 0) return new ArrayList<>();
    // 기저 사례: a가 비교적 짧은 경우 O(n^2) 곱셈으로 변경한다.
    if (an <= 50) return multiply(a, b);
    int half = an / 2;
    // a와 b를 밑에서 half 자리와 나머지로 분리한다.
    ArrayList<Integer> a0 = subList(a,0, half);
    ArrayList<Integer> a1 = subList(a, half, a.size());
    ArrayList<Integer> b0 = subList(b, 0, Math.min(b.size(), half));
    ArrayList<Integer> b1 = subList(b, Math.min(b.size(), half), b.size());
    // z2 = a1 * b1
    ArrayList<Integer> z2 = karatsuba(a1, b1);
    // z0 = a0 * b0
    ArrayList<Integer> z0 = karatsuba(a0, b0);
    // a0 = a0 + a1; b0 = b0 + b1
    addTo(a0, a1, 0); addTo(b0, b1, 0);
    // z1 = (a0 * b0) - z0 - z2;
    ArrayList<Integer> z1 = karatsuba(a0, b0);
    subFrom(z1, z0);
    subFrom(z1, z2);
    // result = z0 + z1 * 10^half + z2 * 10^(half*2)
    ArrayList<Integer> result = new ArrayList<>();
    addTo(result, z0, 0);
    addTo(result, z1, half);
    addTo(result, z2, half + half);
    return result;
}

카라츠바 알고리즘은 분할한 부분 문제의 답에서 원래 문제의 답을 병합해내는 부분을 개선함으로써 알고리즘의 성능을 향상시킨 좋은 예입니다.

시간 복잡도 분석

카라츠바 알고리즘은 두 개의 입력을 절반씩으로 쪼갠 뒤, 세 번 재귀 호출을 합니다.
우선 카라츠바 알고리즘의 수행 시간을 병합 단계와 기저 사례의 두 부분으로 나눕시다. 위 코드에서 병합 단계의 수행 시간은 addTo()와 subFrom()의 수행 시간에 지배되고, 기저 사례의 처리 시간은 multiply()의 수행 시간에 지배되는 것을 볼 수 있습니다.
먼저 기저 사례를 처리하는 데 드는 총 시간을 알아봅니다. 여기서는 편의를 위해 한 자리 숫자에 도달해야만 multiply()를 사용한다고 가정합니다. 자릿수 nn이 2의 거듭제곱 2k2^k라고 했을 때 재귀 호출의 깊이는 kk가 됩니다. 한 번 쪼갤 때마다 해야 할 곱셈의 수가 세 배씩 늘어나기 때문에 마지막 단계에는 3k3^k개의 부분 문제가 있는데, 마지막 단계에서는 두 수 모두 한자리니까 곱셈 한 번이면 충분합니다. 따라서 곱셈의 수는 O(3k)O(3^k)가 됩니다. n=2kn=2^k라고 가정했으니 k=lognk=\log n이고, 이때 곱셈의 수를 nn에 대해 표현하면 다음과 같은 식이 됩니다.
O(3k)=(3logn)=O(nlog3)O(3^k)=(3^{\log n})=O(n^{\log 3})
log31.585\log 3\approx1.585이기 때문에 카라츠바 알고리즘이 O(n2)O(n^2)보다 훨씬 적은 곱셈을 필요로 한다는 것을 알 수 있습니다.
다음으로 병합 단계에 드는 시간의 총 합을 구해 봅시다. addTo()와 subFrom()은 숫자의 길이에 비례하는 시간이 걸리므로 각 단계에 해당하는 숫자의 길이를 모두 더하면 병합 단게에 드는 시간을 계산할 수 있습니다. 단계가 내려갈 때마다 숫자의 길이는 절반으로 줄고 부분 문제의 개수는 세 배 늘기 때문에, ii번째 단계에서 필요한 연산 수는 (32)i×n({3\over 2})^i\times n이 됩니다. 따라서 모든 단계에서 필요한 전체 연산의 수는 다음 식에 비례합니다.
n×i=0logn(32)in\times \sum_{\mathclap{i=0}}^{\log n}({3\over 2})^i
이 함수는 nlog3n^{\log 3}과 같은 속도로 증가합니다. 따라서 카라츠바 알고리즘의 시간 복잡도는 곱셈이 지배하며, 최종 시간복잡도는 O(nlog3)O(n^{\log 3})이 됩니다(마스터 정리를 이용하면 알 수 있습니다).
단, 카라츠바 알고리즘의 구현은 복잡하기 때문에, 입력의 크기가 작을 경우 O(n2)O(n^2) 알고리즘보다 느릴 수 있습니다.

참고문헌: 구종만, 프로그래밍 대회에서 배우는 알고리즘 문제해결전략, 인사이트, (2012)

profile
그리디하게 살자.

0개의 댓글