[Softeer] 11002/Lv.3/CPTI/Python/파이썬/다이나믹프로그래밍/비트마스크

·2025년 1월 29일
0

문제풀이

목록 보기
25/56
post-thumbnail

⌛️ 언어별 시간/메모리

💡문제

코드런 나라에는 총 N명의 사람이 살고 있습니다. 이 나라에는 사람들의 성격을 나타내는 CPTI(CodeRun Person Type Indicator)라는 지표가 존재합니다.

CPTI는 길이 M의 이진 문자열로 표현되며, 각 자리의 값은 해당 성격 영역에 대해 그 사람이 긍정형(1)인지 부정형(0)인지를 나타냅니다. 예를 들어, 성격 영역이 세 개인 경우, 첫 번째와 세 번째 지표가 긍정형이고 두 번째 지표가 부정형이라면, 101으로 표현됩니다.

두 사람의 CPTI를 비교했을 때, 최대 두 가지 영역에서만 성격이 다르면 두 사람은 친밀감을 느낀다고 합니다. 예를 들어, M=3인 경우, CPTI가 각각 000, 101인 사람들은 성격이 다른 영역이 2개이므로 친밀감을 느낍니다. 그렇지만, CPTI가 각각 010, 101인 사람들은 세 개의 영역 모두에서 성격이 모두 다르므로, 친밀감을 느끼지 않습니다.

코드런 나라의 왕인 James는 이 나라에 살고 있는 사람들 중에서 서로 친밀감을 느끼는 사람 쌍이 얼마나 되는지 알고 싶어합니다.

James를 위해 서로 친밀감을 느끼는 사람 쌍의 수를 계산하는 프로그램을 작성하세요. 사람 쌍의 경우 순서를 고려하지 않습니다.

제약조건

[문제 제약 조건]

[조건 1] 1≤N≤30,000

[조건 2] 1≤M≤30

[서브 태스크별 제약 조건]

Subtask1 (12점): N≤1,000, M≤10

Subtask2 (18점): M≤10

Subtask3 (70점): 문제 조건 외에 별도의 제한이 없습니다.

입력

첫 번째 줄에 사람의 수 N과 CPTI를 나타내는 문자열의 길이 M이 주어집니다.

두 번째 줄부터 다음 N개의 줄 중 i번째 줄에는 i번 사람의 성격을 나타내는 문자열이 주어집니다. 각 문자열은 길이 M이며 0과 1로만 이루어져 있습니다.

출력

친밀감을 느끼는 사람 쌍의 수를 첫 번째 줄에 출력합니다.

예제입력

3 3
001
010
100

예제출력

3

📖내가 작성한 Code

import sys
from collections import Counter


'''
최대 2개까지 다른 거 확인
완전 같은거, 1개 같은거, 2개 같은거 확인 필요
2진수 처럼 보여서 비트밀어보면 될듯
'''


def zero_length(dic, lst):
    count = 0
    for mask in lst:
        mask_freq = dic[mask]
        if mask_freq > 1:
            count += mask_freq * (mask_freq - 1) // 2

    return count


def one_length(dic, lst, length):
    count = 0
    for mask in lst:
        for i in range(length):
            neighbor = mask ^ (1 << i)
            if neighbor in dic and neighbor > mask:
                count += dic[mask] * dic[neighbor]

    return count


def two_length(dic, lst, length):
    count = 0
    for mask in lst:
        for i in range(length):
            for j in range(i + 1, length):
                neighbor = mask ^ (1 << i) ^ (1 << j)
                if neighbor in dic and neighbor > mask:
                    count += dic[mask] * dic[neighbor]

    return  count


def count_friendly_couple(length, lst):
    count = 0
    freq = Counter(lst)
    masks = list(freq.keys())

    count += zero_length(freq, masks)
    count += one_length(freq, masks, length)
    count += two_length(freq, masks, length)

    return count


def main():
    speed_input = sys.stdin.readline
    N, M = map(int, speed_input().split())
    people = list(int(speed_input(), 2) for _ in range(N))
    print(count_friendly_couple(M, people))


if __name__ == '__main__':
    main()

✍️풀이과정

비트마스크라는 개념을 공부하고 오랜시간 사용하지 않아서 좀 찾아보고 풀었다. 2진수처럼 주는 것에서 힌트를 받음.

소프티어에서는 문제를 잘만드는데,

  1. 아주 아슬아슬하게 시간 초과를 낸다 -> 따라서 설계할 때 최적화 실시하기
  2. 메모리 초과가 자주 일어난다 -> 따라서 자료구조 및 알고리즘 설계 잘하기

