[Python 알고리즘] Segment Tree (세그먼트 트리) (feat. boj 2357 최솟값과 최댓값)

이예서·2023년 6월 11일
0
post-thumbnail

boj2357

시간 제한 2초를 통해 알 수 있는 것

  1. 일단 주어진 정수의 개수 n이 100,000이므로, O(nlog(n))의 시간복잡도 필요
  2. 최댓값 세그먼트 트리, 최솟값 세그먼트 트리 2개 만들어 각각 접근
    트리 접근 log(100,000) Test Case 100,000 2 (최댓값, 최솟값)
    => 2nlog(n) 이므로 2 초

입력 처리

import sys
input = sys.stdin.readline

# input
N, M = map(int, input().split())
arr = [int(input()) for _ in range(N)]
pair = [map(int, input().split()) for _ in range(M)]

구간합 세그먼트 트리 구현

세그먼트 트리의 가장 일반적인 형태는 구간합을 구하는 형태일 것이다.
다음은 구간합을 구하는 세그먼트 트리의 코드 예시이다.

class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)  # 트리의 크기는 원래 배열 크기의 4배

        self.build(arr, 0, self.n - 1, 1)  # 세그먼트 트리 구축

    def build(self, arr, left, right, node):
        if left == right:  # 리프 노드에 도달한 경우
            self.tree[node] = arr[left]
            return

        mid = (left + right) // 2
        self.build(arr, left, mid, node * 2)  # 왼쪽 자식 노드 구축
        self.build(arr, mid + 1, right, node * 2 + 1)  # 오른쪽 자식 노드 구축
        # 왼쪽 자식 노드와 오른쪽 자식 노드의 값을 요약하여 현재 노드에 저장
        self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

    def query_sum(self, left, right, node, node_left, node_right):
        if right < node_left or left > node_right:  # 구간이 완전히 벗어난 경우
            return 0
        if left <= node_left and node_right <= right:  # 구간이 완전히 포함되는 경우
            return self.tree[node]

        mid = (node_left + node_right) // 2
        # 왼쪽 자식 노드와 오른쪽 자식 노드로 분할하여 구간 합 계산
        return self.query_sum(left, right, node * 2, node_left, mid) + \
               self.query_sum(left, right, node * 2 + 1, mid + 1, node_right)

    def get_sum(self, left, right):
        return self.query_sum(left, right, 1, 0, self.n - 1)

또는 아래와 같이 나타낼 수도 있다.

class SegmentTree:
    def __init__(self, arr):
        self.arr = arr
        self.tree = [0] * (4 * len(arr))  # 세그먼트 트리를 저장할 배열

    def build(self, node, start, end):
        if start == end:
            self.tree[node] = self.arr[start]
        else:
            mid = (start + end) // 2
            self.build(2 * node, start, mid)  # 왼쪽 자식 노드를 빌드
            self.build(2 * node + 1, mid + 1, end)  # 오른쪽 자식 노드를 빌드
            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]  # 요약 정보 업데이트

    def query(self, node, start, end, left, right):
        if left > end or right < start:  # 구간이 완전히 벗어난 경우
            return 0
        if left <= start and right >= end:  # 구간이 완전히 포함되는 경우
            return self.tree[node]
        
        mid = (start + end) // 2
        left_sum = self.query(2 * node, start, mid, left, right)  # 왼쪽 자식 노드로 재귀 호출
        right_sum = self.query(2 * node + 1, mid + 1, end, left, right)  # 오른쪽 자식 노드로 재귀 호출

        return left_sum + right_sum

사용 예시

arr = [1, 3, 5, 7, 9, 11]
tree = SegmentTree(arr)
tree.build(1, 0, len(arr) - 1)  # 세그먼트 트리 빌드

print(tree.query(1, 0, len(arr) - 1, 1, 4))  # 구간 [1, 4]의 합 출력

+) 위 문제와 상관없으나 update 함수는 다음과 같다.

def update(self, node, start, end, index, diff):
    if index < start or index > end:  # 인덱스가 구간에 속하지 않는 경우
	    return
    self.tree[node] += diff
    if start != end:  # 리프 노드가 아닌 경우
        mid = (start + end) // 2
        self.update(2 * node, start, mid, index, diff)  # 왼쪽 자식 노드로 재귀 호출
        self.update(2 * node + 1, mid + 1, end, index, diff)  # 오른쪽 자식 노드로 재귀 호출
            
tree.update(1, 0, len(arr) - 1, 2, 2)  # 인덱스 2의 값을 2만큼 증가

구간의 최댓값과 최솟값 세그먼트 트리 구현

위 문제에서는 각 구간에 대해 최솟값과 최댓값을 구해야 하므로 총 두개의 트리가 필요하다. build함수와 query함수 에서 자식 트리의 값을 합하는 부분을 적절히 변경하면 된다. 다음은 구현 코드이다.

class SegmentTree():
    def __init__(self, arr):
        self.n = len(arr)
        self.mintree = [0]*(4*self.n)
        self.maxtree = [0]*(4*self.n)

        self.build_min(arr, 0, self.n-1, 1)
        self.build_max(arr, 0, self.n-1, 1)

    def build_min(self, arr, left, right, node):
        if left == right:
            self.mintree[node] = arr[left]
            return

        mid = (left + right) // 2
        self.build_min(arr, left, mid, node*2)
        self.build_min(arr, mid+1, right, node*2+1)
        self.mintree[node] = min(self.mintree[node * 2], self.mintree[node*2+1])

    def build_max(self, arr, left, right, node):
        if left == right:
            self.maxtree[node] = arr[left]
            return

        mid = (left + right) // 2
        self.build_max(arr, left, mid, node*2)
        self.build_max(arr, mid+1, right, node*2+1)
        self.maxtree[node] = max(self.maxtree[node*2], self.maxtree[node*2+1])

    def query_min(self, left, right, node, node_left, node_right):
        if node_right < left or right < node_left: # 구간을 벗어남
            return float('inf')
        if left <= node_left and node_right <= right: # 구간 내 포함
            return self.mintree[node]

        node_mid = (node_left + node_right) // 2
        return min(self.query_min(left, right, node*2, node_left, node_mid),
                   self.query_min(left, right, node*2+1, node_mid+1, node_right))

    def query_max(self, left, right, node, node_left, node_right):
        if node_right < left or right < node_left: # 구간을 벗어남
            return 0
        if left <= node_left and node_right <= right: # 구간 내 포함
            return self.maxtree[node]

        node_mid = (node_left + node_right) // 2
        return max(self.query_max(left, right, node*2, node_left, node_mid),
                   self.query_max(left, right, node*2+1, node_mid+1, node_right))

    def get_min(self, left, right):
        return self.query_min(left, right, 1, 0, self.n-1)

    def get_max(self, left, right):
        return self.query_max(left, right, 1, 0, self.n-1)

출력 처리

문제에서 입력받은 a,b 값은 실제 인덱스가 아닌 입력 순서이므로, -1 씩 연산한다.

segment_tree = SegmentTree(arr)

for a, b in pair:
    print(segment_tree.get_min(a-1, b-1), segment_tree.get_max(a-1, b-1))

profile
https://ohge.tistory.com/

0개의 댓글