[PS] PS 메모장

주재완·2026년 5월 15일

기본 입출력

#include <bits/stdc++.h>
#define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;

int main() {
    FASTIO;
    int x;
    cin >> x;
    cout << x << '\n';
    return 0;
}

문자열 파싱 기본

1. 입출력과 파싱의 관계

C++에서 cingetline을 함께 사용할 경우, cin이 남긴 개행 문자(\n)로 인해 getline이 빈 문자열을 읽는 문제가 발생할 수 있습니다.

int n;
cin >> n;
getline(cin, line); // line == "" (입력 오류 발생)

이 문제를 해결하려면 cin.ignore()를 사용하여 버퍼에 남아 있는 개행 문자를 제거해야 합니다.

cin >> n;
cin.ignore(); // 개행 문자 제거
getline(cin, line); // 정상 입력 처리

ios::sync_with_stdio(false);cin.tie(NULL);을 설정하면 C++ 표준 입출력 속도가 향상됩니다. 다만, 이 설정을 사용할 경우 scanf, printf와 같은 C 스타일 입출력과 혼용하지 않는 것이 좋습니다.

ios::sync_with_stdio(false);
cin.tie(NULL);

2. 문자열 파싱 방법

2.1 std::stringstream을 이용한 파싱

std::stringstream ss(line);
std::string part1, part2;
std::getline(ss, part1, '#');
std::getline(ss, part2, '#');

직관적이고 안전하지만, 대량의 데이터를 처리할 때는 성능이 떨어질 수 있습니다.

2.2 findsubstr을 이용한 파싱

size_t pos = line.find('#');
std::string part1 = line.substr(0, pos);
std::string part2 = line.substr(pos + 1);

빠르고 객체 생성이 적으며, 대량의 로그 데이터를 처리할 때 유리합니다.

2.3 scanf 스타일 포맷 지정 파싱

char date[11], time[9];
int level;
sscanf("2021-04-05 17:17:11#1", "%10s %8s#%d", date, time, &level);

속도가 매우 빠르지만, 포맷이 정확하지 않으면 쉽게 오류가 발생할 수 있습니다.

3. 시간 비교 방법

3.1 문자열 직접 비교

std::string t1 = "2021-04-05 17:17:11";
std::string t2 = "2021-04-05 17:18:11";

if (t1 < t2) {
    // t1이 더 이른 시각
}

YYYY-MM-DD hh:mm:ss 형식은 문자열 비교만으로도 정확한 시간 순서를 판단할 수 있습니다.

3.2 std::tmmktime을 이용한 비교

#include <ctime>
#include <sstream>
#include <iomanip>

std::tm tm = {};
std::istringstream ss(datetime);
ss >> std::get_time(&tm, "%Y-%m-%d %H:%M:%S");
time_t t = std::mktime(&tm);

시간 차이를 초 단위로 계산하거나 시간 연산이 필요한 경우에 적합합니다. std::get_time은 일부 컴파일러에서 동작하지 않을 수 있으므로 환경에 따라 검토가 필요합니다.

예시

[BOJ] 21774 / 가희와 로그 파일

  • 코드
    #include <iostream>
    #include <string>
    #include <vector>
    #include <algorithm>
    
    using namespace std;
    
    int main() {
        ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
        int n, q; cin >> n >> q; cin.ignore();
    
        int lv, idx;
        string l, t; 
        vector<string> logs[7];
        for(int i = 0; i < n; ++i) {
            getline(cin, l);
    
            idx = l.find('#');
            t = l.substr(0, idx);
            lv = stoi(l.substr(idx + 1));
    
            logs[lv].emplace_back(t);
        }
    
        for(int i = 1; i < 7; ++i) {
            if (logs[i].empty()) continue;
            sort(logs[i].begin(), logs[i].end());
        }
    
        string query, s, e;
        while (q--) {
            getline(cin, query);
            int p1 = query.find('#');
            int p2 = query.find('#', p1 + 1);
    
            s = query.substr(0, p1);
            e = query.substr(p1 + 1, p2 - p1 - 1);
            lv = stoi(query.substr(p2 + 1));
            
            int res = 0;
            for(int i = lv; i < 7; ++i) {
                if (logs[i].empty()) continue;
                res += upper_bound(logs[i].begin(), logs[i].end(), e) - lower_bound(logs[i].begin(), logs[i].end(), s);
            }
            cout << res << '\n';
        }
    
        return 0;
    }

Ford-Fulkerson 알고리즘

  • 유량 네트워크에 있는 모든 유량을 0으로 초기화 시킨 뒤, source에서 sink로 유량을 더 보낼 수 있는 경로를 찾아 흘리는 동작을 반복
  • 다음과 같은 그래프가 있다고 가정
  • 유량을 흘러 보내면 아래와 같음
  • 유량을 보내는 경로 : 증가 경로(augmenting path)
    • 경로에 해당하는 모든 간선에 잔여 용량(residual capacity, r[u][v]) 존재
    • r[u][v] = c[u][v] - f[u][v]
  • 하지만 경로에 따라 기존의 완전 탐색의 방법으로는 다음과 같이 더 이상 흐르는 유량을 추가할 수 없는 경우도 발생
  • 해당 경우에는 f[B][A] = -f[A][B] = -1 인 대칭성 활용

Edmonds-Karp Algorithm

  • BFS를 활용한 Ford-Fulkerson 알고리즘 구현
  • O(VE^2)

Pseudocode

function EdmondsKarp(source, sink):
    maxFlow ← 0
    // 모든 간선의 초기 유량(flow)을 0으로 설정
    for each edge (u, v) in graph:
        flow[u][v] ← 0

    while true:
        // BFS를 이용하여 증가 경로(augmenting path)를 찾는다.
        for each vertex v in graph:
            parent[v] ← NIL
        create an empty queue Q
        enqueue(source, Q)

        while Q is not empty and parent[sink] is NIL:
            u ← dequeue(Q)
            for each neighbor v of u:
                if parent[v] is NIL and (capacity[u][v] - flow[u][v] > 0):
                    parent[v] ← u
                    enqueue(v, Q)

        // sink에 도달할 수 없는 경우, 더 이상 증가 경로가 없으므로 종료
        if parent[sink] is NIL:
            break

        // 찾은 경로에서 추가로 보낼 수 있는 최대 흐름(병목 용량)을 계산
        path_flow ← ∞
        v ← sink
        while v ≠ source:
            u ← parent[v]
            path_flow ← min(path_flow, capacity[u][v] - flow[u][v])
            v ← u

        // 경로를 따라 흐름을 갱신
        v ← sink
        while v ≠ source:
            u ← parent[v]
            flow[u][v] ← flow[u][v] + path_flow
            flow[v][u] ← flow[v][u] - path_flow
            v ← u

        maxFlow ← maxFlow + path_flow

    return maxFlow

