Segment tree code

Alpha, Orderly·2025년 8월 5일
0
class SegmentTree:
    def __init__(self, arr):
        self.n = len(arr)  # 입력 배열의 크기
        self.arr = arr  # 원본 배열
        self.tree = [0] * (4 * self.n)  # 세그먼트 트리 배열 (충분히 크게 할당)
        self.build(1, 0, self.n - 1)  # 세그먼트 트리 빌드 시작 (루트 노드 번호는 1)

    def build(self, node: int, start: int, end: int):
        # 리프 노드인 경우 (start == end)
        if start == end:
            self.tree[node] = self.arr[start]
        else:
            # 자식 노드로 분할하여 재귀적으로 빌드
            mid = (start + end) // 2
            self.build(2 * node, start, mid)  # 왼쪽 자식
            self.build(2 * node + 1, mid + 1, end)  # 오른쪽 자식
            # 현재 노드는 두 자식 노드의 합
            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def _query(self, node: int, start: int, end: int, l: int, r: int) -> int:
        # [start, end]와 [l, r]이 겹치지 않는 경우
        if r < start or end < l:
            return 0
        # [start, end]가 [l, r]에 완전히 포함되는 경우
        if l <= start and end <= r:
            return self.tree[node]
        
        # 부분적으로 겹치는 경우 왼쪽/오른쪽 자식으로 내려감
        mid = (start + end) // 2
        left_sum = self._query(2 * node, start, mid, l, r)
        right_sum = self._query(2 * node + 1, mid + 1, end, l, r)
        # 왼쪽과 오른쪽 자식의 결과를 합산
        return left_sum + right_sum
    
    def query(self, l: int, r: int) -> int:
        # 사용자에게 제공되는 쿼리 인터페이스
        # 배열 인덱스 l ~ r 구간의 합을 반환
        return self._query(1, 0, self.n - 1, l, r)
    
    def update(self, idx: int, value: int, node: int, start: int, end: int):
        # 리프 노드에 도달한 경우 (idx 위치를 찾은 경우)
        if start == end:
            self.arr[idx] = value  # 원본 배열도 갱신
            self.tree[node] = value  # 세그먼트 트리 노드 갱신
        else:
            # 자식 노드로 내려감
            mid = (start + end) // 2
            if idx <= mid:
                self.update(idx, value, 2 * node, start, mid)  # 왼쪽 자식
            else:
                self.update(idx, value, 2 * node + 1, mid + 1, end)  # 오른쪽 자식
            # 자식 노드가 변경되었으니 현재 노드도 다시 계산
            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

    def _update(self, idx: int, value: int):
        # 사용자에게 제공되는 업데이트 인터페이스
        # idx 위치의 값을 value로 변경
        self.update(idx, value, 1, 0, self.n - 1)
profile
만능 컴덕후 겸 번지 팬

0개의 댓글