여기서는 N이 30000개 이므로, 아무 생각 없이 3중 반복문을 하면 그대로 폭발한다.
따라서 최적화 하려고, 비트마스크로 변환해서 저장함

또한, M≤30면 2^30 ≈ 약 10억이므로, 32비트 체제 기준

10억 요소:
1,000,000,000 요소 ×4 바이트/요소(포인터 크기) = 4000MB = 4GB

따라서 리스트 보단 딕셔너리를 활용.

그러면서 이번에 비트마스크도 정리해봄. 링크 참조


🧠 코드 리뷰

  1. 조건 순서 변경 : 예를 들어, neighbor in dic and neighbor > mask 대신 neighbor > mask and neighbor in dic로 조건을 바꾸면, neighbor <= mask인 경우를 빠르게 걸러내어 딕셔너리 조회를 조금 더 줄일 수 있습니다.
    미묘한 최적화지만, 2비트 차이 계산에서는 수천만 단위의 반복이 일어날 수 있으므로 작은 최적화라도 누적 효과가 있을 수 있습니다.

  2. 미세 최적화 :
    루프 안에서 freq[mask] 값을 여러 번 쓰면 지역 변수에 담아 두어 접근 횟수를 줄일 수 있습니다.
    예시)

base_count = freq[mask]
for i in range(M):
    ...
    count += base_count * freq[neighbor]

로 freq[mask]를 매번 조회하지 않고 한 번만 변수에 저장해둡니다.

  1. 입출력 최적화 : Python에서 sys.stdin.readline을 통해 읽는 것은 이미 잘 하고 있지만, 변수를 짧게(예: readline = sys.stdin.readline) 가져와서 사용하는 식으로도 조금 더 줄일 수 있습니다. 문제에 따라서는 sys.setrecursionlimit, sys.stdout.write 등도 고려할 수 있습니다.

  2. 성능 상 주의 : 알고리즘 자체가 O(N×M2) 근처이므로, N=30,000, M=30에서 최악 2,700만 정도의 반복 + 딕셔너리 조회가 발생합니다.
    Python으로도 잘 작성하면 통과할 수도 있지만, 시간 제한이 빡빡하면 PyPy, C++ 등을 고려하는 것이 안전합니다.


🛠️AI 개선 코드

import sys
from collections import Counter

def count_0bit_diff(freq, masks):
    """
    Count the number of pairs that have exactly 0 bits difference.
    (i.e., the same bitmask)
    """
    count = 0
    for mask in masks:
        cnt = freq[mask]
        # 조합 공식 nC2 = n*(n-1)//2
        if cnt > 1:
            count += cnt * (cnt - 1) // 2
    return count

def count_1bit_diff(freq, masks, M):
    """
    Count the number of pairs that differ in exactly 1 bit.
    """
    count = 0
    for mask in masks:
        base_count = freq[mask]
        for i in range(M):
            neighbor = mask ^ (1 << i)
            # neighbor가 mask보다 큰 경우만 카운팅 -> 중복 방지
            if neighbor > mask and neighbor in freq:
                count += base_count * freq[neighbor]
    return count

def count_2bit_diff(freq, masks, M):
    """
    Count the number of pairs that differ in exactly 2 bits.
    """
    count = 0
    for mask in masks:
        base_count = freq[mask]
        for i in range(M):
            for j in range(i + 1, M):
                neighbor = mask ^ (1 << i) ^ (1 << j)
                if neighbor > mask and neighbor in freq:
                    count += base_count * freq[neighbor]
    return count

def count_friendly_pairs(bitmasks, M):
    """
    Count the total number of pairs that differ in at most 2 bits.
    (0-bit, 1-bit, or 2-bit differences)
    """
    freq = Counter(bitmasks)
    masks = list(freq.keys())

    total = 0
    # 0비트 차이
    total += count_0bit_diff(freq, masks)
    # 1비트 차이
    total += count_1bit_diff(freq, masks, M)
    # 2비트 차이
    total += count_2bit_diff(freq, masks, M)

    return total

def solve():
    readline = sys.stdin.readline
    N, M = map(int, readline().split())
    # 이진 문자열(M길이)을 정수(bitmask)로 변환해 저장
    bitmasks = [int(readline(), 2) for _ in range(N)]
    # 최대 2비트 차이 나는 (친밀한) 쌍의 개수 계산
    answer = count_friendly_pairs(bitmasks, M)
    print(answer)

if __name__ == '__main__':
    solve()

💻결과

소프티어문제 보러가기


🖱️참고 링크

bitmask 참고 자료
bitmask 개념

profile
우물 안에서 무언가 만드는 사람

0개의 댓글