예시

[BOJ] 11375 / 열혈강호

  • 문제 강호네 회사에는 직원이 N명이 있고, 해야할 일이 M개가 있다. 직원은 1번부터 N번까지 번호가 매겨져 있고, 일은 1번부터 M번까지 번호가 매겨져 있다. 각 직원은 자신이 할 수 있는 일들 중 한 개의 일만 담당할 수 있고, 각각의 일을 담당하는 사람은 1명이어야 한다. 각각의 직원이 할 수 있는 일의 목록이 주어졌을 때, M개의 일 중에서 최대 몇 개를 할 수 있는지 구하는 프로그램을 작성하시오.
    첫째 줄에 직원의 수 N과 일의 개수 M이 주어진다. (1 ≤ N, M ≤ 1,000) 둘째 줄부터 N개의 줄의 i번째 줄에는 i번 직원이 할 수 있는 일의 개수와 할 수 있는 일의 번호가 주어진다.
    첫째 줄에 강호네 회사에서 할 수 있는 일의 개수를 출력한다.
    5 5
    2 1 2
    1 1
    2 2 3
    3 3 4 5
    1 1
    4
  • 예시 답안
    #include <iostream>
    #include <vector>
    #include <queue>
    #include <cstring>
    #define SZ 2002
    
    using namespace std;
    
    int n, m;
    int s = 0, e = SZ - 1;
    vector<int> g[SZ];
    int c[SZ][SZ];
    int f[SZ][SZ];
    int p[SZ];
    
    int ek() {
    	int res = 0;
    	while (1) {
    		memset(p, -1, sizeof(p));
    		queue<int> q;
    		q.emplace(s);
    		while (!q.empty()) {
    			int cur = q.front(); q.pop();
    			if (cur == e) break;
    			for (int nxt : g[cur]) {
    				if (p[nxt] == -1 && c[cur][nxt] - f[cur][nxt] > 0) {
    					q.emplace(nxt);
    					p[nxt] = cur;
    				}
    			}
    		}
    		if (p[e] == -1) break;
    
    		int cost = INT32_MAX;
    		for (int cur = e; cur != s; cur = p[cur]) {
    			cost = min(cost, c[p[cur]][cur] - f[p[cur]][cur]);
    		}
    
    		for (int cur = e; cur != s; cur = p[cur]) {
    			f[p[cur]][cur] += cost;
    			f[cur][p[cur]] -= cost;
    		}
    
    		res += cost;
    	}
    	return res;
    }
    
    int main() {
    	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    	cin >> n >> m;
    	int cnt, work;
    	for (int i = 1; i <= n; ++i) {
    		cin >> cnt;
    		g[s].emplace_back(i);
    		g[i].emplace_back(s);
    		c[s][i] = 1;
    		for (int j = 0; j < cnt; ++j) {
    			cin >> work;
    			work += 1000;
    			g[i].emplace_back(work);
    			g[work].emplace_back(i);
    			c[i][work] = 1;
    		}
    	}
    
    	for (int i = 1001; i <= m + 1000; ++i) {
    		g[i].emplace_back(e);
    		g[e].emplace_back(i);
    		c[i][e] = 1;
    	}
    
    	cout << ek() << '\n';
    	return 0;
    }

[BOJ] 11376 / 열혈강호 2

  • 문제 강호네 회사에는 직원이 N명이 있고, 해야할 일이 M개가 있다. 직원은 1번부터 N번까지 번호가 매겨져 있고, 일은 1번부터 M번까지 번호가 매겨져 있다. 각 직원은 자신이 할 수 있는 일들 중 최대 두 개의 일을 담당할 수 있고, 각각의 일을 담당하는 사람은 1명이어야 한다. 각각의 직원이 할 수 있는 일의 목록이 주어졌을 때, M개의 일 중에서 최대 몇 개를 할 수 있는지 구하는 프로그램을 작성하시오.
    첫째 줄에 직원의 수 N과 일의 개수 M이 주어진다. (1 ≤ N, M ≤ 1,000) 둘째 줄부터 N개의 줄의 i번째 줄에는 i번 직원이 할 수 있는 일의 개수와 할 수 있는 일의 번호가 주어진다.
    첫째 줄에 강호네 회사에서 할 수 있는 일의 개수를 출력한다.
    5 5
    2 1 2
    2 1 2
    2 1 2
    2 4 5
    0
    4
  • 예시 답안
    #include <iostream>
    #include <vector>
    #include <queue>
    #include <cstring>
    #define SZ 2002
    
    using namespace std;
    
    int n, m;
    int s = 0, e = SZ - 1;
    vector<int> g[SZ];
    int c[SZ][SZ];
    int f[SZ][SZ];
    int p[SZ];
    
    int ek() {
    	int res = 0;
    	while (1) {
    		memset(p, -1, sizeof(p));
    		queue<int> q;
    		q.emplace(s);
    		while (!q.empty()) {
    			int cur = q.front(); q.pop();
    			if (cur == e) break;
    			for (int nxt : g[cur]) {
    				if (p[nxt] == -1 && c[cur][nxt] - f[cur][nxt] > 0) {
    					q.emplace(nxt);
    					p[nxt] = cur;
    				}
    			}
    		}
    		if (p[e] == -1) break;
    
    		int cost = INT32_MAX;
    		for (int cur = e; cur != s; cur = p[cur]) {
    			cost = min(cost, c[p[cur]][cur] - f[p[cur]][cur]);
    		}
    
    		for (int cur = e; cur != s; cur = p[cur]) {
    			f[p[cur]][cur] += cost;
    			f[cur][p[cur]] -= cost;
    		}
    
    		res += cost;
    	}
    	return res;
    }
    
    int main() {
    	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    	cin >> n >> m;
    	int cnt, work;
    	for (int i = 1; i <= n; ++i) {
    		cin >> cnt;
    		g[s].emplace_back(i);
    		g[i].emplace_back(s);
    		c[s][i] = 2;
    		for (int j = 0; j < cnt; ++j) {
    			cin >> work;
    			work += 1000;
    			g[i].emplace_back(work);
    			g[work].emplace_back(i);
    			c[i][work] = 1;
    		}
    	}
    
    	for (int i = 1001; i <= m + 1000; ++i) {
    		g[i].emplace_back(e);
    		g[e].emplace_back(i);
    		c[i][e] = 1;
    	}
    
    	cout << ek() << '\n';
    	return 0;
    }

