[1517번 버블소트] - C++

Andrew·2022년 8월 26일
0

알고리즘연습

목록 보기
28/31

[1517번 버블소트]
https://www.acmicpc.net/problem/1517

비재귀 세그먼트 트리 방식으로 풀었다.
비재귀 세그먼트 구현에 관한 풀이는 아래의 링크 참고
https://velog.io/@statco19/segment-tree-non-recursive

풀이

가장 핵심이 되는 로직은 swap이 발생하는 횟수는 자신보다 오른쪽에 위치한 숫자 중 자신보다 작은 숫자의 개수와 같다는 점이다. 예를 들어 A = {4,3,2,1} 일 때, A[0] = 4의 경우, A[1] ~ A[3] 중에서 4보다 작은 수가 3개다. A[1] = 3의 경우 2개, A[2] = 2의 경우 1개, A[3] = 1의 경우 0개가 되어 총 합을 구하면 6이 된다. 이 값이 swap 발생 횟수와 같아진다.

따라서 세그먼트 트리를 구성할 때 노드에 들어가게 되는 값은 어떤 값의 오른쪽에 위치한 값들 중 자신보다 작은 값들의 개수가 된다.

트리를 구성하기 이전에 {value, index}와 같은 pair 형태로 입력값을 저장하고, 값과 인덱스 모두 오름차순으로 정렬한다.

bool cmp(pint &a, pint &b) {
	if(a.first < b.first) return true;
	else if(a.first == b.first) {
		return a.second < b.second;
	} else return false;
}

arr.resize(N);  // vector<int> arr;
for(int i=0;i<N;++i) {
	scanf("%d", &x);
	arr[i] = {x,i};
}
sort(arr.begin(), arr.end(), cmp);

트리의 모든 노드 값이 0인 상태에서 query, update 메서드만 사용하여 원하는 값을 구할 수 있다.

#define ll long long

ll ans;
const int MAX = 500000;
int N;
ll t[2*MAX];
vector<pint> arr;

void update(int p, ll val) {
	for(t[p+=N]=val;p>1;p>>=1) t[p>>1] = t[p] + t[p^1];
}

ll query(int l, int r) {
	ll res = 0;
	for(l+=N,r+=N;l<r;l>>=1,r>>=1) {
		if(l&1) res += t[l++];
		if(r&1) res += t[--r];
	}
	return res;
}

void sol() {
	for(int i=0;i<N;++i) {
		ans += query(arr[i].second+1, N);
		update(arr[i].second, 1);
	}
}

A = {4,3,2,1} 예시를 그림으로 한 단계 씩 살펴보면 다음과 같다.

1

{value, index} pair 형태로 입력 받아 정렬한다.

2

정렬 후 값이 작은 것부터 해당 값의 원래 위치(index)보다 오른쪽에 있는 수 중 해당 값보다 작은 숫자의 개수를 세그먼트 트리에 쿼리를 날려 알아낸다(세그먼트 트리의 각 노드는 [l,r) 범위에서 A[l-1] 값보다 작은 수의 개수를 가지고 있다).

i = 0,
정렬 후이기 때문에 값은 1이 되며, 1의 index는 3이다. 따라서 [4,4) 범위에 해당하는 노드의 값을 가져오면 된다. 비재귀 세그먼트 트리 구현에서 [N,N)의 노드는 없으므로 자연스럽게 0을 리턴한다.

ans += 0

쿼리를 날려 값을 알아낸 이후 update를 통해 index 3을 나타내는 노드(leaf node)의 값을 1로 설정한다. 그 이유는 정렬 이후 앞으로 나올 값들은 모두 1보다 크기 때문에 1을 오른쪽에 두는 값(예제에서는 2,3,4)이 쿼리를 날릴 때 카운트 되어야 하기 때문이다.

i = 1,
값은 2가 되며, 2의 index는 2이다. [3,4) 범위에 해당하는 노드 값인 1을 ans에 더해준다.

ans += 1

update를 통해 index 2에 해당하는 노드의 값을 1로 설정하고 index 2를 포함하는 범위의 값을 모두 변경해준다.

i = 2,
값은 3이 되며, 3의 index는 1이다. [2,4) 범위에 해당하는 노드 값인 2를 ans에 더해준다.

ans += 2

update를 통해 index 1에 해당하는 노드의 값을 1로 설정하고 index 1을 포함하는 범위의 값을 모두 변경해준다.

i = 3,
값은 4가 되며, 4의 index는 0이다. [1,4) 범위에 해당하는 노드 값인 1+2 = 3을 ans에 더해준다.

ans += 3

update를 통해 index 0에 해당하는 노드의 값을 1로 설정하고 index 0을 포함하는 범위의 값을 모두 변경해준다.

이로써 연산이 완료되었고, ans 의 값은 1 + 2 + 3 = 6 이 되었다.

C++ 코드

#include <bits/stdc++.h>

#define ll long long
using namespace std;
typedef pair<int,int> pint;
typedef vector<int> vint;
const int INF = 0x3f3f3f3f; const int mINF = 0xc0c0c0c0;
const ll LINF = 0x3f3f3f3f3f3f3f3f; const ll mLINF = 0xc0c0c0c0c0c0c0c0;

ll ans;
const int MAX = 500000;
int N;
ll t[2*MAX];
vector<pint> arr;

bool cmp(pint &a, pint &b) {
	if(a.first < b.first) return true;
	else if(a.first == b.first) {
		return a.second < b.second;
	} else return false;
}

void update(int p, ll val) {
	for(t[p+=N]=val;p>1;p>>=1) t[p>>1] = t[p] + t[p^1];
}

ll query(int l, int r) {
	ll res = 0;
	for(l+=N,r+=N;l<r;l>>=1,r>>=1) {
		if(l&1) res += t[l++];
		if(r&1) res += t[--r];
	}
	return res;
}

void sol() {
	for(int i=0;i<N;++i) {
		ans += query(arr[i].second+1, N);
		update(arr[i].second, 1);
	}
}

int main() {
	// ios_base :: sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL);
	scanf("%d", &N);
	int x;
	arr.resize(N);
	for(int i=0;i<N;++i) {
		scanf("%d", &x);
		arr[i] = {x,i};
	}
	sort(arr.begin(), arr.end(), cmp);
	sol();
	printf("%lld\n", ans);
	return 0;
}

실행결과

profile
조금씩 나아지는 중입니다!

0개의 댓글