GNN은 CNN과 유사하게 convolution 연산을 수행한다. 다만, 그래프의 불규칙한 구조를 반영할 수 있도록 기존의 1D 또는 2D convolution가 아닌 graph convolution 연산을 사용한다. 아래 그림 출처
Graph convolution 연산의 핵심은 노드를 임베딩함에 있어 엣지로 연결된 노드들, 즉 이웃 노드들의 정보를 활용하는 것이다. 하나의 중심 노드에 대해, 그 노드가 이웃하는 노드들의 정보를 하나로 모아 중심 노드를 표현할 수 있는 벡터로 출력한다. 이러한 과정은 이웃 노드들이 엣지를 따라 중심 노드로 정보를 전달한다는 측면에서 message passing이라고도 표현된다.
이를 수식으로 표현한다면 다음과 같다.
: 노드 의 feature 벡터, 초기()에는 input feature와 같지만 레이어를 지나며 임베딩 벡터로 변환
: 노드 에서 로 향하는 엣지의 feature 벡터, 반드시 존재하지는 않음
: aggregation 함수
: 중심 노드와 이웃 노드, 엣지의 feature 벡터로 message를 계산하는 함수
: aggregation의 결과를 업데이트하는 함수
: 레이어 인덱스
수식에 따르면, 노드 를 임베딩하는 과정은 엣지를 기준으로 중심 노드 와 이웃하는 노드들 의 정보를 가공해 message로 만들고, 이들을 aggregate한 결과를 다시 가공하여 최종 결과를 도출하는 것으로 정리할 수 있다.
이러한 과정에서 aggregation 함수 는 합이나 평균, 최댓값과 같이 출력이 입력 순서에 영향을 받지 않는 방식으로 설정되어야 한다. 또한, 모델이 학습을 통해 파라미터를 업데이트할 수 있도록 와 는 MLPs처럼 미분 가능한 형태로 정의되어야 한다.
MessagePassing
클래스PyG에는 torch_geometric.nn.conv.MessagePassing
, 줄여서 MessagePassing
클래스가 구현되어 있다. 클래스의 이름처럼 MessagePassing
은 이웃한 노드 간 message 전파(propagation), 즉 그래프 신경망의 message passing을 관장한다. 위의 수식에서 와 , 만 사용자가 정의하면 연산적인 부분은 알아서 실행되도록 구현되어 있는 것이다.
class MessagePassing(torch.nn.Module):
...
MessagePassing
은 torch.nn.Module
을 상속 받기 때문에, 모델의 학습을 위한 기능들을 사용 및 구현할 수 있다. Message passing을 기반으로 하는 그래프 신경망 모델들은 다시 MessagePassing
을 상속 받아 정의된다.
__init__
def __init__(
self,
aggr: Optional[Union[str, List[str], Aggregation]] = "add",
*,
aggr_kwargs: Optional[Dict[str, Any]] = None,
flow: str = "source_to_target",
node_dim: int = -2,
decomposed_layers: int = 1,
**kwargs,
):
...
클래스를 생성하고 초기화하는데 있어 필요한 인자는 다음과 같다.
"add"
, "mean"
, "max"
등의 키워드 또는 torch_geometric.nn.Aggregation
객체 사용 가능None
으로도 설정 가능 - MessagePassing
객체의 aggregate()
메소드를 통해 구현"add"
aggregate()
로 전달하기 위한 인자None
"source_to_target"
와 "target_to_source"
의 이지선다"source_to_target"
propagate()
에 입력되는 edge_index와 관련-2
1
propagate()
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
...
out = self.message(**msg_kwargs)
...
out = self.aggregate(**aggr_kwargs)
...
out = self.update(**update_kwargs)
...
return out
본격적인 message passing이 실행을 담당하며 propagate()
를 통해 시작된다. 호출 시에는 edge_index와 size, 기타 노드 임베딩 과정에 필요한 요소(**kwargs
)들을 인자로 받는다.
여기서 edge_index는 그래프를 구성하는 연결 관계로 message가 전달되는 경로를 설정하는 인자이다. torch.Tensor
또는 SparseTensor
중 어느 형태로 입력되는지에 따라 분기점이 존재한다.
torch.Tensor
로 입력되는 경우, torch.long
, 크기는 [2, 엣지 수]
를 만족해야 하며, 2개의 행이 source_node와 target_node의 순서로 구성되어야 함 (flow="source_to_target"
).SparseTensor
로 입력되는 경우,SparseTensor
SparseTensor
로 나타내면,이후, message()
, aggregate()
, update()
를 순차적으로 호출하여 노드의 임베딩 벡터를 업데이트한다. 사용자가 구현하는 방식에 따라 message()
와 aggregate()
는 message_and_aggregate()
로 합치는 것도 가능하다.
message()
: def message(self, x_j: Tensor) -> Tensor:
return x_j
노드 에서 노드 로 전달하는 메시지를 생성하는 함수로, propagate()
내에서 호출되기 때문에 propagate()
에 전달된 인자들을 사용하는 것이 가능하다.
기본적으로는 입력된 텐서 x_j
를 그대로 출력하게끔 작성되어 있다. 메시지를 생성하는 방식에 따라 그래프 신경망이 구분되므로, 구체적인 방식은 오버라이딩을 통해 구현되도록 한다.
aggregate()
: def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
dim=self.node_dim)
torch.nn.aggr.Aggregation
을 통해 message aggregation을 수행한다. message()
의 출력값이 입력되며, message()
와 마찬가지로 propagate()
에 전달된 인자들을 사용하는 것이 가능하다.
update()
: def update(self, inputs: Tensor) -> Tensor:
return inputs
aggregate()
의 출력값을 최종 노드 임베딩 벡터로 출력하기 전에 업데이트하는 역할을 한다. message()
처럼 구체적인 방식은 오버라이딩을 통해 구현되며, propagate()
에 전달된 인자들을 사용할 수 있다.
MessagePassing
을 이용해서 그래프 신경망 모델을 만들어보자.
GCN 레이어를 만든다고 하면, message passing 연산은 다음과 같이 정의할 수 있다.
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # Step 3: 이웃 노드들의 벡터들을 더하는 방식으로 aggregation
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: 입력된 이전 레이어의 feature 벡터를 선형 변환
x = self.lin(x)
# Step 3: 인접 행렬에 self-loop 추가
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 정규화를 위해 노드가 갖는 이웃의 수 계산
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# 주어진 조건으로 message passing
out = self.propagate(edge_index, x=x, norm=norm)
# Step 4: bias 벡터 추가
out += self.bias
return out
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 2: 노드가 갖는 이웃의 수로 정규화
return norm.view(-1, 1) * x_j
위 코드에서 GCNConv
는 MessagePassing
을 상속 받아 정의된다.
__init__()
에서는 message passing에 있어 aggregation 방식을 "add"
로 설정하고,
선형 변환 self.lin = Linear(...)
와 bias 벡터 self.bias=Parameter(...)
를 생성하여 파라미터를 초기화한다.
레이어 객체의 작동 매커니즘은 forward()
메소드를 통해 구현한다. 입력된 x
와 edge_index
에 대해 선형 변환 및 self-loop를 추가하는 과정이 포함되어 있다. propagate()
를 호출하여 message passing을 진행하며 이에 필요한 message 생성 함수는 message()
를 오버라이딩하여 구현한 것을 확인할 수 있다.
이렇게 만든 GCNConv
는 다음과 같이 선언하여 사용하는 것이 가능하다.
conv = GCNConv(in_channels=16, out_channels=32)
x = conv(x, edge_index)