확실히 수업 따라가느라 급급하고, 복습까지는 제대로 생각도 하지 못하고 있다. 그래도 끝까지 과정을 통과해서 이수하는데 의의를 둬야겠다.

01. RNN(Recurrent Neural Network)

  • 정의 : 순차적 데이터를 다루기 위한 뉴럴 네트워크 구조
  • 핵심 아이디어: Hidden state를 사용하여 Sequence 정보를 압축
    ht=tanh(Whhht1+Wxhxt),yt=Whyhth_{t}= tanh(W_{hh}h_{t-1} + W_{xh}x_t), y_t = W_{hy}h_t
  • 구조 유형
    • Many-to-Many (OCR 등)
    • Many-to-One (Video Classification)
    • One-to-Many (Generative Model, GPT류)
    • Multi-layer RNN

02. Vanilla RNN의 문제점

  • 긴 시퀀스 학습 시 Gradient Vanishing / Exploding 문제 발생
  • 동일한 가중치 W를 긴 horizon에 반복 곱하면서 발생
  • 따라서 장기 기억 학습이 불안정함
    ht=tanh(Whhht1+Wxhxt)=tanh([WhhWxh][ht1xt])h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t) = \tanh \left( \begin{bmatrix} W_{hh} & W_{xh} \end{bmatrix} \begin{bmatrix} h_{t-1} \\ x_t \end{bmatrix} \right)
import torch, torch.nn as nn
tanh = torch.tanh

W_hh = nn.Linear(128, 128, bias = False) # h_{t-1} -> h
W_xh = nn.Linear(64, 128, bias = False) # x_t -> h
W_hy = nn.Linear(128, 10, bias = False) # h_t -> y

h_prev = torch.zeros(32, 128) # (batch, hidden)
x_t = torch.randn(32, 64) # (batch, input)

h_t = tanh(W_hh(h_prev) + W_xh(x_t))
y_t = W_hy(h_t)

03. LSTM (Long Short-Term Memory)

  • 등장 배경: RNN의 Gradient Vanishing 문제를 해결
  • 핵심 아이디어: Cell state를 도입하여 gradient 흐름을 보존
    it=σ(Wi[ht1,xt])ft=σ(Wf[ht1,xt])ot=σ(Wo[ht1,xt])gt=tanh(Wg[ht1,xt])ct=ftct1+itgtht=ottanh(ct)\begin{aligned} i_t &= \sigma(W_i [h_{t-1}, x_t]) \\ f_t &= \sigma(W_f [h_{t-1}, x_t]) \\ o_t &= \sigma(W_o [h_{t-1}, x_t]) \\ g_t &= \tanh(W_g [h_{t-1}, x_t]) \\ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\ h_t &= o_t \odot \tanh(c_t) \end{aligned}
  • 게이트 역할
    • Forget gate ff = 이전 기억을 얼마나 잊을지 결정
    • Input gate ii = 새로운 입력을 얼마나 반영할지 결정
    • Gate gate gg = 입력을 cell state 공간으로 매핑
    • Output gate oo = hidden state에 반영
  • 변형 구조
    • Bi-directional LSTM: 앞/뒤 정보를 모두 활용
import torch, torch.nn as nn
sigmoid, tanh = torch.sigmoid, torch.tanh

class LSTMCellManual(nn.Module):
    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.W_i = nn.Linear(in_dim + hid_dim, hid_dim)
        self.W_f = nn.Linear(in_dim + hid_dim, hid_dim)
        self.W_o = nn.Linear(in_dim + hid_dim, hid_dim)
        self.W_g = nn.Linear(in_dim + hid_dim, hid_dim)
        
    def forward(self, x_t, h_prev, c_prev):
        z = torch.cat([h_prev, x_t],dim = -1)
        i = sigmoid(self.W_i(z))
        f = sigmoid(self.W_f(z))
        o = sigmoid(self.W_o(z))
        g = tanh(self.W_g(z))
        c_t = f * c_prev + i * g
        h_t = o * tanh(c_t)
        return h_t, c_t

# 사용 예
cell = LSTMCellManual(64, 128)
h_prev = torch.zeros(32, 128)
c_prev = torch.zeros(32, 128)
x_t = torch.randn(32, 64)
h_t, c_t = cell(x_t, h_prev, c_prev)
# 양방향 LSTM(Bi-LSTM) 한 줄 요약 코드
import torch, torch.nn as nn

bilstm = nn.LSTM(input_size=64, hidden_size=128, num_layers=2,
                 bidirectional=True, batch_first=True)
x = torch.randn(32, 50, 64)  # (batch, seq_len, input_size)
out, _ = bilstm(x)           # out: (32, 50, 256)  (2*hidden)

04. EasyOCR 구조

  • Recognition 방식
    • EasyOCR: CNN + RNN + CTC + Greedy Search
    • Tesseract: RNN + CTC + Beam Search
  • Framework
    • Text Detection: CRAFT (Clova AI, CVPR 2019)
    • Recognition: ResNet + LSTM + CTC
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

lengths = torch.tensor([50, 45, 41, 30])           # 내림차순 정렬 가정
x = torch.randn(4, 50, 64)
x_packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=True)
out_packed, _ = bilstm(x_packed)
out, _ = pad_packed_sequence(out_packed, batch_first=True)  # (4, 50, 256)

05. CTC Loss (Connectionist Temporal Classification)

  • 사용 목적: 입력 시퀀스 길이와 출력 시퀀스 길이가 다르거나, 정렬이 정의되지 않은 경우
  • Blank Symbol (-): 사전에 정의되지 않은 경우를 처리
    CTC Loss 수식
    CTC Loss=logπB1(y)P(πx)\text{CTC Loss} = -\log \sum_{\pi\in B^{-1}(y)}P(\pi\mid x)
import torch, torch.nn as nn

T, N, C = 4, 2, 5   # time, batch, classes(문자+blank)
logits = torch.randn(T, N, C).log_softmax(2)  # (T, N, C), log-prob

# target: 배치 별 정답 라벨들을 일렬로 붙인 형태
targets = torch.tensor([1,2,3,  2,2])  # 예: 첫 배치 '1,2,3', 둘째 배치 '2,2'
target_lengths = torch.tensor([3,2])
input_lengths  = torch.tensor([T, T])

ctc = nn.CTCLoss(blank=0, zero_infinity=True)
loss = ctc(logits, targets, input_lengths, target_lengths)
loss.backward()

06. Decoding 방법

  • (1) Greedy Decoding
    • 각 시점에서 가장 확률이 높은 출력을 선택
    • 연속된 문자와 blank 제거
    • 예: "C-AAT-" → "CAT"
  • (2) Beam Search Decoding
    • 각 타임스텝마다 Beam width만큼 후보를 추려 유지
    • 이전 후보로부터 확장하여 최적의 경로 탐색
    • 장점: Greedy보다 성능이 높음 (하지만 계산량 증가)
import torch

def ctc_greedy_decode(log_probs, blank_id=0):
    # log_probs: (T, C) or (T, N, C) 중 (T, C) 가정
    path = log_probs.argmax(dim=-1).tolist()
    decoded = []
    prev = None
    for p in path:
        if p != blank_id and p != prev:
            decoded.append(p)
        prev = p
    return decoded

# 예시
T, C = 6, 5
log_probs = torch.randn(T, C).log_softmax(-1)
tokens = ctc_greedy_decode(log_probs, blank_id=0)
profile
2025화이팅!

0개의 댓글