[Algorithm] Trie (Python code)

Ziggy Stardust·2024년 9월 24일
0

Algorithm

목록 보기
1/2
post-thumbnail

Trie

많은 문자열 데이터들이 있을 때 이를 효율적으로 다룰 수 있게 해줍니다.

단편적 예시로는 영어사전이 있습니다.

아래와 같은 데이터가 있습니다.

APPLE
APPLY
BANANA
KIWI

이 데이터에서 특정 문자열이 이미 존재하는지 찾는다 가정할 때

단순한 방식으로는 특정 문자열을 모든 문자열과 비교하는 방법이 있습니다.

이 때 시간복잡도는 (국어사전의 모든 문자열의 수 x 특정 문자열의 길이) 이 됩니다.

하지만 Trie 를 사용하게 되면 특정 문자열의 길이만으로 평가할 수 있습니다.

Trie 는 최상위 노드부터 시작해 공통적인 경로는 하나의 경로로 관리합니다. (다른 표현으로는 중복되는 접두사는 공통 노드로 관리합니다.)

APPLE, APPLY 로 예를 들자면 APPLE 과 APPLY 는 APPL 이라는 공통경로를 가집니다.

Trie 의 최대 깊이는 가장 긴 문자열의 길이만큼 됩니다.



동작

동작은 크게 insert, find 두 가지가 있습니다.

기본적으로 한 노드 당 가질 수 있는 자식 노드들은 알파벳밖에 없습니다.
따라서 한 노드는 알파벳 크기만큼 배열을 가집니다. (자식노드들을 가리키게 됩니다.)

그래서 두 동작을 살펴보자면

  • insert : 현재 노드를 다루며 이 현재 노드는 최상위 노드부터 시작해 문자열의 각 문자들을 자식 노드로 생성하며 Trie를 형성합니다. 최종결과는 위 예시 이미지와 같습니다.

  • find : 현재 노드를 다루며 이 현재 노드는 최상위 노드부터 시작해 자식 노드에 문자열의 각 문자들이 있는지 확인합니다.

find 시 문자열의 존재판단 과정 중 신경쓸 부분이 있습니다.

  1. 찾고자 하는 문자에 달하기 전 Trie 의 경로가 끝나버림
    2.Trie 에서 일치한 문자열이 우연의 일치로 생성된 경로일 때 입니다.

이러한 케이스를 다루기 위해 각 노드가 Trie 에서 존재하는 문자열인지 판단하면 됩니다.
이는 단순히 각 노드가 insert 당시 문자열의 마지막 단어로 생성된 노드에 관리됨을 나타내는 값을 가지고 있으면 됩니다.



구현 코드 (Python)

단순 배열 응용

해당 코드는 다음 문제의 해답코드이기도 합니다.

from sys import stdin

n, p = map(int, stdin.readline().split())

ROOT = 0
unused = 1
mx = 500 * 10000 + 5
chk = [0] * (mx)
mem = [[-1] * 26 for _ in range(mx)]

def insert(s):
    global ROOT, unused, chk, mem
    cur = ROOT
    for c in s:
        cidx = ord(c) - ord('a')
        if mem[cur][cidx] == -1:
            mem[cur][cidx] = unused
            unused += 1
        cur = mem[cur][cidx]
    chk[cur] +=  1

def find(s):
    global ROOT, unused, chk, mem
    cur = ROOT
    for c in s:
        cidx = ord(c) - ord('a')
        if mem[cur][cidx] == -1:
            return 0
        cur = mem[cur][cidx]
    if chk[cur] > 0: return 1
    return 0

for _ in range(n):
    s = stdin.readline()
    insert(s)
    
ans = 0
for _ in range(p):
    s = stdin.readline()
    ans += find(s)
    
print(ans)


Hash 응용

위의 구현을 따르면 시간복잡도와 공간복잡도는 아래와 같습니다.

|S| = 찾고자하는 문자열의 길이
N= 모든 노드의 개수

  • 시간 복잡도 : O(|S|)
  • 공간 복잡도 : O(26 * N)

이 문제의 경우 삽입하는 문자열의 최대 길이가 500이라는 작은 수와 10,000이라는 적당한 삽입 횟수가 있었기에 Trie 로 시도할 수 있었습니다.

현재는 알파벳 소문자만 다루고 있기에 자식 노드를 찾을 때 26개의 상수값을 가집니다.
만약에 더 다양한 문자들이 다뤄진다면 까다로워질 수 있습니다.

그럴땐 Hash 를 사용하는 것이 적합할 수 있습니다. (물론 Hash 를 통해 자식 노드 조회의 성능을 올릴 수 있지만 다루는 크기에 따라 기존 방식을 적용하는게 더 나을 수 있습니다.)

from sys import stdin

n, p = map(int, stdin.readline().split())

ROOT = 0
unused = 1
mx = 500 * 10000 + 5
chk = [0] * (mx)
mem = [dict() for _ in range(mx)]

def insert(s):
    global ROOT, unused, chk, mem
    cur = ROOT
    for c in s:
        if c not in mem[cur]:
            mem[cur][c] = unused
            unused += 1
        cur = mem[cur][c]
    chk[cur] +=  1

def find(s):
    global ROOT, unused, chk, mem
    cur = ROOT
    for c in s:
        if c not in mem[cur]:
            return 0
        cur = mem[cur][c]
    if chk[cur] > 0: return 1
    return 0

for _ in range(n):
    s = stdin.readline()
    insert(s)
    
ans = 0
for _ in range(p):
    s = stdin.readline()
    ans += find(s)
    
print(ans)

사실 Hash 로만 구현하여 문자가 존재하는지 확인하는 방법도 있습니다.

d = dict()
n, i = map(int, input().split())

for _ in range(n):
    d[input()] = True

ans = 0
for _ in range(i):
    if input() in d:
        ans += 1

print(ans)

Class 로 구현

다음 문제의 해답 코드이기도 합니다.

from collections import defaultdict

class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, words):
        node = self.root
        for word in words:
            node = node.children[word]

    def print_trie(self, node=None, depth=0):
        if node is None:
            node = self.root
        for word in sorted(node.children.keys()):
            print("--" * depth + word)
            self.print_trie(node.children[word], depth + 1)

n = int(input())
trie = Trie()

for _ in range(n):
    path = input().split()[1:]
    trie.insert(path)

trie.print_trie()

단순히 클래스를 사용하면 코드 작성이나 관리에 있어서 편리할 수 있습니다. 신경쓸 부분이 있다면 64비트 아키텍쳐에서 포인터의 크기는 8바이트 이기에 기존 정수형(4바이트)을 다루는 코드보다 낭비적인 상황이 올 수도 있다는 점입니다.

문자열 삭제 시

chk 에서 해당 노드를 False 로 다루면 됩니다.

사용처

문자열을 다루기 용이한 자료구조..이다보니 빠른 속도가 필요한 검색 엔진이나 기타 문자열 처리 프로그램에서 자주 사용됩니다.

참고자료

종만북
https://blog.encrypted.gg/1059

profile
spider from mars

0개의 댓글