Educational Codeforces Round 133 (Div. 2) - Swap and Maximum Block 링크
(2022.08.29 기준 Difficulty *2500)
(No cheating Yes study)
2^n 크기의 수열이 있을 때, q개의 k 값인 쿼리가 주어지면
range(0, 2^n - 2^k) 안에 있는 a(i) 원소를 a(i + 2^k) 원소와 교환한다. 이 때, 교환한 적이 있으면 교환하지 않고 넘어간다. 하나의 쿼리가 끝나면 수열의 최대 부분합을 출력한다.
일단 누가 봐도, 세그먼트 트리.
근데 비트마스킹도 쓰인다. 나도 처음엔 이해가 안갔는데,(지금도 이해가 잘 안감 ㅋㅅㅋ)
풀이에서 비트마스킹이 쓰이는 방법을 후술하겠다.
어떻게 해야 할지 머리론 이해가 간다. 최대 부분합을 담은 세그먼트 트리를 만들어서 원소 교환하고 최대 부분합 출력하고.
하지만 세그먼트 트리를 공부한지 얼마 되지 않아 어떻게 최대 부분합을 구해야 하는지를 몰랐고, 쿼리의 개수가 200000개가 될 수 있어 정직하게 원소 교환하고 최대 부분합 출력하는 방식은 시간 초과 날 것 같아 코드를 어떻게 짜야 할지 감이 오지 않았다. 그래서 이 문제의 Tutorial과 AC를 받은 파이썬 코드(고작 4개 밖에 안된다..)를 보며 보기 쉽게 코드를 구현해 보았다.세그먼트 트리에 부분 최대합 정보를 담으려면 총 4개의 수를 담아야 한다고 한다.
Lval = 노드가 담당하는 구간의 왼쪽 값을 포함 하는 최대 부분합
Rval = 노드가 담당하는 구간의 오른쪽 값을 포함 하는 최대 부분합
val = 노드가 담당하는 구간의 최대 부분합
all = 노드가 담당하는 구간의 전체 합그리고 세그먼트 만드는 과정에서 왼쪽 자식 노드 L과 오른쪽 자식 노드 R을 합칠 때에
아래 그림처럼 합치면 된다.
이렇게 하는 이유는
"어떤 노드에 그냥 최대 부분합만 담고 있다고 생각해보자. 아래 두 개의 노드를 합칠때 그냥 둘 중 최댓값을 고르게 되면 왼쪽 절반 안에 속해있는 부분합과 오른쪽 절반안에 속해있는 부분합 밖에 알 수 없다. 두 구간을 걸친 부분들의 합은 고려해 줄 수 없는 것이다. 따라서 아래와 같은 방식을 채택한다." 라고 한다.출처 링크 - [돌이 코딩하는 방:티스토리]
아무튼 이렇게 최대 부분합을 담은 세그먼트 트리를 만들면 이제 쿼리의 개수가 200000개나 되는 것을 어떻게 감당해야 할까 고민해야 한다.
원소 교환이 될 때, i번째 원소와 i+2^k번째 원소가 교환되는데
i의 k번째 비트가 0이면 2^k만큼 증가 하고, 1이면 2^k만큼 감소하게 된다.
그러므로 교환은 i xor 2^k랑 하는 것이다.n이 3인 수열 [0, 1, 2, 3, 4, 5, 6, 7] 이 있다면
k가 0인 쿼리가 들어오면 수열은
[1, 0, 3, 2, 5, 4, 7, 6]이 왼다.
여기서 k가 1인 쿼리가 들어오면 수열은
[3, 2, 1, 0, 7, 6, 5, 4]가 된다.
여기서 k가 2인 쿼리가 들어오면 수열은
[7, 6, 5, 4, 3, 2, 1, 0]이 된다.무엇이 보이지 않는가? 2^k만큼의 뭉텅이 만큼 다음 뭉텅이랑 바뀐다.
그리고 만약 처음 상태 값이 0이었다면.
0 xor (1 << 0) xor (1 << 1) xor (1 << 2) = 7이 된다.
여기서 쿼리 값이 만약 1, 0, 2가 순서대로 들어온다면
[5, 4, 7, 6, 1, 0, 3, 2] ->
[4, 5, 6, 7, 0, 1, 2, 3] ->
[0, 1, 2, 3, 4, 5, 6, 7]
이렇게 되므로 원래 상태로 돌아온다.
그리고 상태 값은
7 xor (1 << 1) xor (1 << 0) xor (1 << 2) = 0이 된다.아직 정확하게 타당성을 증명하지 못하겠지만
상태값이 (0 ~ 2^n - 1)에 따라 원소 교환이 된 상태를 전처리를 해주면 된다.
tutorial에 적혀 있는 전처리 하는 방법인데 뭔 소린지 정확하게 모르겠다 ㅋ..
구간마다 왼쪽 자식 L과 오른쪽 자식 R의 최대 부분합을 구하고
L과 R을 교환하여 최대 부분합을 구하여 부모 노드로 넣어주는 식으로 전처리 느낌으로 세그먼트 트리를 만들면 되는 것 같다.나도 아직 제대로 이해를 못한 거라 설명이 아주 부족할 것이다.
일단 손으로 그려보면서 이해를 해볼 것!
import sys; input = sys.stdin.readline
'''
lb, rb : 구간 최대 부분합
ls, rs : 오른쪽 값을 포함하는 최대 부분합
lp, rp : 왼쪽 값을 포함하는 최대 부분합
lS, rS : 구간 전체 합
'''
def seg(start, end):
if start == end:
val = max(arr[start], 0)
return [(val, val, val, arr[start])]
mid = (start + end) // 2
l = seg(start, mid)
r = seg(mid + 1, end)
result = []
for i in range((end - start + 1) // 2):
lb, ls, lp, lS = l[i]
rb, rs, rp, rS = r[i]
result.append((max(lb, rb, ls + rp), max(rs, rS + ls), max(lp, lS + rp), lS + rS))
l, r = r, l # 교환
for i in range((end - start + 1) // 2):
lb, ls, lp, lS = l[i]
rb, rs, rp, rS = r[i]
result.append((max(lb, rb, ls + rp), max(rs, rS + ls), max(lp, lS + rp), lS + rS))
return result
n = int(input())
l = 1 << n
arr = list(map(int, input().split()))
tree = seg(0, l - 1)
i = 0 # 교환되지 않은 상태인 0
for _ in range(int(input())):
i ^= (1 << int(input())) # xor 연산
print(tree[i][0])
진짜 어렵다. 세그먼트 공부한지 얼마 안돼서 그런건가, 아직도 이해가 잘 가지 않는다.
이 다음 문제 F번은 풀이 포기.... ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ
답도 없다.난 아직 갈 길이 먼 것 같다.
앞으로도 열심히 해서 코드포스 색깔 아주 찐~하게 만들어 보겠다.