import sys
input=sys.stdin.readline
def init(node, start , end):
if start==end:
tree[node]=L[start]
return # 리턴 필수
init(node * 2, start, (start + end) // 2)
init(node * 2 + 1, (start + end) // 2 + 1, end)
tree[node]=tree[node*2]+tree[node*2+1]
def update(node, start , end , index , value):
if index>end or index<start:
return
if index<=start and index>=end: # 포함관계라면
tree[node]=value
return
update(node*2 , start , (start+end)//2 , index ,value)
update(node*2+1 , (start+end)//2+1 , end ,index , value)
tree[node]=tree[node*2]+tree[node*2+1]
def query(node , start , end , left , right):
if left>end or right<start:
return 0 # 범위밖이면 0을 반환한다.
if left<=start and right>=end: #포함관계라면
return tree[node]
return query(node*2 , start , (start+end)//2 , left , right) + query(node*2+1 , (start+end)//2+1 , end , left ,right)
N,Q=map(int,input().split())
tree=[0]*(1<<22)
L=list(map(int,input().split()))
init(1 , 0, N-1)
for i in range(Q):
x,y,a,b=map(int,input().split())
a-=1 ; x-=1 ; y-=1
if x>y:
x,y=y,x
print( query(1, 0, N-1, x , y))
update(1 , 0, N-1,a,b)
📌 어떻게 접근할 것인가?
세그먼트 트리를 사용했습니다.
처음 init 함수를 통해서 리스트 입력받은 값에 따라 초기 값을 할당해줍니다. 누적합 문제이기 때문에 tree[node]=tree[node*2]+tree[node*2+1] 를 해줍니다.
update 함수는 만약 범위를 포함하고 있다면 tree[node]=value 로 값을 변경해줍니다.
이후 재귀적으로 왼쪽자식과 오른쪽 자식을 탐색하고 마지막에 누적합을 위해서 tree[node]=tree[node*2]+tree[node*2+1] 을 해줍니다.
query 함수에서는 특정 범위의 누적합을 구하는 것이기 때문에 update 함수와 비슷하게 구성하면 되는데 다른점은 포함관계에 있다면 tree[node] 를 반환하고
합을 구하기때문에
return query(node*2 , start , (start+end)//2 , left , right) + query(node*2+1 , (start+end)//2+1 , end , left ,right) 와 같이 반환해줍니다.
이후 쿼리에서는 0번째 인덱스 부터 시작했으므로 a,x,y 값은 을 해줍니다.
쿼리는 a 번째 값을 b 로 바꾸는 것이기 때문에 한번만 실행해주면 됩니다.
또한 x>y 일 수 있으므로 x 가 y 보다 값이 큰 경우 값을 변경해줍니다.