트리 (Tree)

Heejin·2023년 8월 14일

Coding Test

목록 보기
11/12

트리(Tree)는 계층적인 구조를 나타내며, 노드(node)라고 불리는 요소들의 모음으로 구성된다. 각 노드는 하나의 부모(parent) 노드와 0개 이상의 자식(children) 노드를 가질 수 있다.

Baekjoon - Tree

트리의 주요 주제 및 개념은 다음과 같다.


Trie

트라이(Trie)는 트리 자료 구조의 일종으로, 키-값 쌍의 집합을 저장하고 검색하는 데 사용되는 효율적인 자료 구조이다. "Trie"라는 이름은 retrieval(검색)의 첫 글자를 딴 것으로, 주로 문자열 데이터를 저장하고 검색하는 데 사용된다.

Baekjoon - Trie

트라이 알고리즘의 기본적인 작동 원리는 다음과 같다:

  1. 삽입 (Insertion): 문자열을 삽입할 때, 문자열의 각 글자를 트라이의 노드에 연결하면서 하나씩 따라 내려간다. 이미 존재하는 노드를 만날 경우 해당 노드로 이동하고, 존재하지 않는 노드를 만나면 새로운 노드를 생성하여 연결한다. 문자열의 마지막 글자까지 이 작업을 반복하며, 마지막 노드에 해당 문자열의 종료 플래그를 설정한다. 이때, n 길이의 문자열 m개를 넣어 트라이를 생성하는 시간복잡도는 O(mn)이다.

  2. 검색 (Search): 검색할 문자열의 각 글자를 트라이의 노드에서 차례로 검색하며 내려간다. 만약 노드가 존재하지 않는다면 해당 문자열은 트라이에 존재하지 않는 것이고, 모든 글자를 찾아서 마지막 노드에 도달하면 해당 문자열이 트라이에 존재하는 것이다. 이때, 시간 복잡도는 O(n)이다.

  3. 접두사 검색 및 자동 완성: 접두사 검색이나 자동 완성을 수행할 때, 시작 문자열을 트라이에서 찾아가면서 가능한 후보들을 찾아낸다. 시작 노드에서부터 다음 문자로 이동하면서 가능한 모든 문자열을 생성해낼 수 있다.

class Node:
    def __init__(self, key, data=None):
        self.key = key # 문자 하나
        self.data = data # 문자열의 종료를 알리는 flag
        self.children = {} # 자식노드를 저장

class Trie:
    def __init__(self):
        self.head = Node(None)

    # trie에 string의 각 문자(key)를 하나씩 넣고 마지막 node에 string(data)를 넣는 함수
    def insert(self, string):
        current_node = self.head

        for char in string:
            if char not in current_node.children:
                current_node.children[char] = Node(char)
            current_node = current_node.children[char]
        current_node.data = string

    # string이 저장되어 있는지 확인
    def search(self, string):
        current_node = self.head

        for char in string:
            if char in current_node.children:
                current_node = current_node.children[char]
            else:
                return False

        if current_node.data:
            return True
        else:
            return False

    # prefix로 시작하는 문자열 리스트 반환
    def starts_with(self, prefix):
        current_node = self.head
        words = []

        for p in prefix:
            if p in current_node.children:
                current_node = current_node.children[p]
            else:
                return words

        current_node = [current_node]
        next_node = []

        while True:
            for node in current_node:
                if node.data:
                    words.append(node.data)
                next_node.extend(list(node.children.values()))
            if len(next_node) != 0:
                current_node = next_node
                next_node = []
            else:
                break

        return words
    
trie = Trie()
word_list = ["frodo", "front", "firefox", "fire"]
for word in word_list:
    trie.insert(word)
    
result = trie.search("friend") # False
result = trie.search("frodo") # True
result = trie.starts_with("fire") # ['fire', 'firefox']
result = trie.starts_with("f") # ['fire', 'frodo', 'front', 'firefox']

Binary Tree

이진 트리(Binary Tree)는 각 노드가 최대 두 개의 자식 노드를 가지는 트리 구조이다. 각 노드는 왼쪽 자식 노드와 오른쪽 자식 노드로 구성될 수 있다.

이진 트리는 이진 탐색 트리(Binary Search Tree, BST)라는 특별한 형태의 이진 트리를 포함하며, BST는 특정 규칙을 가지고 데이터를 저장하여 검색이나 삽입, 삭제 작업 등을 효율적으로 수행할 수 있다.

이진 트리는 다양한 순회 방법을 통해 노드를 방문하고, 데이터를 탐색할 수 있다. 대표적인 순회 방법으로는 전위(preorder), 중위(inorder), 후위(postorder) 순회가 있으며, 레벨 순서 순회(level order traversal) 등도 있다.

시간복잡도(Average)시간복잡도(Worst)
삽입(Insertion)O(log n)O(n)
삭제(Deletion)O(log n)O(n)
접근(Access)O(log n)O(n)

Baekjoon - 이진 검색 트리

from collections import deque

class Node:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None

