<References>
💡 target node에 대한 neighbor node의 중요도가 모두 같지 않다
→ 특별히 더 중요한 노드가 있다고 할 때, 그 weight를 automatic하게 학습하는 방법?
⇒ GAT
Graph Attention Layer
Concatenation : in layer
Average : on the final prediction layer of the network
👍Advantages
Message Passing Implementation
Creating Message Passing Networks - pytorch_geometric documentation
torch_geometric.nn.conv.message_passing - pytorch_geometric documentation
PyTorch Geometric MessagePassing base class
PyTorch Geometric 탐구 일기 - Message Passing Scheme (1)
message()
, update()
, aggregation
를 설정flow
: flow direction of message passing : 주변 노드로부터 정보를 전달 받을지, 전달할지 결정 (either "source_to_target"
or "target_to_source"
)node_dim
: 노드의 차원을 의미ex. Message Passing interface 예시
class MyOwnConv(MessagePassing):
def __init__(self):
super(MyOwnConv, self).__init__(aggr='add') # add, mean or max aggregation
def forward(self, x, edge_index, e):
return self.propagate(edge_index, x=x, e=e) # pass everything needed for propagation
def message(self, x_j, x_i, e): # Node features get automatically mapped to source(_j) and target(_i) nodes
return x_j * e
torch.nn.Module
이 superclasstorch.nn.Module
⇒ torch_geometric.nn.MessagePassing
⇒ OurCustomLayer
torch_geometric.nn.conv
layer 구현체들이 Message Passing Scheme을 따름MessagePassing - propagate()
def propagate():
if mp_type == 'adj_t' and self.fuse and not self.__explain__:
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
# Otherwise, run both functions in separation.
elif mp_type == 'edge_index' or not self.fuse or self.__explain__:
msg_kwargs = self.__distribute__(self.inspector.params['message'],
coll_dict)
out = self.message(**msg_kwargs)
out = self.aggregate(out, **aggr_kwargs)
out = self.update(out, **update_kwargs)
propagate(edge_index, size=None, **kwargs)
size
→ bipartite graph처럼 (N,M) 사이즈도 propagate 가능size = (N,M)
으로 넣어줌size = None
일경우 정사각행렬message()
와 update()
함수를 차례로 호출message
와 aggregate
함수는 분리되거나 합쳐져 사용update
함수를 통해 출력값 생성MessagePassing - message()
def message(self, x_j: torch.Tensor) -> torch.Tensor:
# need to construct
return x_j
message(**kwargs)
MessagePassing - update()
def update(self, inputs: torch.Tensor):
# need to construct
return inputs
update(aggr_out, **kwargs)
inputs
인자로propagate()
에 전달한 초기 인자들도 이용 가능Implementing the GCN Layer
Semi-Supervised Classification with Graph Convolutional Networks
Steps
"add"
aggregation).Source Code
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index #출발, 도착 노드 분리
#도착노드에 대해 node 등장횟수 count
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
out = self.propagate(edge_index, x=x, norm=norm)
# Step 6: Apply a final bias vector.
out += self.bias
return out
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
GCNConv
“add” propagation을 사용한 MessagePassing을 상속받음
foward()
[torch_geometric.utils.add_self_loops()]
: https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.add_self_loops[torch.nn.Linear]
: https://pytorch.org/docs/master/generated/torch.nn.Linear.html#torch.nn.Linear[num_edges, ]
크기의 tensor norm
outputtorch_geometric.utils.degree
propagate()
→ 내부적으로 message()
, aggregate()
, update()
propagate()
: node embedding을 업데이트하고 message를 구성하기 위한 edge index등 다양한 추가 정보를 가져옴x
, the normalization coefficients norm
를 추가 전달message()
x_j
by norm
x_j
: lifted tensor, 각 엣지의 source node feature를 포함x_i
: 각 엣지의 target node feature를 포함Implementations
https://github.com/sujinyun999/PytorchGeometricTutorial/blob/main/Tutorial3/Tutorial3.ipynb