[Tri] 백준 - 수열과 쿼리 20 16903번

황준승·2021년 4월 12일
1
post-thumbnail

문제 링크 수열과 쿼리 20

단순하게 문제에 접근해보자.
A라는 배열에 값을 추가하고 A라는 배열 안에 있는 항목 중 하나를 삭제하고 각 항목 각각에 대해서 xor연산을 한다고 한다면 엄청난 시간 복잡도가 걸릴 것이다.

특히 이 xor 연산은 기존에 입력받은 수를 이진수로 바꾸어 각 자리 수 각각에 대해서 xor연산을 하기 때문에 수가 커짐에 따라 상당한 시간이 소요될 수 있다.

심지어 쿼리의 개수는 1 <= M <= 200,000이고 x의 범위는 10^9보다 작거나 같기 때문에 단순하게 문제에 접근한다면?? 엄청난 시간 복잡도에 허덕일 것이다.

그렇다면 어떤 방법이 있을까?
그래서 각각의 입력받은 숫자들을 이진수로 변환하여 트라이에 집어넣어 문제를 해결하려고 한다.

노드 생성

노드를 생성시 삭제를 하기위해 얼마나 많이 지나왔는지 기록하는 data 변수, 그리고 0의 값을 나타내는 left 변수, 그리고 1의 값을 나타내는 right변수로 나누었다.

# 노드 생성
class Node(object):
    def __init__(self, data):
        self.data = data    #다녀왔다는 표시, 한번 다녀왔으면 0, 두번 1, ...
        self.left = {}      #좌측이 0
        self.right = {}     #우측이 1

삽입함수

포문을 통해 해당 단어들을 하나씩 훑으면서 만약 그 단어가 0일 시 왼쪽으로 이동,
그 단어가 1일 경우 오른쪽으로 이동하는 식으로 코드를 구현하였다.

# 트라이 생성
class Trie(object):
    #root노드에 Node 생성
    def __init__(self):
        self.head = Node(0)

    #삽입함수
    def insert(self,word):
        cur = self.head

        for ch in word:
            # 삽입 단어가 0일 시
            if ch == "0":
                if cur.left:
                    # left에 다녀왔다는 표시 += 1 추가 
                    cur.left.data += 1

                else:
                    # 다녀온 적이 없다면 노드 생성
                    cur.left = Node(0)

                cur = cur.left

            # 삽입 단어가 1일 시
            else:
                if cur.right:
                    cur.right.data += 1
                
                else:
                    cur.right = Node(0)              

                cur = cur.right  

삭제 함수

앞서 data 변수를 통해서 각 노드들이 얼만큼 지나왔는지 알 수 있었다. 그 data 변수들을 통해서 data > 0 일 경우 해당 data 값을 -1 한다. 만약 data == 0일경우 해당 데이터에 False값을 부여하여 해당 노드를 삭제한다.

    # 삭제 함수
    def delete(self, word):
        cur = self.head

        for ch in word:
            # 삭제할 값이 0이고
            if ch == "0":
                # 0을 다녀왔다는 표시(data)가 0보다 클 경우 -= 1
                if cur.left.data > 0:
                    cur.left.data -= 1

                # 0일 경우 해당 노드를 없앱니다. 
                else:
                    cur.left = False
                    break
                
                cur = cur.left

            # 삭제할 값이 1이고
            else:
                # 1을 다녀왔다는 표시(data)가 0보다 클 경우 -= 1
                if cur.right.data > 0:
                    cur.right.data -= 1
                
                # 0일 경우 해당 노드를 없앱니다. 
                else:
                    cur.right = False
                    break

                cur = cur.right

xor연산 함수