[BOJ] 11377 / 열혈강호 3

  • 문제 강호네 회사에는 직원이 N명이 있고, 해야할 일이 M개가 있다. 직원은 1번부터 N번까지 번호가 매겨져 있고, 일은 1번부터 M번까지 번호가 매겨져 있다. 각 직원은 자신이 할 수 있는 일들 중 한 개의 일만 담당할 수 있고, 각각의 일을 담당하는 사람은 1명이어야 한다. 단, N명 중에서 K명은 일을 최대 2개 담당할 수 있다. 각각의 직원이 할 수 있는 일의 목록이 주어졌을 때, M개의 일 중에서 최대 몇 개를 할 수 있는지 구하는 프로그램을 작성하시오.
    첫째 줄에 직원의 수 N과 일의 개수 M, 일을 2개할 수 있는 직원의 수 K가 주어진다. (1 ≤ N, M ≤ 1,000, 1 ≤ K ≤ N) 둘째 줄부터 N개의 줄의 i번째 줄에는 i번 직원이 할 수 있는 일의 개수와 할 수 있는 일의 번호가 주어진다.
    첫째 줄에 강호네 회사에서 할 수 있는 일의 개수를 출력한다.
    5 5 1
    3 1 2 3
    3 1 2 3
    1 5
    1 5
    1 5
    4
  • 예시 답안
    #include <iostream>
    #include <vector>
    #include <queue>
    #include <cstring>
    #define SZ 2003
    
    using namespace std;
    
    int n, m, k;
    int s = 0, mid = SZ - 2, e = SZ - 1;
    vector<int> g[SZ];
    int c[SZ][SZ];
    int f[SZ][SZ];
    int p[SZ];
    
    int ek() {
    	int res = 0;
    	while (1) {
    		memset(p, -1, sizeof(p));
    		queue<int> q;
    		q.emplace(s);
    		while (!q.empty()) {
    			int cur = q.front(); q.pop();
    			if (cur == e) break;
    			for (int nxt : g[cur]) {
    				if (p[nxt] == -1 && c[cur][nxt] - f[cur][nxt] > 0) {
    					q.emplace(nxt);
    					p[nxt] = cur;
    				}
    			}
    		}
    		if (p[e] == -1) break;
    
    		int cost = INT32_MAX;
    		for (int cur = e; cur != s; cur = p[cur]) {
    			cost = min(cost, c[p[cur]][cur] - f[p[cur]][cur]);
    		}
    
    		for (int cur = e; cur != s; cur = p[cur]) {
    			f[p[cur]][cur] += cost;
    			f[cur][p[cur]] -= cost;
    		}
    
    		res += cost;
    	}
    	return res;
    }
    
    int main() {
    	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    	cin >> n >> m >> k;
    	int cnt, work;
    	g[s].emplace_back(mid);
    	g[mid].emplace_back(s);
    	c[s][mid] = k;
    	for (int i = 1; i <= n; ++i) {
    		cin >> cnt;
    		g[s].emplace_back(i);
    		g[i].emplace_back(s);
    		c[s][i] = 1;
    		g[mid].emplace_back(i);
    		g[i].emplace_back(mid);
    		c[mid][i] = 1;
    		for (int j = 0; j < cnt; ++j) {
    			cin >> work;
    			work += 1000;
    			g[i].emplace_back(work);
    			g[work].emplace_back(i);
    			c[i][work] = 1;
    		}
    	}
    
    	for (int i = 1001; i <= m + 1000; ++i) {
    		g[i].emplace_back(e);
    		g[e].emplace_back(i);
    		c[i][e] = 1;
    	}
    
    	cout << ek() << '\n';
    	return 0;
    }

이분매칭

  • 이분 그래프에서 최대 매칭 수를 구하는 알고리즘
  • DFS 기반의 Hungarian Algorithm을 사용하며, 왼쪽 정점 집합(1번부터 n번까지)과 오른쪽 정점 집합(1번부터 m번까지)을 연결하는 간선을 입력받아 매칭을 수행
  • 핵심 아이디어
    • 각 왼쪽 정점에서 DFS를 통해 오른쪽 정점과의 매칭을 시도
    • 오른쪽 정점이 이미 다른 왼쪽 정점과 매칭되어 있더라도, 그 매칭을 재귀적으로 바꿔줄 수 있다면 현재 정점과 매칭시킨다.
    • 이러한 탐색을 모든 왼쪽 정점에 대해 반복하며 가능한 최대 매칭 수를 구함

수도코드

p[1...n+m] ← -1
ans ← 0

function dfs(cur):
    for nxt in g[cur]:
        if v[nxt] == true:
            continue
        v[nxt] ← true
        if p[nxt] == -1 or dfs(p[nxt]) == true:
            p[nxt] ← cur
            return true
    return false

for i = 1 to n:
    v[1...n+m] ← false
    if dfs(i) == true:
        ans ← ans + 1

[BOJ] 2188 / 축사 배정

  • 문제 링크
  • 예시 답안
    #include <iostream>
    #include <vector>
    #include <cstring>
    
    using namespace std;
    
    int n, m;
    vector<int> g[402];
    int p[402];
    bool v[402];
    
    bool dfs(int cur) {
    	for (int nxt : g[cur]) {
    		if (v[nxt]) continue;
    		v[nxt] = true;
    		if (p[nxt] == -1 || dfs(p[nxt])) {
    			p[nxt] = cur;
    			return true;
    		}
    	}
    	return false;
    }
    
    int main() {
    	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    	cin >> n >> m;
    	int s, num;
    	for (int i = 1; i <= n; ++i) {
    		cin >> s;
    		for (int j = 0; j < s; ++j) {
    			cin >> num;
    			num += 200;
    			g[i].emplace_back(num);
    		}
    	}
    	int ans = 0;
    	memset(p, -1, sizeof(p));
    	for (int i = 1; i <= n; ++i) {
    		memset(v, false, sizeof(v));
    		if (dfs(i)) ++ans;
    	}
    	cout << ans;
    	return 0;
    }

좌표 압축

  • 좌표 압축 문제는 스위핑 등 종종 쓰이므로 반드시 숙지
  • 코드
    #include <iostream>
    #include <vector>
    #include <algorithm>
    #define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
    using namespace std;
    
    struct Pos {
        int id, x;
        Pos() {}
        Pos(int id, int x) : id(id), x(x) {}
        bool operator < (const Pos& p) const {
            return x < p.x;
        }
    };
    
    vector<Pos> pos;
    int ans[1'000'000];
    
    int main() {
        FASTIO;
        int n; cin >> n;
        for (int i = 0; i < n; i++) {
            int x; cin >> x;
            pos.emplace_back(i, x);
        }
        sort(pos.begin(), pos.end());
        int bf = pos[0].x;
        int idx = 0;
        for (int i = 0; i < n; ++i) {
            Pos p = pos[i];
            if (bf != p.x) {
                bf = p.x;
                ++idx;
            }
            ans[p.id] = idx;
        }
        for (int i = 0; i < n; ++i) {
            cout << ans[i] << ' ';
        }
        cout << '\n';
        return 0;
    }

