백준 7453 합이 0인 네 정수 문제
백준 7453 합이 0인 네 정수 소스코드
Problem
정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다. A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.
Input
첫째 줄 : 배열의 크기 n(1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄 : A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 2^28이다.
Output
합이 0이 되는 쌍의 개수를 출력한다.
Example Input
6 -45 22 42 -16 -41 -27 56 30 -36 53 -37 77 -36 30 -75 -46 26 -38 -10 62 -32 -54 -6 45
Example Output
5
A B C D 배열에 포함되는 정수가 n줄에 걸쳐 주어지는데 각 배열에서 하나씩 뽑았을 때, 그 합이 0이 되는 case 개수를 출력하는 프로그램이다.
(-32) + (-54) + (56) + (30) = 0 이와 같이 합이 0이 되는 case를 찾으면 된다.
이 문제는 4중 for문으로 코드를 작성할 수 있지만 그럴 경우 시간 복잡도가 O(N⁴)이 되어 시간초과가 발생한다.
따라서 이 문제에서는 Two Pointers(투 포인터)로 해결해야 한다. 투 포인터로 이 문제를 풀이했을 때 시간 복잡도는 O(N²)이다.
▶ 투 포인터 (Two Pointers)
Two Pointers는 1차원 배열에서 두 개의 포인터를 조작하여 원하는 결과를 얻는 알고리즘이다. 두 개의 포인터를 사용하면 기존의 방식보다 시간 복잡도를 개선할 수있다.
이와 같이 위치를 가리키는 두 개의 포인터를 두고 상황에 맞게 조작하면 된다.
문제로 돌아와서,
이 문제는 A, B 배열의 합(+)배열 AB와 C, D 배열의 합(+)배열 CD를 만든 후,
AB (오름차순), CD (내림차순) 배열을 정렬하고 각 배열을 가리키는 포인터를 만들어 합해가면서 0이 되는 case를 찾을 것이다.
if AB[ptab] + CD[ptcd] > 0 → ptcd를 증가 else if AB[ptab] + CD[ptcd] < 0 → ptab를 증가
else if AB[ptab] + CD[ptcd] == 0 → 동일 값 case 주의(⚠) AB 배열에 AB[ptab]와 같은 값의 개수와 CD 배열에 CD[ptcd]와 같은 값의 개수를 고려해주어야 한다! (단순히 ptab++; ptcd++;를 해주면 논리 오류가 생긴다.)
예를 들어, 현재 ptab와 ptcd가 위의 위치를 가리키고 있을 때,
ptab와 ptcd는 이와 같이 변하게 되고, result += 2 (-95개수) * 3 (95개수) 으로 해주어야 한다.
⚠ 여기서 또 하나 주의할 점이 있는데, n의 크기는 최대 4000이므로 총 경우의 수는 (4000)⁴가 된다. 따라서 result의 자료형은 long long으로 선언해주어야 한다.
#include <bits/stdc++.h> using namespace std; bool comp(const int& a, const int& b) { return a > b; // 내림차순 } int main() { int n; scanf("%d", &n); vector<int> A, B, C, D; for (int i = 0; i < n; i++) { int a, b, c, d; scanf("%d %d %d %d", &a, &b, &c, &d); A.push_back(a); B.push_back(b); C.push_back(c); D.push_back(d); } vector<int> AB, CD; // AB for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { AB.push_back(A[i] + B[j]); } } // CD for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { CD.push_back(C[i] + D[j]); } } // sorting sort(AB.begin(), AB.end()); sort(CD.begin(), CD.end(), comp); // two pointer int ptab = 0, ptcd = 0; long long int result = 0; while (ptab < AB.size() && ptcd < CD.size()) { int currentAB = AB[ptab]; int target = -currentAB; if (CD[ptcd] == target) { int ab = 0, cd = 0; while(AB[ptab] == currentAB && ptab < AB.size()){ ab++; ptab++; } while(CD[ptcd] == target && ptcd < CD.size()){ cd++; ptcd++; } result += (long long int) ab * cd; } else if (CD[ptcd] > target) { ptcd++; } else { ptab++; } } printf("%lld\n", result); return 0; }