문제

https://www.acmicpc.net/problem/2042

풀이

구간 합을 구해야 하는 문제이다. 그런데 O(logN)O(logN)으로 구해야하는 문제이다!! 일반적으로 구간 합을 구할 때 가장 먼저 떠올리는게 누적 합(Prefix Sum)일 텐데 이는 누적 합을 계산 할 때는 O(N)O(N)의 시간 복잡도를 가지고 활용을 할 때는 O(1)O(1)이어서 한 번 계산을 끝내놓고 여러번 재활용을 할 때 유용한 방법이다.

하지만 이 문제는 1M10,0001 \leq M \leq 10,000 만큼 수정이 일어날 수 있으므로 누적 합을 사용한다면 수정이 일어날때마다 누적합을 갱신해줘야하고 그러면 O(NM)O(NM)의 시간복잡도를 가지게 될 것이다. 그런데 N이 1N1,000,0001 \leq N \leq 1,000,000이므로 누적 합으로는 풀지 못하는 문제이다.

배열에 수정이 일어나는 경우에도 O(logN)O(logN)으로 구간 합을 구할 수 있는 알고리즘으로는 세그먼트 트리(Segment Tree)가 있다. 세그먼트 트리에 대한 개념과 구현 코드에 대해서는 다음 링크에 자세하게 설명이 되어 있다. 이 글로 내용이 부족한 것 같다면 다음 링크를 참고하길 바란다!

세그먼트 트리란 무엇인가?


해당 문제는 아무런 응용이 없이 기본 세그먼트 트리 코드로 풀 수 있다. 그리고 구간 합을 구하기 때문에 수정 함수 역시 `Top-Down` 방식으로 구현하면 된다.
arr = [int(input()) for _ in range(N)]
tree = [0] * (N * 4)

우선 배열을 입력을 받고 N * 4의 크기로 트리를 할당해준다. 트리의 크기를 배열의 4배만큼 할당하는 이유는 세그먼트 트리를 구성하는데는 배열의 크기보다 큰 가장 작은 제곱수의 2배만큼만 필요한데, 제곱수를 구하는 과정이 귀찮기도 하고 그냥 4배로 해도 얼추 그 근사치에 낭비되는 공간이 크지 않기에 그냥 4배로 설정을 하면 된다! 자세한 이유는 위 세그먼트 트리 링크에서 확인 할 수 있다!

def init(start, end, idx):
    if start == end:
        tree[idx] = arr[start]
        return tree[idx]
    mid = (start + end) // 2
    tree[idx] = init(start, mid, idx * 2) + init(mid + 1, end, idx * 2 + 1)
    return tree[idx]

세그먼트 트리를 초기화 하는 함수이다. 리프 노드를 만날 때 까지(start == end) 중간 값(mid)를 구해가며 왼쪽 구간, 오른쪽 구간으로 나눠가며 분할 정복으로 트리를 초기화 해준다! 이 때 부모의 자식 노드는 idx * 2의 인덱스를 오른쪽 자식 노드는 idx * 2 + 1의 인덱스를 가지는 규칙을 통해 각 노드의 자식 노드에 쉽게 접근 할 수 있게 함과 동시에 인덱스가 겹치는 경우를 피해준다!

def find(start, end, idx, left, right):
    if left > end or right < start:
        return 0
    if left <= start and right >= end:
        return tree[idx]
    mid = (start + end) // 2
    return find(start, mid, idx * 2, left, right) + find(mid + 1, end, idx * 2 + 1, left, right)

구간 합을 구하는 함수이다. startend는 배열의 시작, 끝 인덱스이고 leftright는 구하려는 구간의 시작, 끝 인덱스이다. leftend보다 크거나 rightstart보다 작다는 것은 배열의 구간이 구하려는 구간의 범위를 벗어났다는 뜻이고 이 경우에는 구간 합에 포함하면 안되기에 0을 리턴한다.

반대로 leftstart 이하이면서 동시에 rightend 이상이라는 것은 배열의 구간의 구하려는 구간에 포함되었다는 뜻이므로 해당 구간의 구간 합 값을 리턴해준다.

마찬 가지로 왼쪽 구간, 오른쪽 구간 나눠가며 분할 정복으로 구간 합을 구해준다!

def update(start, end, idx, target, diff):
    if target < start or target > end:
        return
    tree[idx] += diff
    if start == end:
        return
    mid = (start + end) // 2
    update(start, mid, idx * 2, target, diff)
    update(mid + 1, end, idx * 2 + 1, target, diff)

target은 수정한 배열의 인덱스, diff수정하려는 값 - 기존 값이다. 구간 합의 경우는 그 차이만큼을 수정한 배열의 인덱스가 포함되는 모든 구간의 노드마다 더해주면 된다. 그 차이가 양수든 음수든 상관 없으니 말이다!

targetstart보다 작거나 targetend보다 크다는 것은 수정한 배열의 인덱스가 해당 구간에 포함이 되지 않는다는 뜻이므로 바로 리턴을 해준다. 그게 아니라면 모두 diff 만큼을 해당 노드에 더해준다. start == end라는 것은 리프 노드에 닿았다는 뜻이므로 바로 리턴을 해주고 아니라면 역시나 분할 정복으로 트리를 탐색하면 된다.

init(0, N - 1, 1)
for _ in range(M + K):
    a, b, c = map(int, input().split())
    if a == 1:
        tmp = c - arr[b - 1]
        arr[b - 1] = c
        update(0, N - 1, 1, b - 1, tmp)
    else:
        print(find(0, N - 1, 1, b - 1, c - 1))

init()을 호출해 트리를 초기화를 해주고 a가 1인 경우에는 값의 차이를 임시 변수에 담아놓았다가 배열의 값을 먼저 수정해주고 그 차이를 update()에게 넘겨준다. a가 2인 경우에는 find()를 통해 구한 구간합을 바로 출력해주면 된다!

💡 idx가 1부터 시작하는 이유?

앞서 왼쪽 자식 노드는 idx * 2, 오른쪽 자식 노드는 idx * 2 + 1의 규칙을 따른다고 했다. 만약에 트리의 인덱스가 0부터 시작을 한다면 왼쪽 자식 노드는 영원히 인덱스가 0일 것이다. 그렇기 때문에 트리의 인덱스를 1부터 시작하는 것이다!

전체 코드

import sys

input = sys.stdin.readline


def init(start, end, idx):
    if start == end:
        tree[idx] = arr[start]
        return tree[idx]
    mid = (start + end) // 2
    tree[idx] = init(start, mid, idx * 2) + init(mid + 1, end, idx * 2 + 1)
    return tree[idx]


def find(start, end, idx, left, right):
    if left > end or right < start:
        return 0
    if left <= start and right >= end:
        return tree[idx]
    mid = (start + end) // 2
    return find(start, mid, idx * 2, left, right) + find(mid + 1, end, idx * 2 + 1, left, right)


def update(start, end, idx, target, diff):
    if target < start or target > end:
        return
    tree[idx] += diff
    if start == end:
        return
    mid = (start + end) // 2
    update(start, mid, idx * 2, target, diff)
    update(mid + 1, end, idx * 2 + 1, target, diff)


N, M, K = map(int, input().split())
arr = [int(input()) for _ in range(N)]
tree = [0] * (N * 4)

init(0, N - 1, 1)
for _ in range(M + K):
    a, b, c = map(int, input().split())
    if a == 1:
        tmp = c - arr[b - 1]
        arr[b - 1] = c
        update(0, N - 1, 1, b - 1, tmp)
    else:
        print(find(0, N - 1, 1, b - 1, c - 1))
profile
응애 개발자입니다.

0개의 댓글