쿼리 정렬

  • 기본이지만 가끔 헷갈릴 때 있으니 정리
  • 정답을 출력할 ans 배열, 쿼리 구조체를 만들어 둔다. 이 때 쿼리 구조체의 idx반드시 삽입한다.
    ll ans[101010];
    
    struct Query{
    	  int idx, s, e;
    	  Query() {};
    	  Query(int idx, int s, int e) : idx(idx), s(s), e(e) {}
    	  bool operator < (const Query& q) const {
    		  // this < q : this 다음 q
    		  // q < this : q 다음 this
    	  }
    };
  • 쿼리를 정렬하고, 정렬된 순서대로 쿼리를 처리한다. 처리가 완료되면 해당 결과는 ans 에 저장한다.
    vector<Query> query;
    
    sort(query.begin(), query.end());
    
    for (int i = 0; i < q; ++i) {
        // 작업 수행
        ans[query[i].idx] = res; // 결과 대입
    }
    
    for (int i = 0; i < q; ++i) {
        cout << ans[i] << '\n';
    }

세그먼트 트리

  • 일반 세그
    typedef long long ll;
    
    struct SegTree {
        int size;
        vector<ll> tree;
        
        SegTree(int n) {
            size = n;
            tree.resize(4 * n);
        }
        
        void build(int node, int s, int e, const vector<ll>& arr) {
            if (s == e) {
                tree[node] = arr[s];
                return;
            }
            int m = (s + e) >> 1;
            build(node << 1, s, m, arr);
            build(node << 1 | 1, m + 1, e, arr);
            tree[node] = tree[node << 1] + tree[node << 1 | 1];
        }
        
        void update(int node, int s, int e, int idx, ll val) {
            if (idx < s || idx > e) return;
            if (s == e) {
                tree[node] = val;
                return;
            }
            int m = (s + e) >> 1;
            update(node << 1, s, m, idx, val);
            update(node << 1 | 1, m + 1, e, idx, val);
            tree[node] = tree[node << 1] + tree[node << 1 | 1];
        }
        
        ll query(int node, int s, int e, int ts, int te) {
            if (te < s || e < ts) return 0;
            if (ts <= s && e <= te) return tree[node];
            int m = (s + e) >> 1;
            return query(node << 1, s, m, ts, te) + query(node << 1 | 1, m + 1, e, ts, te);
        }
        
        void build(const vector<ll>& arr) {
            build(1, 0, size - 1, arr);
        }
        
        void update(int idx, ll val) {
            update(1, 0, size - 1, idx, val);
        }
        
        ll query(int l, int r) {
            return query(1, 0, size - 1, l, r);
        }
    };
  • lazy propagation
    typedef long long ll;
    
    struct SegTree {
        int size;
        vector<ll> tree, lazy;
        
        SegTree(int n) {
            size = n;
            tree.resize(4 * n, 0);
            lazy.resize(4 * n, 0);
        }
        
        void update_lazy(int node, int s, int e) {
            if (lazy[node] != 0) {
                tree[node] += (e - s + 1) * lazy[node];
                if (s != e) {
                    lazy[node << 1] += lazy[node];
                    lazy[node << 1 | 1] += lazy[node];
                }
                lazy[node] = 0;
            }
        }
        
        void build(int node, int s, int e, const vector<ll>& arr) {
            if (s == e) {
                tree[node] = arr[s];
                return;
            }
            int m = (s + e) >> 1;
            build(node << 1, s, m, arr);
            build(node << 1 | 1, m + 1, e, arr);
            tree[node] = tree[node << 1] + tree[node << 1 | 1];
        }
        
        void update(int node, int s, int e, int ts, int te, ll val) {
            update_lazy(node, s, e);
            if (te < s || e < ts) return;
            if (ts <= s && e <= te) {
                tree[node] += val;
                if (s != e) {
                    lazy[node << 1] += val;
                    lazy[node << 1 | 1] += val;
                }
                return;
            }
            int m = (s + e) >> 1;
            update(node << 1, s, m, ts, te, val);
            update(node << 1 | 1, m + 1, e, ts, te, val);
            tree[node] = tree[node << 1] + tree[node << 1 | 1];
        }
        
        ll query(int node, int s, int e, int ts, int te) {
            update_lazy(node, s, e);
            if (te < s || e < ts) return 0;
            if (ts <= s && e <= te) return tree[node];
            int m = (s + e) >> 1;
            return query(node << 1, s, m, ts, te) + query(node << 1 | 1, m + 1, e, ts, te);
        }
        
        void build(const vector<ll>& arr) {
            build(1, 0, size - 1, arr);
        }
        
        void update(int ts, int te, ll val) {
            update(1, 0, size - 1, ts, te, val);
        }
        
        ll query(int ts, int te) {
            return query(1, 0, size - 1, ts, te);
        }
    };
  • 비재귀 세그
    typedef long long ll;
    
    struct SegTree {
        int n;
        vector<ll> tree;
    
        SegTree(int _n) {
            n = _n;
            tree.resize(2 * _n);
        }
    
        // [0, n)
        void update(ll x, ll val) {
            for (tree[x += n] = val; x > 1; x >>= 1) tree[x >> 1] = tree[x] + tree[x ^ 1];
        }
        
        // [l, r)
        ll query(ll l, ll r) {
            ll res = 0;
            for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
                if (l & 1) res += tree[l++];
                if (r & 1) res += tree[--r];
            }
            return res;
        }
    };