xor 연산의 경우 만들어진 트라이에서 입력받은 수 반대로 이동하면 된다. 입력받은 수가 0일 경우 1로, 1일 경우 0. 만약 1로 가야 하지만 만들어진 트라이에 1이 없다면 0으로 이동하면 된다. 그런 순서를 통해 코드를 구현하였다.

    #XOR 연산함수 - 만들어진 트라이에서 입력받은 수에서 반대로 가면 된다. 0 -> 1, 1 -> 0
    def xor(self, word):
        cur = self.head
        ans = "0b"

        for ch in word:
            # 0 -> 1
            if ch == "0":
                if cur.right :
                    ans += "1"
                    cur = cur.right
                #0 -> 1로 가고 싶지만 트라이에 해당 값이 없을 때(cur.right == {})
                else:
                    ans += "0"                    
                    cur = cur.left
            # 1 -> 0
            else:
                if cur.left :
                    ans += "1"
                    cur = cur.left 

                #1 -> 0로 가고 싶지만 트라이에 해당 값이 없을 때(cur.left == {})
                else:
                    ans += "0"
                    cur = cur.right           
 
        answer = int(ans,2)
        return answer

전체 코드

import sys
input = sys.stdin.readline

# 노드 생성
class Node(object):
    def __init__(self, data):
        self.data = data    #다녀왔다는 표시, 한번 다녀왔으면 0, 두번 1, ...
        self.left = {}      #좌측이 0
        self.right = {}     #우측이 1

# 트라이 생성
class Trie(object):
    #root노드에 Node 생성
    def __init__(self):
        self.head = Node(0)

    #삽입함수
    def insert(self,word):
        cur = self.head

        for ch in word:
            # 삽입 단어가 0일 시
            if ch == "0":
                if cur.left:
                    # left에 다녀왔다는 표시 += 1 추가 
                    cur.left.data += 1

                else:
                    # 다녀온 적이 없다면 노드 생성
                    cur.left = Node(0)

                cur = cur.left

            # 삽입 단어가 1일 시
            else:
                if cur.right:
                    cur.right.data += 1
                
                else:
                    cur.right = Node(0)              

                cur = cur.right  
    
    # 삭제 함수
    def delete(self, word):
        cur = self.head

        for ch in word:
            # 삭제할 값이 0이고
            if ch == "0":
                # 0을 다녀왔다는 표시(data)가 0보다 클 경우 -= 1
                if cur.left.data > 0:
                    cur.left.data -= 1

                # 0일 경우 해당 노드를 없앱니다. 
                else:
                    cur.left = False
                    break
                
                cur = cur.left

            # 삭제할 값이 1이고
            else:
                # 1을 다녀왔다는 표시(data)가 0보다 클 경우 -= 1
                if cur.right.data > 0:
                    cur.right.data -= 1
                
                # 0일 경우 해당 노드를 없앱니다. 
                else:
                    cur.right = False
                    break

                cur = cur.right

    #XOR 연산함수 - 만들어진 트라이에서 입력받은 수에서 반대로 가면 된다. 0 -> 1, 1 -> 0
    def xor(self, word):
        cur = self.head
        ans = "0b"

        for ch in word:
            # 0 -> 1
            if ch == "0":
                if cur.right :
                    ans += "1"
                    cur = cur.right
                #0 -> 1로 가고 싶지만 트라이에 해당 값이 없을 때(cur.right == {})
                else:
                    ans += "0"                    
                    cur = cur.left
            # 1 -> 0
            else:
                if cur.left :
                    ans += "1"
                    cur = cur.left 

                #1 -> 0로 가고 싶지만 트라이에 해당 값이 없을 때(cur.left == {})
                else:
                    ans += "0"
                    cur = cur.right           
 
        answer = int(ans,2)
        return answer

n = int(input())

trie = Trie()
t = format(0,'b').zfill(30)
trie.insert(t)

for _ in range(n):
    how, x = map(int, input().split())
    # 자릿수의 공정한 비교를 위해 30으로 다 채운다. 
    word = format(x,'b').zfill(30)
    
    # 1일 때 삽입
    if how == 1:
        trie.insert(word)

    # 2일 때 삭제
    elif how == 2:
        trie.delete(word)

    # 3일 때 xor
    elif how == 3:
        result = trie.xor(word)
        print(result)        
profile
다른 사람들이 이해하기 쉽게 기록하고 공유하자!!

0개의 댓글