[SWEA] 4837 부분집합의 합 (비트연산)

Heejin Ryu·2021년 2월 16일
1

Algorithm

목록 보기
5/14

모든 문제의 저작권은 swea에 있습니다.
4837 부분집합의 합 (비트연산)

이 문제는 비트연산을 할 줄 알면 바로 슥삭슥삭 풀 수 있는데, 조금이라도 헷갈리면 쓰는데 어려움이 있다.
check함수는 부분집합의 합과 길이가 맞는지 확인하고 결과값을 도출해 내는 함수다.

for i in range(1 << n):  # 원소의 개수 12개
        tmp_arr = []
        for j in range(n + 1):  # 원소의 수만큼 비트를 비교함
            if i & (1 << j):  # i의 j번째 비트가 1이면 j번째 원소 출력
                count += 1
                tmp_arr.append(arr[j])

주요로 볼 곳은 이 위에 코드인데, 아래 있는 코드 그대로 부분만 복사해왔다.
한 줄 한줄 보면, n이 3이라고 가정하자.
1 << n 을 range값에 넣으면 어떻게 될까? 2진수는 굵은 글씨로 쓰겠습니당
1을 n 번 shift 하는 것이기 때문에, 100(2)이 된다. 즉 10진수로는 8이다.
저렇게 1을 n번 shift하게 되면 2의 n승이 되어 8이 된다. n이 4면 1000(2), 10진수로 16이 된다.

그 다음줄에는 부분집합의 원소일 경우 넣어줄 임시 리스트를 선언한 것.

그 다음 for문에서는 range안에 n+1을 넣어주면서 j변수가 아까 위에서 shift한 만큼 생성된다. 아까 n을 3으로 가정했으므로 j는 [0, 1, 2]가 나올거다.

이후 가장 중요한 코드인데, 여기서는 i와 1을 j번 shift한 것을 & 연산하고있다.
비트연산을 아는 사람이면 알 텐데, & 연산은 각 자리수의 연산이 둘 다 True일 때만 True, 즉 1을 반환한다.
이 코드에서 i는 가장 바깥에서 돌고있는 0에서 8까지의 수, 이진수로 표현하면

000
001
010
011
100
101
110
111 

이 된다. 그러면 안쪽 포문에서 1을 j번 shift한 것과 각각 & 연산이 되는데 즉, 예를들어

001 & 001 (1을 0번 shift) # j = 0
001 & 010 (1을 1번 shift) # j = 1
001 & 100 (1을 2번 shift) # j = 2

이렇게 된다!
만약 연산을 했는데, T가 나온 값이 있다면, 그럼 그 자리에 있는 값을 부분으로 갖는다는 의미가 된다. 위에 코드에서는 001(2) & 001(2)의 연산값이 001(2)이 되므로 1은 1을 부분집합으로 갖게 된다는 뜻이다.

예시를 한 가지 더 들어보자면

011 & 001 = 001
011 & 010 = 010
011 & 100 = 000

을 확인해보면 결과값에서 1번과 2번에서 true값이 나왔다. 즉 001(2)는 1이고 010(2)는 2이기 때문에, 011(2), 즉 3의 부분집합이 1과 2가 있다는 것이다.

결론적으로 부분집합인 숫자들을 tmp_arr에 넣어주고 원하는 계산을 한 것!!!

느낀점: 비트연산을 잊지 말것!!

import sys

sys.stdin = open("sample_input.txt")

T = int(input())


def check(arr, N, K):
    total = 0
    for i in range(len(arr)):
        total += arr[i]
    if total == K and len(arr) == N:
        return True
    else:
        return False


for tc in range(1, T + 1):
    count = 0
    result = 0
    N, K = map(int, input().split())  # N개의 원소를 갖고 원소의 합이 K인 부분집합의 개수

    n = 12
    arr = list(range(1, n + 1))
    for i in range(1 << n):  # 원소의 개수
        tmp_arr = []
        for j in range(n + 1):  # 원소의 수만큼 비트를 비교함
            if i & (1 << j):  # i의 j번째 비트가 1이면 j번째 원소 출력
                count += 1
                tmp_arr.append(arr[j])
        if check(tmp_arr, N, K):
            result += 1

    print("#{} {}".format(tc, result))



profile
Chocolate lover🍫 & Junior Android developer🤖

0개의 댓글