SQRT Decomposition

  • 원소를 O(N)O(\sqrt N) 로 나누어서 처리하는 방법
  • init
    • sq 를 구해서 대푯값(bucket)을 채우기
      void init(){
      	  sq = sqrt(n);
      	  for(int i=1; i<=n; i++){
      		    bucket[i/sq] += arr[i];
      	  }
      }
  • update
    • 해당 원소를 직접 업데이트 해주고, 그 원소가 속한 그룹에 있는 모든 원소의 합을 구해 대표값을 갱신
      void update(int idx, int val) {
          arr[idx] = val;
          int id = idx / sq; // sq = sqrt(n);
          int s = id * sq; // 시작
          int e = s + sq; // 끝
          bucket[id] = 0;
          for(int i = s; i < e; ++i) bucket[id] += arr[i];
      }
  • query
    • 왼쪽 몇 개, 오른쪽 몇 개 + 중간은 대푯값(bucket) 통째로
      ll query(int l, int r) {
          ll res = 0;
          while(l % sq != 0 && l <= r) {
              res += arr[l++];
          }
          while((r + 1) % sq != 0 && l <= r) {
      	      res += arr[r--];
          }
          while(l <= r) {
      				res += bucket[l / sq];
      				l += sq;
          }
          return res;
      }
  • 구간합 구하기 풀이
    • segment tree
      #include <iostream>
      #include <vector>
      #define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
      using namespace std;
      typedef long long ll;
      
      struct SegTree {
          int size;
          vector<ll> tree;
      
          SegTree(int n) {
              size = n;
              tree.resize(4 * n);
          }
      
          void build(int node, int s, int e, const vector<ll>& arr) {
              if (s == e) {
                  tree[node] = arr[s];
                  return;
              }
              int m = (s + e) >> 1;
              build(node << 1, s, m, arr);
              build(node << 1 | 1, m + 1, e, arr);
              tree[node] = tree[node << 1] + tree[node << 1 | 1];
          }
      
          void update(int node, int s, int e, int idx, ll val) {
              if (idx < s || idx > e) return;
              if (s == e) {
                  tree[node] = val;
                  return;
              }
              int m = (s + e) >> 1;
              update(node << 1, s, m, idx, val);
              update(node << 1 | 1, m + 1, e, idx, val);
              tree[node] = tree[node << 1] + tree[node << 1 | 1];
          }
      
          ll query(int node, int s, int e, int ts, int te) {
              if (te < s || e < ts) return 0;
              if (ts <= s && e <= te) return tree[node];
              int m = (s + e) >> 1;
              return query(node << 1, s, m, ts, te) + query(node << 1 | 1, m + 1, e, ts, te);
          }
      
          void build(const vector<ll>& arr) {
              build(1, 0, size - 1, arr);
          }
      
          void update(int idx, ll val) {
              update(1, 0, size - 1, idx, val);
          }
      
          ll query(int l, int r) {
              return query(1, 0, size - 1, l, r);
          }
      };
      
      int main() {
          FASTIO;
          int n, m, k; cin >> n >> m >> k;
          vector<ll> arr(n);
          for (int i = 0; i < n; ++i) {
              cin >> arr[i];
          }
      
          SegTree seg_tree(n);
          seg_tree.build(arr);
      
          int a;
          ll b, c;
          for (int i = 0; i < m + k; ++i) {
              cin >> a >> b >> c;
              if (a == 1) {
                  seg_tree.update(b - 1, c);
              } else {
                  cout << seg_tree.query(b - 1, c - 1) << '\n';
              }
          }
          return 0;
      }
      
    • sqrt decomposition
      #include <iostream>
      #include <cmath>
      #define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
      using namespace std;
      
      typedef long long ll;
      
      int n, m, k, sq;
      ll arr[1'000'001];
      ll bucket[1'001];
      
      void init(){
          sq = sqrt(n);
          for(int i = 1; i <= n; i++){
              bucket[i/sq] += arr[i];
          }
      }
      
      void update(ll idx, ll val) {
          arr[idx] = val;
          int id = idx / sq;
          int s = id * sq;
          int e = s + sq;
          bucket[id] = 0;
          for(int i = s; i < e; ++i) bucket[id] += arr[i];
      }
      
      ll query(ll l, ll r) {
          ll res = 0;
          while(l % sq != 0 && l <= r) {
              res += arr[l++];
          }
          while((r + 1) % sq != 0 && l <= r) {
              res += arr[r--];
          }
          while(l <= r) {
              res += bucket[l / sq];
              l += sq;
          }
          return res;
      }
      
      int main() {
          FASTIO;
          cin >> n >> m >> k;
          for (int i = 1; i <= n; ++i) {
              cin >> arr[i];
          }
      
          init();
          int a;
          ll b, c;
      
          for (int i = 0; i < m + k; ++i) {
              cin >> a >> b >> c;
              if (a == 1) {
                  update(b, c);
              } else {
                  cout << query(b, c) << '\n';
              }
          }
          return 0;
      }

Mo’s

  • update가 없는 구간 쿼리 문제를 처리하는 알고리즘
  • 쿼리들의 순서를 재배치 후 일부 구간 재활용 → 효율적으로 쿼리 수행
  • SQRT Decomposition 기반
  • 쿼리 Q1, Q2에 대해서 다음 중 하나를 만족하는 경우 Q1을 Q2 보다 먼저 처리
    • [s1/sq]<[s2/sq][s_1/sq] < [s_2/sq]
    • e1<e2([s1/sq]=[s2/sq])e_1 < e_2 ([s_1/sq] = [s_2/sq])
  • 수열과 쿼리 5 풀이
    • 기본적으로 슬라이딩 윈도우 문제, but 쿼리 정렬만 추가
    • cnt 배열로 (숫자, 개수) 를 관리
    • -- / ++ 위치 주의할 것
    • 코드
      #include <iostream>
      #include <vector>
      #include <algorithm>
      #include <cmath>
      #define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
      using namespace std;
      
      int n, m, sq;
      
      struct Query {
          int id, s, e;
          Query() {}
          Query(int id, int s, int e) : id(id), s(s), e(e) {}
          bool operator < (const Query& p) const {
              if (s / sq != p.s / sq) return s / sq < p.s / sq;
              return e < p.e;
          }
      };
      
      vector<Query> queries;
      int arr[100'001];
      int ans[100'001];
      int c, cnt[1'000'001];
      
      int main() {
          FASTIO;
          cin >> n; sq = sqrt(n);
          for (int i = 1; i <= n; ++i) {
              cin >> arr[i];
          }
          cin >> m;
          int s, e;
          for (int i = 0; i < m; ++i) {
              cin >> s >> e;
              queries.emplace_back(i, s, e);
          }
          sort(queries.begin(), queries.end());
          s = queries[0].s;
          e = queries[0].e;
          for (int i = s; i <= e; ++i) if (cnt[arr[i]]++ == 0) ++c;
          ans[queries[0].id] = c;
          for (int i = 1; i < m; ++i) {
              int ns = queries[i].s;
              int ne = queries[i].e;
              while (s < ns) if (--cnt[arr[s++]] == 0) --c;
              while (ns < s) if (cnt[arr[--s]]++ == 0) ++c;
              while (e < ne) if (cnt[arr[++e]]++ == 0) ++c;
              while (ne < e) if (--cnt[arr[e--]] == 0) --c;
              ans[queries[i].id] = c;
          }
          for (int i = 0; i < m; ++i) {
              cout << ans[i] << '\n';
          }
          return 0;
      }
      

Strongly Connected Component (SCC)

  • 방향 그래프에서 정점 간 u → v, v → u 경로가 모두 존재하는 정점 집합을 강한 연결 요소(SCC)라 정의
  • SCC 분리는 위상 정렬, 2-SAT, 도달성 판단 등에서 핵심 전처리 과정으로 사용
  • 각 정점이 속한 SCC 정보를 구한 뒤 SCC 간 DAG를 구성하여 전체 그래프를 간결하게 압축 가능
  • 대표 알고리즘으로는 Kosaraju 알고리즘과 Tarjan 알고리즘 존재

Tarjan 알고리즘

  • DFS 기반으로 한 번의 탐색으로 SCC를 구하는 알고리즘
  • 정점마다 고유 번호(id)를 매기고, 자신 및 후손을 통해 도달 가능한 가장 빠른 조상 정점 번호를 parent로 기록
  • DFS 탐색 도중 방문한 정점을 스택에 저장하고, parent == id 조건이 만족될 경우 해당 시점까지 쌓인 스택의 정점들을 하나의 SCC로 분리
  • 각 정점이 이미 SCC에 속했는지 여부는 finished 배열로 관리
  • 역방향 그래프 생성 없이 SCC 추출 가능
  • 수행 시간은 O(V + E)

Tarjan 알고리즘의 핵심 절차

  • DFS 시작 시 정점 cur에 고유 id를 부여하고 스택에 저장
  • 모든 인접 정점 nxt에 대해
    • 방문하지 않았다면 재귀 호출을 통해 parent 갱신
    • 방문한 정점이면서 아직 SCC로 분리되지 않았다면 parent 값 최소 갱신
  • parent == id 조건 성립 시, 현재 스택에서 해당 정점까지 pop하여 하나의 SCC 구성
  • 각 SCC는 정점 번호 기준으로 정렬하여 출력

수도코드

id ← 0
p[1...V] ← 0
finished[1...V] ← false
stack ← empty
scc ← []

function dfs(cur):
    p[cur] ← ++id
    stack.push(cur)
    parent ← p[cur]

    for nxt in g[cur]:
        if p[nxt] == 0:
            parent ← min(parent, dfs(nxt))
        else if not finished[nxt]:
            parent ← min(parent, p[nxt])

    if parent == p[cur]:
        component ← []
        do:
            nxt ← stack.pop()
            finished[nxt] ← true
            component.add(nxt)
        while nxt ≠ cur
        sort(component)
        scc.add(component)

    return parent

for i = 1 to V:
    if p[i] == 0:
        dfs(i)

[BOJ] 2150 / Strongly Connected Component

  • 문제 https://www.acmicpc.net/problem/2150
  • 예시 답안
    #include <iostream>
    #include <vector>
    #include <stack>
    #include <algorithm>
    #include <cstring>
    #define SZ 10'005
    
    using namespace std;
    
    int v, e, id;
    int p[SZ];
    bool finished[SZ];
    vector<int> g[SZ], scc[SZ];
    stack<int> stk;
    int scc_cnt;
    
    int dfs(int cur) {
    	p[cur] = ++id;
    	stk.push(cur);
    	int par = p[cur];
    
    	for (int nxt : g[cur]) {
    		if (p[nxt] == 0) par = min(par, dfs(nxt));
    		else if (!finished[nxt]) par = min(par, p[nxt]);
    	}
    
    	if (par == p[cur]) {
    		while (1) {
    			int t = stk.top(); stk.pop();
    			finished[t] = 1;
    			scc[scc_cnt].emplace_back(t);
    			if (t == cur) break;
    		}
    		sort(scc[scc_cnt].begin(), scc[scc_cnt].end());
    		++scc_cnt;
    	}
    
    	return par;
    }
    
    int main() {
    	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    	cin >> v >> e;
    	for (int i = 0; i < e; ++i) {
    		int a, b;
    		cin >> a >> b;
    		g[a].emplace_back(b);
    	}
    
    	for (int i = 1; i <= v; ++i) {
    		if (p[i] == 0) dfs(i);
    	}
    
    	vector<vector<int>> res(scc, scc + scc_cnt);
    	sort(res.begin(), res.end(), [](const vector<int>& a, const vector<int>& b) {
    		return a[0] < b[0];
    	});
    
    	cout << scc_cnt << '\n';
    	for (auto& s : res) {
    		for (int x : s) cout << x << ' ';
    		cout << -1 << '\n';
    	}
    	return 0;
    }
    

Heavy Light Descomposition

서론

우리가 처리해야하는 2개의 쿼리

우리는 정점 100,000개로 이루어진 트리에서 아래 2가지 쿼리를 총 100,000번 처리해야 합니다.

  1. Update v w : 정점 v의 가중치에 w를 더해준다.
  2. Query s e : s에서 e로 가는 경로에 있는 모든 정점의 가중치의 합을 출력한다.

선형이라면 2EZ

1, 2번 쿼리를 트리가 아닌 배열에서 처리한다고 생각해보면, BOJ2042 구간 합 구하기 문제와 동일한 문제가 됩니다.

세그먼트 트리/펜윅 트리 등을 이용해 O(QlogN)O(QlogN) 시간에 풀 수 있습니다.

트리에서는 불가능

트리는 선형이 아니기 때문에 세그먼트 트리같은 자료구조를 사용할 수 없습니다. 트리에서도 세그먼트 트리를 사용할 수 있는 형태로 만들기 위해 트리를 적당히 잘라서 여러 개의 체인들로 만들어 줄 것 입니다. 각 체인 안에서의 쿼리는 세그먼트 트리를 사용하면 O(logN)O(logN)에 처리할 수 있습니다.

만약 모든 (u, v)쌍에 대해 u에서 v로 가는 경로에 O(logN)O(logN)개의 체인만 존재하게 할 수 있다면 트리에서 경로에 대한 쿼리를 O(logN)O(logN)번의 구간 쿼리로 바꿀 수 있고, 세그먼트 트리 등의 자료구조를 이용하면 O(log2N)O(log2N)에 처리할 수 있습니다.

HLD가 뭔데?

HLD는 Heavy Light Decomposition의 약자입니다. 트리의 간선들을 Heavy Edge(무거운 간선)와 Light Edge(가벼운 간선)로 구분하는 것을 의미합니다.

보통 무겁다/가볍다를 “무게”라는 척도를 이용해 구분하듯이, Heavy Edge와 Light Edge는 “서브 트리의 크기”를 기준으로 구분합니다.

부모 정점 u에서 자식 정점 v로 가는 간선 (u, v)가 있을 때 v의 서브 트리 크기가 u의 서브 트리 크기의 절반 이상인 경우(sz[son] ≥ sz[parent]/2) 그 간선을 heavy edge라고 하고, 나머지 간선들은 light edge라고 합니다. 한 정점에서 내려가는 heavy edge는 최대 한 개만 존재해야 합니다.

이렇게 heavy edge와 light edge를 잘 구분해놓으면 좋은 점은, light edge를 타고 올라가면 무조건 트리의 크기가 2배 이상이 됩니다. 그러면 당연히 어떤 정점에서 루트로 갈 때 최대 O(logN)O(logN)개의 light edge만 거치게 됩니다.

정점 u에서 정점 v로 가는 경로는 루트, 혹은 루트보다 아래에 있는 노드를 거쳐서 가기 때문에 최대 2∗O(logN)=O(logN)2∗O(logN)=O(logN)개의 light edge만 거치게 됩니다.

보통 HLD를 코드로 구현할 때는 구현의 편의를 위해 sz[son] ≥ sz[parent]/2 대신 sz[son]이 가장 큰 간선을 heavy edge로 잡습니다. 이렇게 해도 복잡도 등의 분석은 크게 달라지지 않습니다.

이 트리에서 heavy edge들을 표시하면 아래 그림과 같이 됩니다.

위에서 이야기했듯이, 각 정점에서 아래로 뻗어나가는 heavy edge는 최대 한 개이기 때문에 인접한 heavy edge들은 한 개의 체인으로 묶어줄 수 있습니다. light edge들은 그 자체를 하나의 체인으로 보면 됩니다.

인접한 heavy edge들을 체인으로 묶어주었으니, 이것들에 대해서는 세그먼트 트리등의 자료구조를 사용할 수 있고, 모든 경로에서는 light edge를 최대 O(logN)O(logN)개 거치기 때문에 당연히 O(logN)O(logN)개의 체인만 보면 됩니다.

구현은 어떻게 하죠?

구현은 정말 다양한 방법이 존재합니다. 그러나 맞는 구현은 단 한 가지라는 말이 있을 정도로 많이 사용하고 있고, 그만큼 매우 간단한 구현 방법을 소개하고자 합니다.

int sz[MAXV], dep[MAXV], par[MAXV], top[MAXV], in[MAXV], out[MAXV];
vector<int> g[MAXV];
/*
sz[i] = i를 루트로 하는 서브트리의 크기
dep[i] = i의 깊이
par[i] = i의 부모 정점
top[i] = i가 속한 체인의 가장 위에 있는 정점
in[i], out[i] = dfs ordering
g[i] = i의 자식 정점
*/

void dfs1(int v = 1){
	sz[v] = 1;
	for(auto &i : g[v]){
		dep[i] = dep[v] + 1; par[i] = v;
		dfs1(i); sz[v] += sz[i];
		if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);
	}
}

void dfs2(int v = 1){
	in[v] = ++pv;
	for(auto i : g[v]){
		top[i] = i == g[v][0] ? top[v] : i;
		dfs2(i);
	}
	out[v] = pv;
}

코드를 하나씩 살펴봅시다.

dfs1에서는 sz, dep, par배열을 채워주고 있습니다. 그러면서 동시에 서브트리가 가장 큰 자식을 맨 앞으로 보내는 역할을 해주고 있습니다.(if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);)

dfs1가 끝난다면 v에서 뻗어나가는 heavy edge는 (v,g[v][0])(v,g[v][0])일 것입니다.

dfs2에서는 in, out 배열을 채워주고 있습니다. 이는 dfs를 돌면서 i번에 들어가는 시점을 in[i]에, i번에서 빠져나가는 시점을 out[i]에 저장합니다. dfs1에서 heavy edge를 인접리스트의 가장 앞으로 옮겨주었기 때문에, 인접한 heavy edge에 속한 정점들은 dfs ordering 상에서도 인접합니다!

top[i]는 체인의 가장 위에 있는 정점인데, 만약 i가 v의 0번째 자식이라면 같은 체인에 속하므로 top[v]를 물려받고, 그렇지 않다면 새로운 체인이 시작하는 것이기 때문에 i로 설정해주면 됩니다.

각 정점의 in값은 유일합니다. 또한, dfs1과 dfs2를 이용해 같은 체인에 속한 정점들은 in값도 인접하게 만들어주었기 때문에 세그먼트 트리에서 각 정점을 관리하는 인덱스를 in[i]로 해줄 수 있습니다.

만약 어떤 정점 v부터 v가 속한 체인의 가장 위에 있는 정점까지의 구간을 알고 싶다면 [ in[top[v]], in[v] ]구간을 보면 됩니다.

dfs2에서 in과 더불어 out까지 구해놓았기 때문에 서브트리에 대한 쿼리도 처리를 해줄 수 있습니다!

Update 처리

한 정점에 대해서만 갱신을 하는 경우에는 seg.update(in[v], w) 형식으로 해주면 됩니다.

만약 경로에 대해 갱신을 할 때는 바로 아래에서 다룰 Query 처리를 참고하셔서 Lazy Propagation을 잘 섞어주시면 됩니다.

Query 처리

Query만 처리하면 끝납니다!

경로에 대한 쿼리를 처리하는 기본적인 아이디어는 경로를 여러 개의 체인으로 나눠서, 각 체인에 대해 쿼리를 날려준 뒤 모두 합치는 것입니다.

만약 두 정점이 같은 체인에 속한다면 아래 코드처럼 단순히 세그먼트 트리에 쿼리를 한 번 날리는 것으로 끝납니다.

if(dep[a] > dep[b]) swap(a, b);
ret += seg.query(1, 1, n, in[a], in[b]);
return ret;

만약 두 정점이 서로 다른 체인에 속한다면, 같은 체인에 속할 때까지 체인을 타고 올라가야합니다.

u에서 v로 가는 경로를 처리할 때, u, v의 lca를 기준으로 나눠서 보면 이해하기 쉽습니다.

경로를 처리하기 위해서는 lca까지 모두 봐야하고, u v가 서로 다른 체인에 있으면 각자 체인을 타고 쭉쭉 올라와서 lca와 같은 체인에서 만나는 결말이 나와야 합니다.

그것을 이루기 위해서, u와 v 중 더 아래에 있는 정점 x를 선택해서 top[x]부터 x, 즉 x부터 x가 속한 체인의 끝까지 모두 쿼리를 처리해준 뒤, x를 par[st]로 올려줍니다. 이런 방식으로 체인들을 하나씩 떼어나가면 결국 마지막에는 lca와 같은 체인에서 만나게 되고, 같은 체인에 속한 쿼리는 쉽게 해결할 수 있습니다.

구현은 아래와 같이 하면 됩니다.

int query(int a, int b){
    int ret = 0;
    while(top[a] != top[b]){
        if(dep[top[a]] < dep[top[b]]) swap(a, b);
        int st = top[a];
        ret += seg.query(in[st], in[a]);
        a = par[st];
    }
    if(dep[a] > dep[b]) swap(a, b);
    ret += seg.query(in[a], in[b]);
    return ret;
}

전체 코드

#include <bits/stdc++.h>
using namespace std;

struct Seg{
    int tree[1 << 18];
    int sz = 1 << 17;

    void update(int x, int v){
        x |= sz; tree[x] += v;
        while(x >>= 1){
            tree[x] = tree[x << 1] + tree[x << 1 | 1];
        }
    }

    int query(int l, int r){
        l |= sz, r |= sz;
        int ret = 0;
        while(l <= r){
            if(l & 1) ret += tree[l++];
            if(~r & 1) ret += tree[r--];
            l >>= 1, r >>= 1;
        }
        return ret;
    }
}seg;

int sz[101010], dep[101010], par[101010], top[101010], in[101010], out[101010];
vector<int> g[101010];
vector<int> inp[101010]; //입력 / 양방향 그래프

int chk[101010];
void dfs(int v = 1){
	chk[v] = 1;
	for(auto i : inp[v]){
		if(chk[i]) continue;
		chk[i] = 1;
		g[v].push_back(i);
		dfs(i);
	}
}

void dfs1(int v = 1){
	sz[v] = 1;
	for(auto &i : g[v]){
		dep[i] = dep[v] + 1; par[i] = v;
		dfs1(i); sz[v] += sz[i];
		if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);
	}
}

