정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다.
A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.
첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 228이다.
합이 0이 되는 쌍의 개수를 출력한다.
4개의 배열에서 나올 수 있는 모든 쌍은 4000^4개다. 이렇게 많은 경우의 수를 줄여주기 위해서 A, B의 쌍을 따로 C, D의 쌍을 따로 구하면 된다. -> (중간에서 만나기 활용 AB의 쌍이 CD쌍의 영향을 주기 때문에 가능함) 그러면 AB쌍과 CD쌍은 각각 1600만개씩 존재하고, AB쌍을 돌면서 합이 0이되는 값을 CD쌍에서 찾으면 된다.
그러면 CD쌍을 찾으면 되는데 처음에는 table이나 map을 사용해서 구현했다. 하지만 시간 초과가 났다. 시간복잡도를 계산했을 때 충분히 12초안에 가능한 풀이라고 생각했는데 찾아보니 map, table의 연산이 매우 느리다는 것을 알게 됐다. 그래서 생각한 방법이 lower bound와 upper bound를 사용한 풀이이다. sum_1, sum_2를 정렬하고, lower bound와 upper bound를 이용해서 현재 sum_1에 값의 범위를 구하고, 그 sum_1을 0값으로 만드는 sum_2의 범위를 구해서 ans를 계산했다. 그리고 출력하니 통과할 수 있었다.
// sum_1은 AB의 쌍
// sum_2는 CD의 쌍
import java.io.*;
import java.util.*;
public class Main {
static int N;
static int arr[][];
static long sum_1[]; // A, B의 모든 합
static long sum_2[]; //C, D의 모든 합
static long ans = 0;
public static void main(String args[]) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
N = Integer.parseInt(br.readLine());
arr = new int[4][N];
sum_1 = new long[N*N+1];
sum_2 = new long[N*N+1];
sum_1[N*N] = 1000000000000L;
sum_2[N*N] = 1000000000000L;
for(int i=0; i<N; i++) {
StringTokenizer n_st = new StringTokenizer(br.readLine());
for(int j=0; j<4; j++) {
arr[j][i] = Integer.parseInt(n_st.nextToken());
}
}
for(int i=0; i<N; i++) {
for(int j=0; j<N; j++) {
sum_1[i*N+j] = (long) arr[0][i] + (long) arr[1][j];
}
}
for(int i=0; i<N; i++) {
for(int j=0; j<N; j++) {
sum_2[i*N+j] = (long) arr[2][i] + (long) arr[3][j];
}
}
Arrays.sort(sum_1);
Arrays.sort(sum_2);
//합이 0이 되는 쌍의 개수 찾기
int start_index = 0;
int end_index = 0;
while(end_index != sum_1.length-1) {
end_index = upper_bound(sum_1[start_index], sum_1);
long v = 0 - sum_1[start_index];
ans += (long)(end_index - start_index) * ((long)upper_bound(v, sum_2) - (long)lower_bound(v, sum_2));
start_index = end_index;
}
System.out.println(ans);
}
static int lower_bound(long search_value, long sum[]) {
int min = 0;
int max = sum.length-1;
while(min<max) {
int mid = (min + max)/2;
if(sum[mid]>=search_value) {
max = mid;
} else {
min = mid + 1;
}
}
return min;
}
static int upper_bound(long search_value, long sum[]) {
int min = 0;
int max = sum.length-1;
while(min<max) {
int mid = (min + max)/2;
if(sum[mid]<=search_value) {
min = mid + 1;
} else {
max = mid;
}
}
return min;
}
}