[Python] 세그먼트 트리 (Segment Tree)

서녁·2022년 4월 28일
0

잡담

내가 알고리즘 처음 공부했을 시절..(불과 몇 달 안 됐지만..)
백준 알고리즘 분류보고 제일 궁금했던 그 이름 세그먼트 트리..
에 대해 오늘 써볼까한다.

세그먼트 트리(Segment Tree) : 특정 구간(Segment)에 대한 구간 값을 트리에 저장해둔 자료구조.

철저히 문과생 입장에선(트리도 몰랐던 시절)

처음 들었을 땐, 음... 이게 뭐람..?

그런 거였는데.

아무튼,

간단한 예시로 보자

nums = [3, 2, 8, 4, 7, 5, 6]

이런 숫자 배열이 있다.

각 구간의 합을 구하고싶은데

nums[start:end] 의 구간에 대해 합을 n번 구한다고 하면
(start와 end의 값은 계속 바뀌는 걸로 가정하자.)

총 연산이 (end-start)*n번 일어난다.

배열 길이가 크고 구간 길이도 크면 답이 없다.

그래서 어느 정도 구간에 대한 합은 미리 구해서 저장해두고,
'필요할 때 그 값만 찾자' 라는 의미로 세그먼트 트리를 사용하게 된다.

아 세그먼트 트리는 이진트리를 기본으로 한다.

대충 이런 느낌으로 말이다.


트리 구현

세그먼트 트리를 구현하는 방법은 재귀를 사용하는 top-down 방식과
반복을 사용하는 bottom-up 방식이 있다고하는데..

일단은 재귀를 사용하는 방식이 더 쉬워보이니까..

일단 필요한건 트리 배열을 먼저 만들어줘야한다.

필요한 크기는

2**(트리 수준 + 1) 이다.
(트리 수준에 대해서는.. 루트 노드 수준을 0으로 보느냐 1로 보느냐에 차이가 있지만 여기선 0이라고 생각하자..)

트리 수준은 밑을 2로하는 로그로 쉽게 구할 수 있으니까

tree = [0] * 2**(ceil(log(n, 2) + 1))
# n은 nums의 길이라고 하자.

이제 루트 노드부터 아래 노드 두개씩 합치는 방법으로
트리를 구현해주면 된다. (리프노드라면 nums의 값을 넣고..!)

# top-down 방식의 segment tree 구현
# 1차원 배열로 구현하는 경우 트리의 시작 인덱스가 1이기 때문에 i는 무조건 1에서 시작한다고 보는게 편할 듯
def segment(left, right, i=1):
    # i는 구하려는 인덱스, left는 구간 범위 왼쪽, right는 구간 범위 오른쪽
    if left == right:
        # 구간 길이가 1일 때, 트리에 자기 자신을 저장
        tree[i] = nums[left]
        return tree[i]
    # 범위를 절반으로 나누고
    mid = (left + right) // 2
    # segment에 대한 연산을 노드에 저장
    tree[i] = segment(left, mid, i*2) + segment(mid+1, right, i*2+1)
    return tree[i]

주석이 많이 달려서 그렇지
의외로 구현자체는 간단하다.

꽤 직관적이기도 하고..?
절반 나누고 왼쪽 오른쪽 더하고
구간길이 1이면 리프노드니까 배열에서 인덱스 값 넣어주고.


구간합 탐색

# 구간합 반환 함수
def search(start, end, left, right, i=1):
    # start = 구간 범위 왼쪽, end = 구간 범위 오른쪽
    # left = 찾는 범위 왼쪽, 찾는 범위 오른쪽

    # 찾는 범위가 구간을 벗어나면 0을 리턴
    if end < left or start > right:
        return 0

    # 찾는 범위 왼쪽이 구간 왼쪽보다 작거나 같고,
    # 찾는 범위 오른쪽이 구간 오른쪽보다 크거나 같으면 구간값 리턴
    if left <= start and end <= right:
        return tree[i]

    # 모두 아니라면 구간을 절반으로 나누고,
    mid = (start + end) // 2
    return search(start, mid, left, right, i*2) + search(mid+1, end, left, right, i*2+1)