int pv;
void dfs2(int v = 1){
	in[v] = ++pv;
	for(auto i : g[v]){
		top[i] = i == g[v][0] ? top[v] : i;
		dfs2(i);
	}
	out[v] = pv;
}

void update(int v, int w){
    seg.update(in[v], w);
}

int query(int a, int b){
    int ret = 0;
    while(top[a] ^ top[b]){
        if(dep[top[a]] < dep[top[b]]) swap(a, b);
        int st = top[a];
        ret += seg.query(in[st], in[a]);
        a = par[st];
    }
    if(dep[a] > dep[b]) swap(a, b);
    ret += seg.query(in[a], in[b]);
    return ret;
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    int n, q; cin >> n >> q; //정점 개수, 쿼리 개수
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        inp[s].push_back(e);
        inp[e].push_back(s);
    }
    dfs(); dfs1(); dfs2();
    while(q--){
        //1 v w : update v w
        //2 s e : query s e
        int op, a, b; cin >> op >> a >> b;
        if(op == 1) update(a, b);
        else cout << query(a, b) << "\n";
    }
}

가장 가까운 두 점

Brute Force

  • 모든 두 점의 거리를 계산, O(n2)O(n^2)

Divide & Conquer

  • 두 영역으로 분할 후 각자의 영역에서 최솟값을 구함
  • 합칠 때 중간 영역 고려 필요
  • 최근접 쌍의 거리인 dd 이내의 중간 영역 안에 포함된 점들 중 거리가 더욱 짧은 근접 쌍이 존재하는지 확인
  • d=min(dleft,dright)d = min(d_{left}, d_{right})
  • 중간 영역에 속한 점의 범위 : xleft,maxd<xmid<xright,min+dx_{left, max} - d < x_{mid} < x_{right, min} + d

