[BOJ 27551] - Vrsta (세그먼트 트리, 오프라인 쿼리, 좌표 압축, C++, Python)

보양쿠·2024년 4월 5일
0

BOJ

목록 보기
237/260
post-custom-banner

BOJ 27551 - Vrsta 링크
(2024.04.05 기준 P4)

문제

학생들이 키에 따라 정렬되어 줄을 서 있을 때, 줄 중앙에 서 있는 학생이 준비 운동을 이끌게 된다. 만약 중앙에 서 있는 학생이 두 명이면, 키가 더 작거나 같은 학생이 준비 운동을 이끌게 된다.

viv_i의 키를 가진 학생이 aia_i명 오는 사건이 총 nn번 발생한다. 이때 사건이 발생할 때마다, 준비 운동을 이끌게 되는 학생의 키를 출력.

알고리즘

카운팅 배열을 이용한 세그먼트 트리 + 압축 테크닉

풀이

중앙값 문제, 사탕상자 문제와 굉장히 유사한 문제이다.

구간합 세그먼트 트리에, viv_i, aia_i 쿼리가 주어질 때마다 인덱스 viv_iaia_i를 추가하는 업데이트를 한다고 생각을 해보자.

일단 ii번째 쿼리 때 학생들의 총합은 a1a_1, \dots, aia_i의 합일 것이다. 이는 ctct에 누적하면서 저장하자.

중앙값은 (ct+1)/2(ct + 1) / 2번째 학생이다. 이 중앙값에 따라 세그먼트 트리에 어떻게 쿼리를 날릴 것인지 생각을 해보자.

현재 [st,en][st, en] 구간을 나타내는 ndnd번째 노드를 살펴보고 있다고 가정해보자. stenst \ne en이면 현재 노드엔 왼쪽 자식과 오른쪽 자식이 달려 있다. 만약 이 구간에서 ii번째 인덱스를 찾는다고 생각해보자.

왼쪽 자식이 나타내는 구간 합 TlT_{l}ii보다 크거나 같다면? 현재 구간의 ii번째 인덱스는 왼쪽 자식에 포함되어 있다.
왼쪽 자식이 나타내는 구간 합 TlT_{l}ii보다 작다면? 현재 구간의 ii번째 인덱스는 오른쪽 자식에 포함되어 있다.

이를 이용해서 현재 구하고자 하는 (ct+1)/2(ct + 1) / 2를 전체 구간 쿼리에 날려서, st=enst = en를 만족할 때까지 왼쪽 자식이 나타내는 구간 합에 따라 왼쪽이나 오른쪽으로 파고 들어가면 된다. 단, 오른쪽 자식으로 들어간다면 당연히 구하고자 하는 인덱스는 iTli - T_{l}가 된다. 왼쪽 자식의 구간 합이 제외되기 때문이다.


그런데 문제가 있다. viv_i는 최대 10910^9이므로 그대로 카운팅 배열을 사용하지 못하게 된다.

nn은 최대 200000200\,000이다. 그럼 viv_i를 압축시켜버리면? 범위 [0,199999][0, 199\,999]의 수로 나타낼 수 있다. 그러면 카운팅 배열을 사용할 수 있게 된다.

코드

  • C++
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;

typedef pair<int, int> pii;
typedef long long ll;

const int MAXN = 200'000, MAXM = 1 << (int)ceil(log2(MAXN) + 1);

ll T[MAXM];

void update(int nd, int st, int en, int idx, ll val){
    if (st == en){
         T[nd] += val;
         return;
    }
    int mid = (st + en) >> 1;
    if (idx <= mid) update(nd << 1, st, mid, idx, val);
    else update(nd << 1 | 1, mid + 1, en, idx, val);
    T[nd] = T[nd << 1] + T[nd << 1 | 1];
}

