[Graph] 4장. Graph Neural Networks: Algorithms

서쿠·2024년 7월 17일
1

GNN 시리즈

목록 보기
4/4
post-thumbnail
post-custom-banner

1. Introduction

그래프 구조 데이터는 복잡한 관계와 상호작용을 모델링하는 데 매우 유용합니다. 이러한 데이터를 효과적으로 분석하고 학습하기 위해 그래프 신경망(Graph Neural Networks, GNN)과 그래프 임베딩(Graph Embedding) 기법이 개발되었습니다.

1.1 그래프 신경망 모델 vs 그래프 임베딩

  • 그래프 신경망 모델: 그래프의 구조와 노드의 특성을 동시에 고려하여 학습하는 신경망 모델입니다. 대표적으로 GCN, GRN, GAT가 있습니다.
  • 그래프 임베딩: 그래프의 구조적 정보를 저차원 벡터 공간에 매핑하는 기법입니다. DeepWalk, Node2Vec, GraphSAGE 등이 있습니다.

1.2 주요 차이점

  1. 학습 방식: GNN은 다양한 학습 방식을 사용하며, 그래프 임베딩은 주로 비지도 학습을 사용합니다.
  2. 특성 정보 활용: GNN은 노드의 특성 정보를 직접 활용하지만, 그래프 임베딩은 주로 구조적 정보만 사용합니다.
  3. 모델 복잡성: GNN은 더 복잡한 구조를 가지며, 그래프 임베딩은 상대적으로 단순합니다.
  4. 귀납적 학습: GNN은 귀납적 학습이 가능하지만, 대부분의 그래프 임베딩은 변환적 방식을 사용합니다.
  5. 동적 그래프 처리: GNN은 동적 그래프 처리에 더 적합한 모델이 있습니다.
  6. 표현력: GNN은 지역적 구조와 전역적 구조를 모두 포착할 수 있지만, 그래프 임베딩은 주로 지역적 구조에 초점을 맞춥니다.

2. Graph Convolutional Networks (GCNs)

GCN은 CNN의 개념을 그래프 데이터에 확장한 모델로, 노드의 특성과 그래프 구조를 동시에 고려합니다. GCN은 주로 스펙트럴 기반과 공간 기반 방법으로 나뉩니다.

2.1 스펙트럴 기반 방법

스펙트럴 기반 GCN은 그래프 라플라시안(Laplacian) 행렬의 고유벡터를 사용하여 그래프의 주파수 도메인에서 합성곱 연산을 수행합니다. 기본 아이디어는 그래프 신호를 주파수 도메인에서 변환하고, 필터링한 후 다시 공간 도메인으로 변환하는 것입니다.

  • 그래프 라플라시안 행렬: L=DAL = D - A (여기서 DD는 차수 행렬, AA는 인접 행렬)
  • 스펙트럴 합성곱: gθx=Ugθ(Λ)UTxg_{\theta} \star x = U g_{\theta}(\Lambda) U^T x
    • UU: 라플라시안의 고유벡터
    • Λ\Lambda: 라플라시안의 고유값

2.2 공간 기반 방법

공간 기반 GCN은 직접 이웃 노드들의 특성을 사용하여 합성곱 연산을 수행합니다. Kipf와 Welling의 GCN이 대표적입니다. 이 모델은 인접 행렬을 사용하여 이웃 노드의 정보를 집계하고, 이를 통해 각 노드의 특성을 업데이트합니다.

  • Kipf와 Welling의 GCN:
    • 노드 특성 행렬 XX와 인접 행렬 AA 사용
    • 정규화 인접 행렬 A^=D12AD12\hat{A} = D^{-\frac{1}{2}} A D^{-\frac{1}{2}}
    • 업데이트 식: H(l+1)=σ(A^H(l)W(l))H^{(l+1)} = \sigma(\hat{A} H^{(l)} W^{(l)})

2.3 GCN 구현 예시 (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, A):
        X = torch.mm(A, X)
        X = self.linear(X)
        return F.relu(X)

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GCN, self).__init__()
        self.layer1 = GCNLayer(in_features, hidden_features)
        self.layer2 = GCNLayer(hidden_features, out_features)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X