Pseudocode

  • Divide & Conquer Recursion
    CLOSEST_PAIR_REC(px, py):
        n = length(px)
        
        // 기저 조건: 점이 3개 이하인 경우 브루트 포스
        if n <= 3:
            return BRUTE_FORCE(px)
        
        // 분할: 중간점을 기준으로 두 영역으로 나눔
        mid = n / 2
        midpoint = px[mid]
        
        pyl = []  // 왼쪽 영역의 점들 (y좌표 기준 정렬)
        pyr = []  // 오른쪽 영역의 점들 (y좌표 기준 정렬)
        
        for each point in py:
            if point.x <= midpoint.x:
                pyl.add(point)
            else:
                pyr.add(point)
        
        // 정복: 각 영역에서 최근접 쌍 찾기
        dl = CLOSEST_PAIR_REC(px[0...mid], pyl)
        dr = CLOSEST_PAIR_REC(px[mid+1...n-1], pyr)
        
        // 두 영역의 최솟값 중 작은 값
        d = min(dl, dr)
        
        // 합치기: 중간 영역에서 더 가까운 쌍 찾기
        return min(d, CLOSEST_SPLIT_PAIR(px, py, d))
  • Middle Area
    CLOSEST_SPLIT_PAIR(px, py, d):
        n = length(px)
        mid = n / 2
        midpoint = px[mid]
        
        // 중간 영역에 속한 점들 추출
        // 범위: midpoint.x - d ≤ x ≤ midpoint.x + d
        strip = []
        for each point in py:
            if |point.x - midpoint.x| < d:
                strip.add(point)
        
        min_dist = d
        
        // 중간 영역 내 점들 간의 거리 계산
        for i = 0 to length(strip) - 1:
            j = i + 1
            // 최적화: y좌표 차이가 d 이상인 점들은 건너뛰기
            while j < length(strip) AND (strip[j].y - strip[i].y) < min_dist:
                dist = DISTANCE(strip[i], strip[j])
                if dist < min_dist:
                    min_dist = dist
                j = j + 1
        
        return min_dist

