from previous blog..
before persistant segment tree
프로젝트가 끝나고 살살 문제를 다시 풀어보려고 랜덤다이스를 굴렸다가 재미있는 문제를 발견했다.
그냥 풀기에는 대략 O(n^2) 정도의 시간이 걸리는 문제인 것 같고, 높은 난이도의 문제답게 inital code로는 당연히 실패..
알고리즘 분류를 보는데 되게 다양한 알고리즘이 쓰여있어서 가장 아래 persistant segment tree에 대한 공부를 하기 시작했고, 그 개념에 대해 명확하게 감이 잘 잡히지 않아 우선 segment tree를 공부하게 되었다.
segment tree의 개념은 결과적으로 배열과 연산(기본적으로 +, 개념에 대해 문제를 풀고 이해한 결과 곱연산도 충분히 가능할 듯?)이 주어졌을 때 binary tree의 구조를 이용해서 연속적인 배열의 부분을 연산했을 때의 정보를 미리 저장해놓는 방법이라고 할 수 있겠다.
한가지 예시를 들어보자.
10개의 원소를 가진 배열 l = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 이 있고, 주어진 연산은 +라고 하자.
tree의 root에는 l의 모든 원소의 합 55가 원소로 들어간다.
root의 child를 살펴보자. left child의 원소에는 l을 반으로 쪼갰을 때 앞에 있는 배열의 합(15)가 들어가고, right child의 원소에는 뒤에 있는 배열의 합(40)이 들어간다.
다시 left child의 left child(편의상 각각 ll, lr, rl, rr child라 부르고 depth가 깊어질수록 rl sequence를 늘리는 식으로 쓴다., 즉, ll child)에는 1,2의 합 3이 들어가고 lr child에는 12, rl child에는 13, rr child에는 27이 들어가게 된다. 그리고 마지막 leaf에는 각 배열의 원래 값들이 들어가게 되고, 배열의 마지막 원소가 tree에 들어가면서 segment tree가 완성된다. 이런식으로 미리 연산을 하게 되었을 때 저장해야 되는 정보의 양은 약 2배가 늘어나게 되지만, 배열의 특정 구간부터 연속적인 구간의 연산들을 구할 때는 미리 연산한 값을 통해 그냥 더할 때( O(n) )와 비교해서 더 효율적으로( O(log n) ) 연산이 되는 시간을 줄일 수 있다.
이를 문제 하나를 예시로 들며 파이썬으로 구현해보자.
파이썬의 경우 class를 이용해서 binary tree를 만들 수 있기는 하지만 요즘 index를 익숙하게 쓰는 연습을 하고 있기 때문에 indexing으로 segment tree를 구현했다.
nums = [1,2,3,4,5,6,7,8,9,10]
segtree = [0] * 100 # nums의 크기에 따라 segtree의 크기도 커져야하지만 여기선 100으로도 충분하다.
N = len(nums)
def inittree(idx,start,end):
if start == end:
segtree[idx] = nums[start]
return segtree[idx]
else:
segtree[idx] = inittree(2*idx,start,(start+end)//2) + inittree(2*idx+1,(start+end)//2+1,end)
return segtree[idx]
inittree(1,0,N-1)
print(segtree[:30])
index를 이용한 binary tree의 원리는 간단하기에 설명은 생략하고, 여기서 segment tree의 형태를 보고싶다면 간단히 배열로 된 segment tree를 2의 배수(차례대로 1, 2, 4, 8 ...)으로 짜른 후 depth에 맞게 binary tree 형태로 숫자를 나열해보면 된다.
이제 문제로 돌아가보자.
해당 문제의 경우 초기 input은 3개의 숫자 N,M,K(차례대로 배열의 크기, 수의 변경 횟수, 합을 구하는 횟수)를 받은 후
다음의 N개의 줄에서는 하나의 숫자를 input으로 받는다.( 이 숫자들은 배열이 된다.)
이후 N+K개의 줄에는 3개의 숫자들을 input으로 받는다.(차례대로 a,b,c라 부르겠다.)
이 때
a가 1인 경우 : 배열의 b번째 숫자를 c로 바꾼다.
a가 2인 경우 : 배열의 b번째 숫자부터 c번째 숫자까지의 합을 구한다.
의 작업을 시행한다.
여기서 잠깐 big O time에 대한 계산을 해보자. 우리는 2가지 case(segment tree를 이용하지 않는 raw case, 이용하는 case)를 비교해볼 것이다.
첫 번째 케이스에서 a가 1인 경우 O(1)의 시간이 걸린다. 그리고 a가 2인 경우(당연히 worst case를 가정하기 때문에 모든 수를 다 더하는 과정이 수행되야 하므로) O(N)의 시간이 걸린다. 따라서 최종적으로 모든 결과물을 뽑는데 걸리는 시간은 O(NK)이 된다.
두 번째 케이스에서 a가 1인 경우 segment tree 위의 index를 포함하는 모든 배열의 수를 바꿔줘야 하므로 O(log n)의 시간이 걸린다는 점에서 첫 번째 케이스보다 오랜 시간이 걸리게 된다. 하지만 a가 2인 경우 합이 맞는 가장 큰 배열들을 더해주기만 하면 되기 때문에 O(log n)의 시간이 걸리게 되므로(물론 적당한 배열의 순서를 찾는다는 점에서 약간의 시간이 더 걸리긴 하지만 기껏해야 constant multiple 일 뿐이다..) 따라서 최종적으로 모든 결과물을 뽑는데 걸리는 시간은 O((M+K)log N)이 되고, N,M,K가 충분히 크다는 가정 하에 우리는 두 번째 케이스가 시간적으로 훨씬 효율적인 방법임을 알 수 있게 된다.(물론 각자의 값에 따라 첫 번째 케이스가 더 효율적일 수 있는 경우가 있다는 것은 기억하고 있으면 좋고 아님 말고..)
이제 a에 따른 두 가지 케이스에 대한 함수를 생각해보자.
def updatetree(idx,start,end,idx2,diff):
if idx2 < start or idx2 > end:
# ! start와 end가 idx와 같이 움직이기 때문에 변할지 말지를 결정하는 절대적인 기준(idx2)이 필요.
return
# print('change idx =',idx)
segtree[idx] += diff
if start!=end:
updatetree(idx*2,idx2,diff,start,(start+end)//2)
updatetree(idx*2+1,idx2,diff,(start+end)//2+1,end)
def value_sum(idx,start,end,stt,ed):
if end < stt or ed < start:
return 0
elif stt <= start and end <= ed:
return segtree[idx]
else:
return value_sum(idx*2,start,(start+end)//2,stt,ed) + value_sum(idx*2+1,(start+end)//2+1,end,stt,ed)
최종적으로 이 함수들을 조합할 경우 완전한 풀이가 완성된다.