백준 2042 - 구간 합 구하기 (python)

평범한 대학생·2023년 2월 9일
1

baekjoon

목록 보기
3/12
post-thumbnail

구간 합 문제 보러가기


세그먼트 트리(Segment Tree)란?

  • 여러 개의 데이터가 존재할 때 특정 구간(중간)의 합(최솟값, 최댓값, 곱 등)을 구하는 데 사용하는 자료구조
  • 이진 트리 형태
  • 특정구간의 합을 가장 빠르게 구할 수 있음
  • 시간복잡도 : O(logN)O(logN)

세그먼트 트리 관련 문제 구현 패턴 및 기능

  • 세그먼트 트리 공간 할당
    • Python 에서는 리스트로 생성
  • 세그먼트 트리 생성 & 초기화
    • 세그먼트 트리의 인덱스는 무조건 1부터 시작
    • 이 과정에서 구간의 합을 구할것인지 아니면 최솟값, 최댓값, 곱을 구할 것인지 결정해서 트리를 생성
  • 원하는(특정) 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 함수 생성
    • 범위 안에 있는 경우에 한해서만 더해주면 됨
    • 중간에(구간의) 어떤 부분의 합(최솟값, 최댓값, 곱 등) 구하는 함수 생성
  • 특정 원소의 값을 수정하는 함수 생성
    • 중간에 수의 변경이 일어남
    • 해당 원소를 포함하고 있는 모든 구간 합 노드들을 갱신
    • 해당 원소를 포함하고 있는 부분적인 노드들만 바꿔주면 댐
  • 항상 트리는 루트부터 시작
  • 트리와 관련된 거의 모든 구현은 재귀적으로 구현된다.

세그먼트 트리(Segment Tree) 자세한 설명 보러가기 (Python 기준)

문제


어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.


입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.


출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.


핵심 포인트 요약

전형적인 세그먼트 트리를 활용한 문제

  • 중간에 수의 변경 & 중간에 어떤 부분의 합 & 구간의 합을 출력
    👉 여러 개의 데이터가 존재할 때 특정 구간의 또는 중간의 합은 세그먼트 트리를 사용해 구할 수 있다.

코드 & 설명 주석 포함

import sys
sys.setrecursionlimit(10**9)
input = sys.stdin.readline


# 세그먼트 트리를 배열의 각 구간 합으로 채워주기 (세그먼트 트리 생성 & 초기화)
def init(start, end, index):
    # 리프노드에 도달했으면 
    if start == end:
        tree[index] = data[start]
        return tree[index]
    # 두개의 서브트리로 쪼갬
    mid = (start + end)//2
    # 후위 순회(LRV) 방식으로 값을 채워 나간다. 
    tree[index] = init(start, mid, index*2) + init(mid+1, end, index*2+1)
    return tree[index]
    
    
# 원하는(특정) 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 함수
def interval_sum(start, end, index, left, right):
	# 범위를 완전히 벗어난 경우 (내가 원하는 구간이 아닌 경우)
    if left > end or right < start:
        return 0
    # 범위 안에 있는 경우 (내가 원하는 구간안에 속해 있는 경우)
    if left <= start and end <= right:
        return tree[index]
    # 그렇지 않다면 두 서브트리로 쪼개서 합을 구함
    mid = (start + end)//2
    return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid+1, end, index*2+1, left, right)
    

# 특정 원소의 값을 수정하는 함수
def update(start, end, index, up_inx, diff):
    # 구간안에 수정할 구간(인덱스) 없는 경우 
    if up_inx < start or up_inx > end:
        return
    # 구간안에 수정할 구간(인덱스) 있는 경우 수정 해야할 diff(차이) 만큼 갱신해줌
    tree[index] += diff 
    # 리프노드까지 바꿔주었으면 다시 전 단계로 돌아감
    if start == end:
        return
    # 한 단계 아래로 내려가서 탐색함
    mid = (start + end)//2
    update(start, mid, index*2, up_inx, diff) # 왼쪽 서브트리
    update(mid+1, end, index*2+1, up_inx, diff) # 오른쪽 서브트리


# N개의 수, M: 변경이 일어나는 횟수, K: 구간의 합을 구하는 횟수
N, M, K = map(int, input().split())

data = [0] + [int(input()) for _ in range(N)]
tree = [0] * (N * 4)

# 트리 생성
init(1, N, 1)

for _ in range(M+K):
    a, b, c = map(int, input().split())
    # a가 1인경우 b번째 수를 c로 바꿈
    if a == 1:
        val = c - data[b]	# 변경할값 - 원래값 = 차이를 구해줌
        data[b] = c			# 원래 배열에서 값 수정
        update(1, N, 1, b, val)
    # a가 2인경우 b번째 수부터 c번째 수까지의 합을 구함
    elif a == 2:
        print(interval_sum(1, N, 1, b, c))

코드 & 설명 주석 미포함

import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline


def init(start, end, index):
    if start == end:
        tree[index] = data[start]
        return tree[index]
        
    mid = (start + end)//2
    tree[index] = init(start, mid, index*2) + init(mid+1, end, index*2+1)
    return tree[index]


def interval_sum(start, end, index, left, right):
    if left > end or right < start:
        return 0
    if left <= start and end <= right:
        return tree[index]

    mid = (start + end)//2
    return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid+1, end, index*2+1, left, right)


def update(start, end, index, up_inx, diff):
    if up_inx < start or up_inx > end:
        return
    
    tree[index] += diff 
    if start == end:
        return
    
    mid = (start + end)//2
    update(start, mid, index*2, up_inx, diff)
    update(mid+1, end, index*2+1, up_inx, diff)
    
    
N, M, K = map(int, input().split())
data = [0] + [int(input()) for _ in range(N)]
tree = [0] * (N * 4)

init(1, N, 1)
for _ in range(M+K):
    a, b, c = map(int, input().split())
    
    if a == 1:
        val = c - data[b]
        data[b] = c
        update(1, N, 1, b, val)
        
    elif a == 2:
        print(interval_sum(1, N, 1, b, c))

보충 설명

def update(start, end, index, up_inx, diff):
    if up_inx < start or up_inx > end:
        return
    
    tree[index] += diff 
    
    if start == end:
        return

    mid = (start + end)//2
    update(start, mid, index*2, up_inx, diff)
    update(mid+1, end, index*2+1, up_inx, diff)

👉 어떤 방식으로 값이 변경되는 것인가?

  • 위 함수는 원래의 배열에 있는 원소의 값을 수정했을때 트리도 수정된 값으로 구간마다 갱신해주는 함수이다.

  • 처음 세그먼트 트리를 공부하면서 tree[index] += diff 다음과 같은 코드를 봤을 때 살짝 멈칫 했다. 한 번 손으로 따라가보면서 확인해보니 해당 코드는 트리내에서 동작할때 원래 원소와 수정할 원소간의 차이만큼 업데이트 해주는 것이였다. 처음에는 원소자체를 업데이트 한다고 생각을 했었다.


혹시나 설명이 잘못된 부분이 있으면 댓글 부탁드립니다.

profile
주니어 데이터 엔지니어 꿈나무

0개의 댓글