Example

  • [BOJ] 2261 / 가장 가까운 두 점
    • 중간에 있는 점들 사이 거리 비교를 어떻게 최적화 시키는가 → 가장 어려운 부분

    • 관련 이론 링크

      #include <bits/stdc++.h>
      #define FASTIO ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
      using namespace std;
      typedef pair<int, int> pii;
      
      int n;
      vector<pii> v;
      
      int dist(pii p1, pii p2) {
          return (p1.first - p2.first) * (p1.first - p2.first) + (p1.second - p2.second) * (p1.second - p2.second);
      }
      
      int divide(int s, int e) {
          int res = 1e9;
          if (e - s < 3) {
              for (int i = s; i <= e; i++) {
                  for (int j = i + 1; j <= e; j++) {
                      res = min(res, dist(v[i], v[j]));
                  }
              }
              return res;
          }
      
          int mid = (s + e) / 2;
          int dl = divide(s, mid);
          int dr = divide(mid, e);
          int d = min(dl, dr);
      
          if (d == 0) return 0;
      
          vector<pii> area;
          for (int i = s; i <= e; i++) {
              int dx = v[mid].first - v[i].first;
              if (dx * dx < d) {
                  area.emplace_back(v[i]);
              }
          }
      
          res = d;
          sort(area.begin(), area.end(), [](const pii& p1, const pii& p2) {
              return (p1.second != p2.second) ? (p1.second < p2.second) : (p1.first < p2.first);
          });
      
          for (int i = 0; i < area.size(); ++i) {
              for (int j = i + 1; j < area.size(); ++j) {
                  int dy = abs(area[i].second - area[j].second);
                  if (dy * dy < d) {
                      res = min(res, dist(area[i], area[j]));
                  } else {
                      break;
                  }
              }
          }
      
          return res;
      }
      
      int main() {
          FASTIO;
          cin >> n; v.resize(n);
          for (int i = 0; i < n; ++i) {
              cin >> v[i].first >> v[i].second;
          }
          sort(v.begin(), v.end());
          cout << divide(0, n - 1) << '\n';
          return 0;
      }
profile
데이터베이스, 트랜잭션 구조 설계에 관심이 많은 백엔드 개발자입니다.

0개의 댓글