class BinaryTree:
    def __init__(self):
        self.root = None

    # insert
    def insert(self, key):
        if self.root is None:
            self.root = Node(key)
        else:
            self._insert_key(self.root, key)

    def _insert_key(self, node, key):
        if key < node.key:
            if node.left is None:
                node.left = Node(key)
            else:
                self._insert_key(node.left, key)
        else:
            if node.right is None:
                node.right = Node(key)
            else:
                self._insert_key(node.right, key)

    # find
    def find(self, key):
        return self._find_key(self.root, key)

    def _find_key(self, node, key):
        if node is None or node.key == key:
            return node
        if key < node.key:
            return self._find_key(node.left, key)
        return self._find_key(node.right, key)
    

    # delete
    def delete(self, key):
        self.root = self._delete_key(self.root, key)

    def _delete_key(self, node, key):
        if node is None:
            return node

        if key < node.key:
            node.left = self._delete_key(node.left, key)
        elif key > node.key:
            node.right = self._delete_key(node.right, key)
        else:
            if node.left is None:
                temp = node.right
                node = None
                return temp
            elif node.right is None:
                temp = node.left
                node = None
                return temp

            temp = self._min_value_node(node.right)
            node.key = temp.key
            node.right = self._delete_key(node.right, temp.key)

        return node
    
    def _min_value_node(self, node):
        current = node
        while current.left is not None:
            current = current.left
        return current
    
    # preorder
    def preorder_traversal(self):
        result = []
        self._preorder_recursive(self.root, result)
        return result

    def _preorder_recursive(self, node, result):
        if node:
            result.append(node.key)
            self._preorder_recursive(node.left, result)
            self._preorder_recursive(node.right, result)

    # inorder
    def inorder_traversal(self):
        result = []
        self._inorder_recursive(self.root, result)
        return result

    def _inorder_recursive(self, node, result):
        if node:
            self._inorder_recursive(node.left, result)
            result.append(node.key)
            self._inorder_recursive(node.right, result)
    
    # postorder
    def postorder_traversal(self):
        result = []
        self._postorder_recursive(self.root, result)
        return result

    def _postorder_recursive(self, node, result):
        if node:
            self._postorder_recursive(node.left, result)
            self._postorder_recursive(node.right, result)
            result.append(node.key)

    # levelorder
    def levelorder_traversal(self):
        result = []
        if self.root is None:
            return result

        queue = deque()
        queue.append(self.root)

        while queue:
            node = queue.popleft()
            result.append(node.key)
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)

        return result

'''
          40
         /  \
        4    45
         \     \
         34    55
        /      /   
       14     48
      /  \   /  \
     13  15 47  49
'''

array = [40, 4, 34, 45, 14, 55, 48, 13, 15, 49, 47]

bst = BinaryTree()

# insert
for n in array:
    bst.insert(n)

# find
node1 = bst.find(15)
node2 = bst.find(17) # None

# delete
bst.delete(55)
bst.delete(14)
bst.delete(11)

# traversal
bst.preorder_traversal() # [40, 4, 34, 15, 13, 45, 48, 47, 49]
bst.inorder_traversal() # [4, 13, 15, 34, 40, 45, 47, 48, 49]
bst.postorder_traversal() # [13, 15, 34, 4, 47, 49, 48, 45, 40]
bst.levelorder_traversal() # [40, 4, 45, 34, 48, 15, 47, 49, 13]

LCA

최소 공통 조상(Lowest Common Ancestor, LCA) 알고리즘은 트리 구조에서 두 노드의 가장 가까운 공통 부모를 찾는 알고리즘이다. 이 알고리즘은 주로 계층적인 관계를 갖는 자료구조에서 특정한 노드들 간의 관계를 찾을 때 사용된다.

Baekjoon - Lowest Common Ancestor

최소 공통 조상을 찾기 위해서는 트리를 사전에 전처리하여, 미리 각 노드들의 조상 노드에 대한 정보를 저장한다. 이후 두 노드의 공통 조상을 찾을 때, 두 노드의 깊이(depth)를 맞춘 뒤, 같은 높이에 도달할 때까지 각 노드의 조상을 찾아 올라가면서 공통 조상을 찾는다. 이때, 시간복잡도는 일반적으로 O(log n)이며 최악의 경우 O(n)이 된다.

class Node:
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.parent = None

class BinaryTree: # 최소 공통 조상 함수를 구현하기 위해 임시로 binary tree 생성
    def __init__(self):
        self.root = None

    def insert(self, key):
        if self.root is None:
            self.root = Node(key)
        else:
            self._insert_key(self.root, key, None)

    def _insert_key(self, node, key, parent):
        if key < node.key:
            if node.left is None:
                new_node = Node(key)
                new_node.parent = node
                node.left = new_node
            else:
                self._insert_key(node.left, key, node)
        else:
            if node.right is None:
                new_node = Node(key)
                new_node.parent = node
                node.right = new_node
            else:
                self._insert_key(node.right, key, node)

    def find(self, key):
        return self._find_key(self.root, key)

    def _find_key(self, node, key):
        if node is None or node.key == key:
            return node
        if key < node.key:
            return self._find_key(node.left, key)
        return self._find_key(node.right, key)

    ######################################################################
    def find_lca(self, key1, key2):
        node1 = self.find(key1)
        node2 = self.find(key2)

        height_diff = self._get_height(node1) - self._get_height(node2)

        # node1이 더 깊은 위치에 있을 경우, 두 노드의 높이를 맞춰주기
        if height_diff > 0:
            node1, node2 = node2, node1

        for _ in range(abs(height_diff)):
            node2 = node2.parent

        # 공통조상 찾기
        while node1 != node2:
            node1 = node1.parent
            node2 = node2.parent

        return node1 if node1 else None

    def _get_height(self, node):
        height = 0
        while node:
            height += 1
            node = node.parent
        return height
    ######################################################################

'''
         4
       /   \
      2     6
     / \   / \
    1   3 5   7
'''

array = [4, 2, 6, 1, 3, 5, 7]
bst = BinaryTree()
for n in array:
    bst.insert(n)

lca = bst.find_lca(5, 7).key # 6

0개의 댓글