https://www.acmicpc.net/problem/1182
어제 본 카카오커머스 채용 코딩 테스트에서 나온 문제의 기본이다.
어제 본 코딩 테스트에 대해 자세히 이야기할 수는 없지만 대충 부분 수열을 구해 문제의 조건에 맞게 답을 구하는 문제이다.
어제 문제를 풀 때는 부분 수열을 어떻게 구해야할 지 난감해 조금 불편한 방법으로 구했었는데 테스트를 끝내고 복기하는 과정에서 비트마스크 방법으로 풀면 편하게 풀이할 수 있었겠다라는 생각이 들어 간단한 문제를 풀며 비트 마스크 방법에 대해 공부했다.
[1,2,3,4,5] 배열의 부분 수열을 표현할 때,
[1,3], [1,4,5] 이런 식으로 표현도 가능하지만
[1,3] -> [1,0,1,0,0] 이런 식으로 부분 수열의 값을 전체 배열의 인덱스와 일치하게 함으로서 표현을 단순하게 할 수 있다.
1이 있으니 1의 인덱스인 0번째에 1을 저장하고 3도 마찬가지이다.
각 요소는 인덱스처럼 표현할 수 있다.
즉, 집합 i 번째 요소가 부분집합에 존재한다면 1 을 의미하고, 그렇지 않으면 0 을 의미한다.
비트마스크 방법을 적용해 풀기 위해서는 주어진 배열에서 모든 경우의 수에 대한 비트바스크가 이루어진 배열을 구해야했는데 이 과정이 조금 힘들었다.
만들어질 수 있는 수열의 경우의 수는 2^n - 2와 같다. (0과 2^n은 제외)
여기서 어떻게 부분 집합을 구해서 비트 마스크를 할까하는 고민이 있었다.
처음 생각은
부분집합 생성 -> 비트마스킹
이었는데 코드가 복잡하고 문제가 있어도 수정하기가 쉽지않았다.
그런데 생각을 바꿔보니 간단했다. 어짜피 모든 경우의 수를 다 구해봐야 하기 때문에 1부터 2^n-1까지 반복하면서 &연산을 통해 인덱스를 추출해내 부분집합을 구하는 것이다.
if ((i & (1 << j)) != 0)
코드 상으로는 위와 같이 구현했는데 i는 1부터 2^n-1까지 증가하고 j는 0부터 n-1까지 증가한다.
만약 i == 3일 때 j가 0이면 3은 011이고 1은 001이기 때문에 &연산을 통해 1이 나오고 0이 아니기 때문에 부분 집합의 0번 인덱스에 1이 저장된다.
즉 연산결과가 0이 아니면 j번째 인덱스가 1이 되는 부분집합을 구할 수 있다.
이후는 문제 조건대로 합을 구해 s값과 비교해서 count를 증가시키는 방식으로 구현했다.
import java.io.*;
public class Main {
/*static int count = -1;
static int[] dx = {-2, -2, -1, -1, 1, 1, 2, 2};
static int[] dy = {1, -1, 2, -2, 2, -2, 1, -1};*/
static boolean visited[];
static int n, s;
//static long graph[][];
static long min = Long.MAX_VALUE;
static int arr[];
static int nums[];
static int count = 0;
static int mask[];
public static void main(String[] args) throws IOException {
// write your code here
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
String[] num = br.readLine().split(" ");
n = Integer.parseInt(num[0]);
s = Integer.parseInt(num[1]);
nums = new int[n];
visited = new boolean[n];
String[] tmp = br.readLine().split(" ");
for (int i = 0; i < n; i++)
nums[i] = Integer.parseInt(tmp[i]);
int m = (int) Math.pow(2, n);
for (int i = 1; i < m; i++) {
int[] temp = new int[n];
for (int j = 0; j < n; j++) {
if ((i & (1 << j)) != 0) {
temp[j] = nums[j];
}
}
int sum = 0;
for (int k = 0; k < n; k++) {
if(temp[k] != 0) {
sum = sum + temp[k];
}
}
if(sum == s)
count++;
}
bw.write(count + "\n");
br.close();
bw.close();
}
}