예를 들어, nums의 2번인덱스부터 4번인덱스까지 합을 구한다고 생각해보자.

nums[2:5] = [8, 4, 7]

그러면 8과 4, 7을 모두 더하는게 아닌 12와 7만 찾아서 더하면 된다.

나는 이 search함수가 처음에 이해하기 제일 어려웠는데..
사실 그 이유는 함수 인자로 start, end, left, right.. 너무 헷갈려..
좀 바꿔야겠어..

아무튼 지금은 트리 크기 자체가 그렇게 크지 않아서 조금 비효율적으로 보일지 모르겠으나
트리 크기가 커질수록 더 효율적으로 바뀌는 모습이 된다.

print(search(0, 6, 2, 4)) # 19

트리 업데이트

중간에 nums 배열의 값이 바뀐다면
그에 맞춰 트리의 값들도 모두 바꿔줘야 한다.

여기서는 함수 인자로 원래 값과의 차이만 넣어서
루트노드부터 아래로 내려오면서 값을 업데이트 해주면 된다.

# 구간합 업데이트
def update(start, end, idx, diff, i=1):
    if start > idx or idx > end:
        return
    tree[i] += diff
    if start != end:
        mid = (start + end) // 2
        update(start, mid, idx, diff, i*2)
        update(mid+1, end, idx, diff, i*2+1)

사실 이해하고 보면 위에 두 함수와 논리상 별 차이는 없다.

update(0, 6, 3, 1 - nums[3])
print(tree)
# [0, 32, 14, 18, 5, 9, 12, 6, 3, 2, 8, 1, 7, 5, 0, 0]

끄적임

구간곱구하기
사실 구간합이야 그렇다 치지만, 뭔가 구간곱부터도 저렇게만 구현하면 뭔가 막히는 부분이 있달까..?

update 부분만 봐도 기존 값이 0이라면 변경값/0을 곱해야하는데
ZeroDivision 에러 나오니..

기존 값이 0일때를 분기 처리해도 괜히 코드만 복잡해진다.

루트노드부터 변화율로 곱해주는건 힘들다.
(그렇다고 그 부분부터 segment를 재생성 하는 건 비효율적인 거 같으니..)

조금 더 직관적인 update를 위해 리프노드 위치에서 값을 변경해주고
위로 올라가면서 갱신해주자고 생각해보자.

이러기 위해선 처음에 segment tree 생성 과정에서
리프노드 인덱스를 미리 저장해둬야 한다.
(nums의 길이만큼 메모리를 더 써야하지만.. 그래도 아직 효율 먼저 생각할 수준은 아니니까..)

location = [0] * n

def segment(left, right, i=1):
    if left == right:
        tree[i] = nums[left]
        location[left] = i
        return tree[i]
    mid = (left+right) // 2
    tree[i] = segment(left, mid, i*2) * segment(mid+1, right, i*2 + 1)
    return tree[i]

location을 확인해보면

print(tree)
# [0, 40320, 192, 210, 6, 32, 35, 6, 3, 2, 8, 4, 7, 5, 0]
print(location)
# [8, 9, 10, 11, 12, 13, 7]

리프노드의 위치를 알 수 있다.

아무튼 여기서도 3번 인덱스의 값을 1로 바꿔보면,

idx = location[3]
tree[location[3]] = 1
while idx > 1:
    idx //= 2
    tree[idx] = tree[idx*2] * tree[idx*2 + 1]
    
print(tree)
# [0, 10080, 48, 210, 6, 8, 35, 6, 3, 2, 8, 1, 7, 5, 0]

이렇게 된다.

기존 update보다 간단하기도 하고,
재귀를 사용한 방식이 아니라 (내 기준) 좀 더 직관적이기도 하고..

아무튼 update는 bottom-up 방식으로 하지만
생성, 탐색은 top-down 방식으로 하는 아직은 애매한..


문제

세그먼트 트리하면 가장 대표적인 문제 두 개..

boj. 2042 구간 합 구하기
boj. 11505 구간 곱 구하기

다른 문제들은 더 풀어봐야 알 거 같다

끝!

profile
코딩배우는 문도리

0개의 댓글