# 데이터 준비
X = torch.randn(5, 10)  # 5개의 노드, 10차원 특성
A = torch.eye(5) + torch.rand(5, 5)  # 간단한 인접 행렬 예시

# 모델 초기화 및 학습
model = GCN(10, 16, 2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Forward pass
output = model(X, A)

3. Graph Recurrent Networks (GRNs)

Graph Recurrent Networks (GRNs)은 RNN의 개념을 그래프에 적용한 모델로, 주로 시간적인 순서나 계층적 구조를 가지는 그래프 데이터에 유용합니다.

3.1 게이트된 그래프 신경망 (Gated Graph Neural Networks, GGNN)

GGNN은 GRU(Gated Recurrent Unit)를 그래프에 적용하여 시퀀스 출력이 필요한 문제에 적합합니다.

  • GGNN 구조: 각 노드는 GRU를 통해 자신의 상태를 업데이트하며, 인접 노드로부터의 메시지를 받아 이를 통합합니다.
  • 업데이트 식:
    • 메시지 전달: Mt=AHtM_t = A H_t
    • GRU 업데이트: Ht+1=GRU(Ht,Mt)H_{t+1} = \text{GRU}(H_t, M_t)

3.2 GGNN 구현 예시 (PyTorch)

class GGNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GGNNLayer, self).__init__()
        self.gru = nn.GRU(in_features, out_features)
    
    def forward(self, X, A):
        M = torch.mm(A, X)
        X, _ = self.gru(M.unsqueeze(0))
        return X.squeeze(0)

class GGNN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(GGNN, self).__init__()
        self.layer1 = GGNNLayer(in_features, hidden_features)
        self.layer2 = GGNNLayer(hidden_features, out_features)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X

3.3 트리 LSTM (Tree LSTM)

Tree LSTM은 LSTM을 트리 구조에 적용한 모델로, 계층적 구조를 가진 데이터에 적합합니다. 트리 구조에서는 부모 노드와 자식 노드 간의 관계를 모델링합니다.

Tree LSTM 구조

각 노드는 부모 노드와 자식 노드의 상태를 통합하여 자신의 상태를 업데이트합니다. 이를 통해 트리 구조를 따라 정보를 전파하고, 트리의 루트 노드에서 전체 트리의 정보를 집약할 수 있습니다.

업데이트 식

  • 자식 노드 상태 집합: C={hjjchildren(i)}C = \{h_j | j \in \text{children}(i)\}
  • 부모 노드 상태 집합: P={hpp=parent(i)}P = \{h_p | p = \text{parent}(i)\}
  • 상태 업데이트: hi=LSTM(C,P,hi)h_i = \text{LSTM}(C, P, h_i)

각 노드 ii의 상태는 다음과 같이 업데이트됩니다:

  • ii 노드의 자식 노드 상태 합: hjh_jii의 자식 노드 상태입니다.
  • ii 노드의 부모 노드 상태 합: hph_pii의 부모 노드 상태입니다.
  • ii 노드의 새로운 상태: hih_i는 LSTM을 통해 업데이트됩니다.

Tree LSTM의 핵심 수식

$$ it = \sigma(W_i x_t + \sum{j \in \text{children}(i)} Ui h_j + b_i) f{tj} = \sigma(Wf x_t + \sum{j \in \text{children}(i)} Uf h_j + b_f) o_t = \sigma(W_o x_t + \sum{j \in \text{children}(i)} Uo h_j + b_o) u_t = \tanh(W_u x_t + \sum{j \in \text{children}(i)} Uu h_j + b_u) c_t = i_t \odot u_t + \sum{j \in \text{children}(i)} f_{tj} \odot c_j h_t = o_t \odot \tanh(c_t) $$

여기서 it,ftj,ot,uti_t, f_{tj}, o_t, u_t는 각각 입력 게이트, 잊음 게이트, 출력 게이트, 업데이트 게이트를 나타내며, ctc_thth_t는 각각 셀 상태와 히든 상태를 나타냅니다.

