[백준/python/2042] 구간 합

bej_ve·2022년 4월 25일
0

python알고리즘

목록 보기
21/46

문제링크 : 구간 합

골드에서 쉬운 문제를 발견하면 항상 런타임 에러가 뜬다. 골드문제인 이유가 있었다. 처음 문제를 보고 단순 구현으로 푼 코드는 아래와 같다. 풀면서도 런타임 에러가 뜰 것 같았다.

import sys

input=sys.stdin.readline
n,m,k=map(int, input().split())
arr=[]

for _ in range(n):
    arr.append(int(input()))
for _ in range(m+k):
    a,b,c=map(int, input().split())
    if a==1:
        arr[arr.index(b)]=c
    elif a==2:
        sum=0
        for i in range(b-1,c):
            sum+=arr[i]
        print(sum)

그래서 찾아보니 세그먼트 트리를 사용해서 푸는 문제였다. 세그먼트 트리를 사용하면 하나의 조건을 수행하는데 O(logN)의 시간이 걸린다.

# 0. 입력받기
import sys
from math import ceil, log

input=sys.stdin.readline
N, M, K=map(int, input().split())
l=[]
segment_tree=[0]*(N*4)

# 1. 트리 만들기
def init(start, end, index):
    # start와 end가 같다면 리프노드이다.
    if start==end:
        segment_tree[index]=l[start-1]
        return segment_tree[index]
    # 현재 노드는 왼쪽 아래 노드와 오른쪽 아래 노드를 더한 값이다.
    mid=(start+end)//2
    segment_tree[index]=init(start, mid, index*2)+init(mid+1, end, index*2+1)
    return segment_tree[index]


# 2. 트리에서 값 찾기
def find(start, end, index, left, right):
    # 찾으려는 범위가 start~end 범위보다 클 경우
    if start>right or end<left:
        return 0
    # 찾으려는 범위가 segment tree 노드안에 구현되어 있을 경우
    if start>=left and end<=right:
        return segment_tree[index]
    # 코드를 동작시키기 위한 기본 코드
    # 현재 노드는 왼쪽아래 + 오른쪽아래 노드이다.
    mid=(start+end)//2
    sub_sum=find(start, mid, index*2, left, right) + find(mid+1, end, index*2+1, left, right)
    return sub_sum


# 3. 트리 값 바꿔주기
def update(start, end, index, update_idx, update_data):
    # update 하려는 범위가 초과될 경우
    if start > update_idx or end < update_idx:
        return
    segment_tree[index]+=update_data
    # 리프노드까지 바꿔주었으면 재귀함수를 끝낸다.
    if start==end:
        return
    # 내가 관여하고 있는 노드들을 찾아서 바꿔준다 -> 재귀함수로 구현
    mid=(start+end)//2
    update(start, mid, index*2, update_idx, update_data)
    update(mid+1, end, index*2+1, update_idx, update_data)

for _ in range(N):
    l.append(int(input()))
    
init(1, N, 1)

for _ in range(M+K):
    a, b, c=map(int, input().split())
    if a==1:
        temp=c-l[b-1]
        l[b-1]=c
        update(1, N, 1, b, temp)
    elif a==2:
        print(find(1, N, 1, b, c))

다른 블로그를 참조한 코드이다. 세그먼트 트리에 대해 잘 설명해주는 블로그들이 많아서 그림을 그려가며 공부했다. 각 노드의 start와 end, index의 개념을 잘 이해하는 것이 중요하다.

0개의 댓글