[1517번 버블소트]
https://www.acmicpc.net/problem/1517
비재귀 세그먼트 트리 방식으로 풀었다.
비재귀 세그먼트 구현에 관한 풀이는 아래의 링크 참고
https://velog.io/@statco19/segment-tree-non-recursive
가장 핵심이 되는 로직은 swap이 발생하는 횟수는 자신보다 오른쪽에 위치한 숫자 중 자신보다 작은 숫자의 개수와 같다는 점이다. 예를 들어 A = {4,3,2,1} 일 때, A[0] = 4의 경우, A[1] ~ A[3] 중에서 4보다 작은 수가 3개다. A[1] = 3의 경우 2개, A[2] = 2의 경우 1개, A[3] = 1의 경우 0개가 되어 총 합을 구하면 6이 된다. 이 값이 swap 발생 횟수와 같아진다.
따라서 세그먼트 트리를 구성할 때 노드에 들어가게 되는 값은 어떤 값의 오른쪽에 위치한 값들 중 자신보다 작은 값들의 개수가 된다.
트리를 구성하기 이전에 {value, index}와 같은 pair 형태로 입력값을 저장하고, 값과 인덱스 모두 오름차순으로 정렬한다.
bool cmp(pint &a, pint &b) {
if(a.first < b.first) return true;
else if(a.first == b.first) {
return a.second < b.second;
} else return false;
}
arr.resize(N); // vector<int> arr;
for(int i=0;i<N;++i) {
scanf("%d", &x);
arr[i] = {x,i};
}
sort(arr.begin(), arr.end(), cmp);
트리의 모든 노드 값이 0인 상태에서 query, update 메서드만 사용하여 원하는 값을 구할 수 있다.
#define ll long long
ll ans;
const int MAX = 500000;
int N;
ll t[2*MAX];
vector<pint> arr;
void update(int p, ll val) {
for(t[p+=N]=val;p>1;p>>=1) t[p>>1] = t[p] + t[p^1];
}
ll query(int l, int r) {
ll res = 0;
for(l+=N,r+=N;l<r;l>>=1,r>>=1) {
if(l&1) res += t[l++];
if(r&1) res += t[--r];
}
return res;
}
void sol() {
for(int i=0;i<N;++i) {
ans += query(arr[i].second+1, N);
update(arr[i].second, 1);
}
}
A = {4,3,2,1} 예시를 그림으로 한 단계 씩 살펴보면 다음과 같다.
{value, index} pair 형태로 입력 받아 정렬한다.
정렬 후 값이 작은 것부터 해당 값의 원래 위치(index)보다 오른쪽에 있는 수 중 해당 값보다 작은 숫자의 개수를 세그먼트 트리에 쿼리를 날려 알아낸다(세그먼트 트리의 각 노드는 [l,r) 범위에서 A[l-1] 값보다 작은 수의 개수를 가지고 있다).
i = 0,
정렬 후이기 때문에 값은 1이 되며, 1의 index는 3이다. 따라서 [4,4) 범위에 해당하는 노드의 값을 가져오면 된다. 비재귀 세그먼트 트리 구현에서 [N,N)의 노드는 없으므로 자연스럽게 0을 리턴한다.
ans += 0
쿼리를 날려 값을 알아낸 이후 update를 통해 index 3을 나타내는 노드(leaf node)의 값을 1로 설정한다. 그 이유는 정렬 이후 앞으로 나올 값들은 모두 1보다 크기 때문에 1을 오른쪽에 두는 값(예제에서는 2,3,4)이 쿼리를 날릴 때 카운트 되어야 하기 때문이다.
i = 1,
값은 2가 되며, 2의 index는 2이다. [3,4) 범위에 해당하는 노드 값인 1을 ans에 더해준다.
ans += 1
update를 통해 index 2에 해당하는 노드의 값을 1로 설정하고 index 2를 포함하는 범위의 값을 모두 변경해준다.
i = 2,
값은 3이 되며, 3의 index는 1이다. [2,4) 범위에 해당하는 노드 값인 2를 ans에 더해준다.
ans += 2
update를 통해 index 1에 해당하는 노드의 값을 1로 설정하고 index 1을 포함하는 범위의 값을 모두 변경해준다.
i = 3,
값은 4가 되며, 4의 index는 0이다. [1,4) 범위에 해당하는 노드 값인 1+2 = 3을 ans에 더해준다.
ans += 3
update를 통해 index 0에 해당하는 노드의 값을 1로 설정하고 index 0을 포함하는 범위의 값을 모두 변경해준다.
이로써 연산이 완료되었고, ans
의 값은 1 + 2 + 3 = 6
이 되었다.
#include <bits/stdc++.h>
#define ll long long
using namespace std;
typedef pair<int,int> pint;
typedef vector<int> vint;
const int INF = 0x3f3f3f3f; const int mINF = 0xc0c0c0c0;
const ll LINF = 0x3f3f3f3f3f3f3f3f; const ll mLINF = 0xc0c0c0c0c0c0c0c0;
ll ans;
const int MAX = 500000;
int N;
ll t[2*MAX];
vector<pint> arr;
bool cmp(pint &a, pint &b) {
if(a.first < b.first) return true;
else if(a.first == b.first) {
return a.second < b.second;
} else return false;
}
void update(int p, ll val) {
for(t[p+=N]=val;p>1;p>>=1) t[p>>1] = t[p] + t[p^1];
}
ll query(int l, int r) {
ll res = 0;
for(l+=N,r+=N;l<r;l>>=1,r>>=1) {
if(l&1) res += t[l++];
if(r&1) res += t[--r];
}
return res;
}
void sol() {
for(int i=0;i<N;++i) {
ans += query(arr[i].second+1, N);
update(arr[i].second, 1);
}
}
int main() {
// ios_base :: sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
scanf("%d", &N);
int x;
arr.resize(N);
for(int i=0;i<N;++i) {
scanf("%d", &x);
arr[i] = {x,i};
}
sort(arr.begin(), arr.end(), cmp);
sol();
printf("%lld\n", ans);
return 0;
}