코드 예시 (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class TreeLSTMCell(nn.Module):
    def __init__(self, in_features, out_features):
        super(TreeLSTMCell, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W_i = nn.Linear(in_features, out_features)
        self.U_i = nn.Linear(out_features, out_features)
        self.W_f = nn.Linear(in_features, out_features)
        self.U_f = nn.Linear(out_features, out_features)
        self.W_o = nn.Linear(in_features, out_features)
        self.U_o = nn.Linear(out_features, out_features)
        self.W_u = nn.Linear(in_features, out_features)
        self.U_u = nn.Linear(out_features, out_features)
    
    def forward(self, x, child_h, child_c):
        child_h_sum = torch.sum(child_h, dim=0)
        child_c_sum = torch.sum(child_c, dim=0)
        
        i = torch.sigmoid(self.W_i(x) + self.U_i(child_h_sum))
        f = torch.sigmoid(self.W_f(x) + self.U_f(child_h_sum))
        o = torch.sigmoid(self.W_o(x) + self.U_o(child_h_sum))
        u = torch.tanh(self.W_u(x) + self.U_u(child_h_sum))
        
        c = i * u + f * child_c_sum
        h = o * torch.tanh(c)
        
        return h, c

class TreeLSTM(nn.Module):
    def __init__(self, in_features, out_features):
        super(TreeLSTM, self).__init__()
        self.cell = TreeLSTMCell(in_features, out_features)
    
    def forward(self, x, children_h, children_c):
        h, c = self.cell(x, children_h, children_c)
        return h, c

3.4 그래프 LSTM (Graph LSTM)

그래프 LSTM은 LSTM을 일반 그래프에 확장한 모델로, 노드 순서를 결정하는 방법이 중요합니다. 그래프 구조에서는 순환 경로가 있을 수 있으므로, 순서를 정의하는 것이 핵심입니다.

그래프 LSTM 구조

그래프 LSTM에서는 각 노드가 자신의 이웃 노드로부터 정보를 받아 상태를 업데이트합니다. 노드 순서를 결정하는 다양한 방법이 있으며, 이는 모델의 성능에 중요한 영향을 미칩니다.

업데이트 식

각 노드 ii의 상태는 다음과 같이 업데이트됩니다:

$$ hi = \text{LSTM}(h{\text{neighbors}(i)}, x_i) $$

여기서 hneighbors(i)h_{\text{neighbors}(i)}는 노드 ii의 이웃 노드들의 상태를 나타내며, xix_i는 노드 ii의 입력 특성을 나타냅니다.

그래프 LSTM의 핵심 수식

$$ it = \sigma(W_i x_t + U_i h{t-1} + bi) f_t = \sigma(W_f x_t + U_f h{t-1} + bf) o_t = \sigma(W_o x_t + U_o h{t-1} + bo) u_t = \tanh(W_u x_t + U_u h{t-1} + bu) c_t = f_t \odot c{t-1} + i_t \odot u_t $$
$$ h_t = o_t \odot \tanh(c_t) $$

코드 예시 (PyTorch)

class GraphLSTMCell(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphLSTMCell, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W_i = nn.Linear(in_features, out_features)
        self.U_i = nn.Linear(out_features, out_features)
        self.W_f = nn.Linear(in_features, out_features)
        self.U_f = nn.Linear(out_features, out_features)
        self.W_o = nn.Linear(in_features, out_features)
        self.U_o = nn.Linear(out_features, out_features)
        self.W_u = nn.Linear(in_features, out_features)
        self.U_u = nn.Linear(out_features, out_features)
    
    def forward(self, x, neighbor_h, neighbor_c):
        neighbor_h_sum = torch.sum(neighbor_h, dim=0)
        neighbor_c_sum = torch.sum(neighbor_c, dim=0)
        
        i = torch.sigmoid(self.W_i(x) + self.U_i(neighbor_h_sum))
        f = torch.sigmoid(self.W_f(x) + self.U_f(neighbor_h_sum))
        o = torch.sigmoid(self.W_o(x) + self.U_o(neighbor_h_sum))
        u = torch.tanh(self.W_u(x) + self.U_u(neighbor_h_sum))
        
        c = i * u + f * neighbor_c_sum
        h = o * torch.tanh(c)
        
        return h, c

class GraphLSTM(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphLSTM, self).__init__()
        self.cell = GraphLSTMCell(in_features, out_features)
    
    def forward(self, x, neighbors_h, neighbors_c):
        h, c = self.cell(x, neighbors_h, neighbors_c)
        return h, c

4. Graph Attention Networks (GAT)

Graph Attention Networks (GAT)는 주의 메커니즘을 그래프에 적용한 모델입니다. GAT는 자기 주의(self-attention) 메커니즘을 사용하여 각 노드의 이웃 노드들에 서로 다른 가중치를 부여합니다.

4.1 주요 특징

  • self-attention 메커니즘: 각 노드의 이웃 노드들에 대해 가중치를 계산하여, 중요한 이웃 노드의 정보를 더 많이 반영합니다.
  • multi-head-attention 메커니즘: 여러 주의 헤드를 사용하여 안정성을 향상시키고, 각 헤드의 출력을 결합하여 최종 출력을 만듭니다.

4.2 GAT 구현 예시 (PyTorch)

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads=1):
        super(GATLayer, self).__init__()
        self.num_heads = num_heads
        self.attention_heads = nn.ModuleList([nn.Linear(in_features, out_features) for _ in range(num_heads)])
        self.attention_coeffs = nn.ModuleList([nn.Linear(2 * out_features, 1) for _ in range(num_heads)])
    
    def forward(self, X, A):
        outputs = []
        for head, coeff in zip(self.attention_heads, self.attention_coeffs):
            H = head(X)
            N = X.size(0)
            H_repeated_in_chunks = H.repeat_interleave(N, dim=0)
            H_repeated_alternating = H.repeat(N, 1)
            all_combinations_matrix = torch.cat([H_repeated_in_chunks, H_repeated_alternating], dim=1)
            e = F.leaky_relu(coeff(all_combinations_matrix).view(N, N))
            attention = F.softmax(e.masked_fill(A == 0, -1e9), dim=1)
            outputs.append(torch.matmul(attention, H))
        return torch.cat(outputs, dim=1)

class GAT(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_heads=1):
        super(GAT, self).__init__()
        self.layer1 = GATLayer(in_features, hidden_features, num_heads)
        self.layer2 = GATLayer(hidden_features * num_heads, out_features, num_heads)
    
    def forward(self, X, A):
        X = self.layer1(X, A)
        X = self.layer2(X, A)
        return X

5. 결론

이 챕터에서는 Graph Neural Networks의 주요 알고리즘인 GCN, GRN, GAT에 대해 살펴보았습니다. 각 알고리즘은 그래프 데이터를 처리하는 고유한 방식을 가지고 있으며, 다양한 그래프 관련 작업에 적용될 수 있습니다.

  • GCN은 노드의 특성과 그래프 구조를 동시에 고려하여 효과적인 노드 표현을 학습합니다.
  • GRN은 시간적 또는 계층적 구조를 가진 그래프 데이터를 처리하는 데 적합합니다.
  • GAT는 주의 메커니즘을 통해 이웃 노드들의 중요도를 학습하여 더 유연한 정보 aggregation을 가능하게 합니다.

이러한 Graph Neural Networks 알고리즘들은 각각의 장단점을 가지고 있으며, 문제의 특성에 따라 적절한 모델을 선택하는 것이 중요합니다.

  • GCN은 구현이 간단하고 효과적이지만, 깊은 레이어를 쌓기 어렵다는 단점이 있습니다.
  • GRN은 시퀀스 데이터나 트리 구조 데이터에 강점을 보이지만, 일반적인 그래프에 적용할 때는 노드 순서 결정이 어려울 수 있습니다.
  • GAT는 노드 간의 중요도를 학습할 수 있어 유연성이 높지만, 계산 복잡도가 상대적으로 높습니다.
profile
Always be passionate ✨
post-custom-banner

0개의 댓글