Problem link: https://www.acmicpc.net/problem/7578
segment tree를 사용하는 inversion counting 문제로 reduction될 수 있다.
주어진 0행이 m[i-2] m[i-1] m[i] m[i+1] m[i+2] ...
와 같았다고 하자.
주어진 1행에서 각 m[i]
에 대응되는 지점들을 d[i]
로 아래와 같이 표현해보자.
m[i-2] m[i-1] m[i] m[i+1] m[i+2]
d[i] d[i+2] d[i-2] d[i+1] d[i-1]
이 때, m[i]
와 d[i]
를 선으로 이어보고 나서 m[i] <-> d[i]
와 겹치는 전선이 누구일까를 생각해보면 m[i]
보다 앞에서 출발해서, d[i]
보다 뒤에 도착하는 전선들임을 알 수 있다.
전형적인 inversion counting의 예임을 알 수 있고, 그대로 풀어주면 된다.
단, 이때, m[i-1] <-> d[i-1]
을 기준으로 생각해보면 m[i-1]
보다 뒤에서 출발해서, d[i-1]
보다 앞에 도착하는 전선들을 세야하는 게 아닌가 하는 걱정이 들 수 있다.
하지만, 우상향 전선 기준으로 세건 우하향 전선으로 세건 답에는 차이가 없으므로, 여기서는 일반성을 잃지 않고 우상향 전선을 기준으로만 세었다.
#include <iostream>
#include <vector>
#include <set>
#include <map>
using namespace std;
class SegmentTree
{
private:
size_t number_of_elements_;
vector<size_t> node_;
public:
SegmentTree(const size_t number_of_elements) : number_of_elements_(number_of_elements)
{
size_t size = 1;
while (size < number_of_elements_)
{
size *= 2;
}
node_.assign(2 * size, 0);
}
private:
size_t Query(const size_t root, const size_t root_left, const size_t root_right, const size_t query_left,
const size_t query_right)
{
if (root_right < query_left || query_right < root_left)
{
return 0;
}
if (query_left <= root_left && root_right <= query_right)
{
return node_[root];
}
size_t root_mid = (root_left + root_right) / 2;
return Query(2 * root, root_left, root_mid, query_left, query_right) +
Query(2 * root + 1, root_mid + 1, root_right, query_left, query_right);
}
size_t Update(const size_t root, const size_t root_left, const size_t root_right, const size_t update_index,
const size_t update_value)
{
if (root_right < update_index || update_index < root_left)
{
return node_[root];
}
if (root_left == root_right)
{
node_[root] = update_value;
return node_[root];
}
size_t root_mid = (root_left + root_right) / 2;
node_[root] = Update(2 * root, root_left, root_mid, update_index, update_value) +
Update(2 * root + 1, root_mid + 1, root_right, update_index, update_value);
return node_[root];
}
public:
size_t Query(const size_t left, const size_t right)
{
return Query(1, 1, number_of_elements_, left, right);
}
void Update(const size_t index, const size_t value)
{
Update(1, 1, number_of_elements_, index, value);
}
const size_t number_of_elements(void) const
{
return number_of_elements_;
}
};
SegmentTree* segment_tree = nullptr;
vector<size_t> machines;
map<size_t, size_t> machine_positions;
int main(void)
{
// For faster IO
ios_base::sync_with_stdio(false);
cout.tie(nullptr);
cin.tie(nullptr);
// Read input
size_t N;
cin >> N;
machines.assign(N, 0);
for (auto& machine : machines)
{
cin >> machine;
}
for (size_t pos = 1; pos <= N; ++pos)
{
size_t machine;
cin >> machine;
machine_positions[machine] = pos;
}
// Solve
segment_tree = new SegmentTree(N);
size_t ret = 0;
for (const auto& machine : machines)
{
ret += segment_tree->Query(machine_positions[machine], segment_tree->number_of_elements());
segment_tree->Update(machine_positions[machine], 1);
}
cout << ret << "\n";
return 0;
}