// 현재 구간에서 정확하게 ct번째인 인덱스를 찾아야 한다.
// 왼쪽 자식의 값이 ct보다 같거나 크면 ct번째는 왼쪽 자식에 있으며
// 그게 아니라면 ct번째는 오른쪽 자식에 속해 있다.
int query(int nd, int st, int en, ll ct){
    if (st == en) return st;
    int mid = (st + en) >> 1;
    if (T[nd << 1] >= ct) return query(nd << 1, st, mid, ct);
    return query(nd << 1 | 1, mid + 1, en, ct - T[nd << 1]);
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int n; cin >> n;
    vector<pii> queries(n); for (int i = 0; i < n; i++) cin >> queries[i].x >> queries[i].y;

    // v_i는 최대 1e9이므로 카운팅 배열을 그대로 사용하기 부적합하다.
    // n이 총 2e5이므로 값 압축을 하면 [0, n-1] 구간의 수로 나타낼 수 있게 된다.

    vector<int> A; // v_i를 따로 담고
    for (auto [v, a]: queries) A.push_back(v);
    sort(A.begin(), A.end()); // 중복 제거
    A.erase(unique(A.begin(), A.end()), A.end());

    map<int, int> B; // 각 값마다 압축된 값을 부여
    for (int i = 0, sz = A.size(); i < sz; i++) B[A[i]] = i;

    // 세그먼트 트리 준비
    memset(T, sizeof(T), 0);

    ll ct = 0; // 현재 a_i의 합
    for (auto [v, a]: queries){
        update(1, 0, n - 1, B[v], a); // 쿼리 순서대로 v_i의 압축된 값의 인덱스에 a_i를 더한다.
        ct += a;

        // 중앙값은 (ct+1)/2번째다.
        // 압축된 값이 아닌 원래 값을 출력해야 한다.
        cout << A[query(1, 0, n - 1, (ct + 1) / 2)] << '\n';
    }
}
  • Python
import sys; input = sys.stdin.readline
from math import ceil, log2

def update(nd, st, en, idx, val):
    if st == en:
        T[nd] += val
        return
    mid = (st + en) >> 1
    if idx <= mid:
        update(nd << 1, st, mid, idx, val)
    else:
        update(nd << 1 | 1, mid + 1, en, idx, val)
    T[nd] = T[nd << 1] + T[nd << 1 | 1]


# 현재 구간에서 정확하게 ct번째인 인덱스를 찾아야 한다.
# 왼쪽 자식의 값이 ct보다 같거나 크면 ct번째는 왼쪽 자식에 있으며
# 그게 아니라면 ct번째는 오른쪽 자식에 속해 있다.
def query(nd, st, en, ct):
    if st == en:
        return st
    mid = (st + en) >> 1
    if T[nd << 1] >= ct:
        return query(nd << 1, st, mid, ct)
    return query(nd << 1 | 1, mid + 1, en, ct - T[nd << 1])

n = int(input())
queries = [tuple(map(int, input().split())) for _ in range(n)]

# v_i는 최대 1e9이므로 카운팅 배열을 그대로 사용하기 부적합하다.
# n이 총 2e5이므로 값 압축을 하면 [0, n-1] 구간의 수로 나타낼 수 있게 된다.

A = set() # v_i를 중복없이 담을 set
for v, a in queries:
    A.add(v)
A = sorted(A) # set을 정렬된 리스트로 변환

B = {v: i for i, v in enumerate(A)} # 각 값마다 압축된 값을 부여

# 세그먼트 트리 준비
T = [0] * (1 << ceil(log2(n) + 1))

ct = 0 # 현재 a_i의 합
for v, a in queries:
    update(1, 0, n - 1, B[v], a) # 쿼리 순서대로 v_i의 압축된 값의 인덱스에 a_i를 더한다.
    ct += a

    # 중앙값은 (ct+1)/2번째다.
    # 압축된 값이 아닌 원래 값을 출력해야 한다.
    print(A[query(1, 0, n - 1, (ct + 1) // 2)])
profile
GNU 16 statistics & computer science
post-custom-banner

0개의 댓글