안녕하세요. 오늘은 2-SAT알고리즘을 배워볼 거예요.
sat문제는 충족 가능성 문제입니다. (satisfiability problem)
충족 가능성 문제는 여러 boolean 변수들이 있는 식이 있을 때 그 식을 참(true)으로 만드는 boolean값들 (true, false)이 있는지 구하는 문제입니다.
이런게 sat문제입니다. 아래로 뾰족한게 or 연산자 위로 뾰족한게 and 연산자입니다. 변수 앞에 막대기가 있는것은 not 연산자입니다.
위 그림에서 괄호로 묶인 것들처럼 or 연산자들과 boolean변수들로만 이루어진 식을 절(clause)라고 하며 and 연산자들과 절들로만 이루어진 식을 CNF(Conjunctive Normal Form) 라고 합니다.
CNF의 특징은 각 절은 모두 참이여야하고 각 절 앞에 있는 변수들중 하나이상 참이여야한다는 것입니다.
이때 절 앞에 있는 boolean변수들의 개수의 최댓값이 k이면 그런 CNF를 푸는 문제를 k-SAT라고 부르는데 위의 경우는 최대 2개이므로 2-SAT가 됩니다. 1-SAT는 직관적이며 선형적으로 풀 수 있고 3-SAT이상부터는 변형을 하여 3-SAT문제로 바꿀 수 있으나.. NP-HARD문제입니다. 하지만 2-SAT는 다항시간안에 풀 수 있습니다.
위 그림에서 첫번째 절은 ~x1과 x2둘 중 하나는 참이여야합니다. 이말은 ~x1이 거짓이면 x2는 참이여야하고 x2가 거짓이면 ~x1은 참이여야한다는 말입니다. 이를 SCC로 구현할 수 있습니다.
N개의 종류의 boolean변수가 나온다면 2N개의 정점을 만드는 것입니다. 각 boolean변수의 기본값과 not값 두개 입니다.
그렇게 모든 절에 대해서 위와같이 선을 이으면 그래프가 나타나게 되니다. 이때 모순이 되는 경우, 즉 불가능한 경우는 x이면 ~x인데 ~x이면 x인 경우입니다. 즉, x와 ~x가 같은 SCC안에 있다는 것이지요. (하지만 단순히 x이면 ~x이다 하나만 있으면 성립하지 않습니다.)
2N개의 변수에 대해서 SCC를 진행하고 검사를 해주면 됩니다.
가능한지 구했으니까 값들도 구해야합니다.
재밌는 사실이 있습니다.
SCC를 구하는 대표적인 방법으로 코사라주 알고리즘과 타잔 알고리즘이 있습니다. 이 두 알고리즘다 그룹을 위상정렬한 값과 관련이 크게 있다는것입니다. 코사라주에서의 그룹번호는 위상정렬한 값과 동일하고 타잔은 그 역순입니다. 그래서 굳이 따로 위상정렬을 하지 않아도 됩니다.
전체 값들을 확정시킬때 x_n과 ~x_n이 있을 때 먼저 만나는것을 false로 해두면 나중에 만나는게 뭐가 되어도 상관이 없습니다. 만약 x_n을 false로 바꾸면 x_n -> ~x_n이 있어도 false -> true이니까 상관이 없고 ~x_n을 false로 바꿔도 ~x_n -> x_n이고 false -> true이므로 전혀 상관이 없습니다.
2-SAT - 3 (가능한지 여부만)
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
using namespace std;
vector <int> v1[20202], v2[20202];
bool visit[20202] = { 0 };
stack <int> s;
vector <int> scc;
int GroupNum[20202];
void dfs1(int node)
{
visit[node] = true;
for (int next : v1[node])
if (visit[next] == false)
dfs1(next);
s.push(node);
}
void dfs2(int node)
{
visit[node] = true;
scc.push_back(node);
for (int next : v2[node])
if (visit[next] == false)
dfs2(next);
}
int Not(int x) //~x
{
if (x % 2 == 1) return x + 1;
return x - 1;
}
int main(void)
{
ios_base::sync_with_stdio(false); cin.tie(NULL);
int N, M, i, a, b;
cin >> N >> M;
for (i = 0; i < M; i++)
{
//x가 홀수 ~x가 짝수
cin >> a >> b;
a = abs(a) * 2 + (a < 0) - 1; //a가 보다 작으면 1을 더해서 저장 아니면 그냥 2만 곱해서 저장
b = abs(b) * 2 + (b < 0) - 1; //인덱스가 1부터 시작하기 위해서
v1[Not(a)].push_back(b);
v1[Not(b)].push_back(a);
v2[b].push_back(Not(a));
v2[a].push_back(Not(b));
}
for (i = 1; i <= 2 * N; i++)
if (visit[i] == false)
dfs1(i);
for (i = 1; i <= 2 * N; i++) visit[i] = false;
int cnt = 0;
while (s.size())
{
int node = s.top();
s.pop();
if (visit[node] == true) continue;
dfs2(node);
cnt++;
for (int cur : scc)
GroupNum[cur] = cnt;
scc.clear();
}
for (i = 1; i <= N; i++)
{
if (GroupNum[i * 2 - 1] == GroupNum[i * 2]) //i와 ~i가 같은 그룹에 속해있으면
{
cout << 0; //불가능
return 0;
}
}
cout << "1\n";
}
2-SAT - 4 (값들까지 다 출력)
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#include <queue>
using namespace std;
vector <int> v1[20202], v2[20202];
bool visit[20202] = { 0 };
stack <int> s;
vector <int> scc;
int GroupNum[20202], GroupCnt;
void dfs1(int node)
{
visit[node] = true;
for (int next : v1[node])
if (visit[next] == false)
dfs1(next);
s.push(node);
}
void dfs2(int node)
{
visit[node] = true;
scc.push_back(node);
for (int next : v2[node])
if (visit[next] == false)
dfs2(next);
}
int Not(int x) //~x
{
if (x % 2 == 1) return x + 1;
return x - 1;
}
int main(void)
{
ios_base::sync_with_stdio(false); cin.tie(NULL);
int N, M, i, a, b;
cin >> N >> M;
for (i = 0; i < M; i++)
{
//x가 홀수 ~x가 짝수
cin >> a >> b;
a = abs(a) * 2 + (a < 0) - 1; //a가 보다 작으면 1을 더해서 저장 아니면 그냥 2만 곱해서 저장
b = abs(b) * 2 + (b < 0) - 1; //인덱스가 1부터 시작하기 위해서
v1[Not(a)].push_back(b);
v1[Not(b)].push_back(a);
v2[b].push_back(Not(a));
v2[a].push_back(Not(b));
//cout << a << ' ' << b << "\n";
}
for (i = 1; i <= 2 * N; i++)
if (visit[i] == false)
dfs1(i);
for (i = 1; i <= 2 * N; i++) visit[i] = false;
int cnt = 0;
while (s.size())
{
int node = s.top();
s.pop();
if (visit[node] == true) continue;
dfs2(node);
cnt++;
for (int cur : scc)
GroupNum[cur] = cnt;
scc.clear();
}
GroupCnt = cnt;
for (i = 1; i <= N; i++)
{
if (GroupNum[i * 2 - 1] == GroupNum[i * 2]) //i와 ~i가 같은 그룹에 속해있으면
{
cout << 0; //불가능
return 0;
}
}
cout << "1\n";
for (i = 1; i <= N; i++)
cout << (GroupNum[2 * i - 1] > GroupNum[2 * i]) << ' ';
}
2-SAT - 3 (가능한지 여부만)
#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>
using namespace std;
vector <int> v[20202];
int visit[20202] = { 0 }, finish[20202] = { 0 };
stack <int> s;
int id = 0, cnt = 0;
int GroupNum[20202] = { 0 };
int dfs(int node)
{
visit[node] = ++id;
s.push(node);
int parent = id;
for (int next : v[node])
if (!visit[next])
parent = min(parent, dfs(next));
else if (!finish[next])
parent = min(parent, visit[next]);
if (parent == visit[node])
{
cnt++;
while (true)
{
int cur = s.top();
s.pop();
finish[cur] = 1;
GroupNum[cur] = cnt;
if (cur == node) break;
}
}
return parent;
}
int NOT(int x)
{
if (x % 2 == 0) return x + 1;
return x - 1;
}
int main(void)
{
ios_base::sync_with_stdio(false); cin.tie(NULL);
int N, M, i, a, b;
cin >> N >> M;
for (i = 0; i < M; i++)
{
cin >> a >> b;
a = 2 * (abs(a) - 1) + (a < 0);
b = 2 * (abs(b) - 1) + (b < 0);
v[NOT(a)].push_back(b);
v[NOT(b)].push_back(a);
}
for (i = 0; i < 2 * N; i++)
if (!visit[i])
dfs(i);
for (i = 1; i <= N; i++)
if (GroupNum[i * 2 - 2] == GroupNum[i * 2 - 1]) //불가능하면
{
cout << 0;
return 0;
}
cout << 1;
}
2-SAT - 4 (값들까지 다 출력)
#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>
using namespace std;
vector <int> v[20202];
int visit[20202] = { 0 }, finish[20202] = { 0 };
stack <int> s;
int id = 0, cnt = 0;
int GroupNum[20202] = { 0 };
int dfs(int node)
{
visit[node] = ++id;
s.push(node);
int parent = id;
for (int next : v[node])
if (!visit[next])
parent = min(parent, dfs(next));
else if (!finish[next])
parent = min(parent, visit[next]);
if (parent == visit[node])
{
cnt++;
while (true)
{
int cur = s.top();
s.pop();
finish[cur] = 1;
GroupNum[cur] = cnt;
if (cur == node) break;
}
}
return parent;
}
int NOT(int x)
{
if (x % 2 == 0) return x + 1;
return x - 1;
}
int main(void)
{
ios_base::sync_with_stdio(false); cin.tie(NULL);
int N, M, i, a, b;
cin >> N >> M;
for (i = 0; i < M; i++)
{
cin >> a >> b;
a = 2 * (abs(a) - 1) + (a < 0);
b = 2 * (abs(b) - 1) + (b < 0);
v[NOT(a)].push_back(b);
v[NOT(b)].push_back(a);
}
for (i = 0; i < 2 * N; i++)
if (!visit[i])
dfs(i);
for (i = 1; i <= N; i++)
if (GroupNum[i * 2 - 2] == GroupNum[i * 2 - 1]) //불가능하면
{
cout << 0;
return 0;
}
cout << "1\n";
for (i = 1; i <= N; i++)
if (GroupNum[i * 2 - 2] > GroupNum[i * 2 - 1])
cout << "0 ";
else
cout << "1 ";
}
감사합니다.