GNN(PyG tutorial)

용권순·2021년 12월 24일
0

논문

목록 보기
1/12
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
import torch_geometric
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn.inits import uniform
from torch.nn import Parameter as Param
from torch import Tensor 
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from torch_geometric.nn.conv import MessagePassing
dataset = 'Cora'
transform = T.Compose([
    T.RandomNodeSplit('train_rest', num_val=500, num_test=500),
    T.TargetIndegree(),#대상 노드의 전역 normalized 차원을 저장합니다.
])
path = osp.join('data', dataset)
dataset = Planetoid(path, dataset, transform=transform)# dataset을 정의합니다. 
data = dataset[0]#slicing을 통해서 그래프가 아닌 하나의 node를 가지고 올 수 있습니다. 
# T.AddTrainValTestMask('train_rest', num_val=500, num_test=500),
# AddTrainValTestMask는 이름이 바뀌어서 사용할 수 없음 
print(dataset)#cora()자체가 하나의 그래프  
print(len(dataset)) #
print(dataset.num_classes)#클래스 수 그래프가 아니라 node
print(dataset.num_node_features)#1433개의 노드 특성 
print(data)#slicing 을 통해 그래프가 아닌 노드 하나를 가져옵니다.
print()
print(data.is_undirected())# undirected node인가? 
print()
print(data.train_mask.sum().item())# : 학습하기 위해 사용하는 노드들을 가리킴
print()
print(data.val_mask.sum().item())#: 검증 시 사용하는 노드들을 가리킴
print()
print(data.test_mask.sum().item())#: 테스트 시 사용하는 노드들을 가리킴
class MLP(nn.Module): #FNN을 생성하는 module입니다. 
    def __init__(self, input_dim, hid_dims, out_dim):
        super(MLP, self).__init__()

        self.mlp = nn.Sequential()
        dims = [input_dim] + hid_dims + [out_dim]
        for i in range(len(dims)-1):
            self.mlp.add_module('lay_{}'.format(i),nn.Linear(in_features=dims[i], out_features=dims[i+1]))
            if i+2 < len(dims):
                self.mlp.add_module('act_{}'.format(i), nn.Tanh())
    def reset_parameters(self):
        for i, l in enumerate(self.mlp):
            if type(l) == nn.Linear:
                nn.init.xavier_normal_(l.weight)

    def forward(self, x):
        return self.mlp(x)
class GatedGraphConv(MessagePassing):
    
    def __init__(self, out_channels, num_layers, aggr = 'add',
                 bias = True, **kwargs):
        super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs)

        self.out_channels = out_channels #
        self.num_layers = num_layers

        self.weight = Param(Tensor(num_layers, out_channels, out_channels))
        self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias)
		#GNN의 rnn부분을 GRU를 사용하여 정의했습니다. 
        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.out_channels, self.weight)
        self.rnn.reset_parameters()

    def forward(self, data):
        """"""
        data= data.to(device) #해줘야 cuda로 보낼 수 있음         
        x = data.x
        edge_index = data.edge_index #edge 부분
        edge_weight = data.edge_attr #edge_weight부분 
        if x.size(-1) > self.out_channels: #예외 처리 
            raise ValueError('The number of input channels is not allowed to '
                             'be larger than the number of output channels')

        if x.size(-1) < self.out_channels:
            zero = x.new_zeros(x.size(0), self.out_channels - x.size(-1))
            x = torch.cat([x, zero], dim=1)

        for i in range(self.num_layers):
            m = torch.matmul(x, self.weight[i])
            m = self.propagate(edge_index, x=m, edge_weight=edge_weight,
                               size=None)
            x = self.rnn(m, x)

        return x

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        return '{}({}, num_layers={})'.format(self.__class__.__name__,
                                              self.out_channels,
                                              self.num_layers)

class GGNN(torch.nn.Module):
    def __init__(self):
        super(GGNN, self).__init__()
        
        self.conv = GatedGraphConv(1433, 3).to(device)
        self.mlp = MLP(1433, [32,32,32], dataset.num_classes).to(device)
        
    def forward(self):
        x = self.conv(data)
        x = self.mlp(x)
        return F.log_softmax(x, dim=-1)
model = GGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


test_dataset = dataset[:len(dataset) // 10] #다 가지고 오지 않고 부분적으로 진행합니다.
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset)
train_loader = DataLoader(train_dataset)

def train():
    model.train() #학습 
    optimizer.zero_grad() #gradient 계산 
    data.train_mask = data.train_mask.to(device) # data를 cuda로 보냄     
    criterion(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


def test():
    model.eval()#gradient를 학습하지 않습니다. 
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


for epoch in range(1, 51):
    train()
    accs = test()
    train_acc = accs[0]
    val_acc = accs[1]
    test_acc = accs[2]
    print('Epoch: {:03d}, Train Acc: {:.5f}, '
          'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,
                                                       val_acc, test_acc))

결과

Epoch: 001, Train Acc: 0.27143, Val Acc: 0.15800, Test Acc: 0.15400
Epoch: 002, Train Acc: 0.35000, Val Acc: 0.22200, Test Acc: 0.22200
Epoch: 003, Train Acc: 0.18571, Val Acc: 0.22400, Test Acc: 0.21000
Epoch: 004, Train Acc: 0.32857, Val Acc: 0.29600, Test Acc: 0.28300
Epoch: 005, Train Acc: 0.36429, Val Acc: 0.25000, Test Acc: 0.26700
Epoch: 006, Train Acc: 0.45000, Val Acc: 0.35000, Test Acc: 0.37200
Epoch: 007, Train Acc: 0.54286, Val Acc: 0.45800, Test Acc: 0.47100
Epoch: 008, Train Acc: 0.53571, Val Acc: 0.40800, Test Acc: 0.43000
Epoch: 009, Train Acc: 0.60000, Val Acc: 0.51000, Test Acc: 0.50700
Epoch: 010, Train Acc: 0.64286, Val Acc: 0.58800, Test Acc: 0.57500
Epoch: 011, Train Acc: 0.62143, Val Acc: 0.57600, Test Acc: 0.57700
Epoch: 012, Train Acc: 0.63571, Val Acc: 0.54000, Test Acc: 0.55100
Epoch: 013, Train Acc: 0.65714, Val Acc: 0.55000, Test Acc: 0.55200
Epoch: 014, Train Acc: 0.67143, Val Acc: 0.59200, Test Acc: 0.57100
Epoch: 015, Train Acc: 0.70714, Val Acc: 0.60800, Test Acc: 0.58400
Epoch: 016, Train Acc: 0.73571, Val Acc: 0.62400, Test Acc: 0.60400
Epoch: 017, Train Acc: 0.74286, Val Acc: 0.61200, Test Acc: 0.58800
Epoch: 018, Train Acc: 0.69286, Val Acc: 0.57000, Test Acc: 0.55700
Epoch: 019, Train Acc: 0.70714, Val Acc: 0.59400, Test Acc: 0.58700
Epoch: 020, Train Acc: 0.74286, Val Acc: 0.59600, Test Acc: 0.60400
Epoch: 021, Train Acc: 0.72143, Val Acc: 0.58600, Test Acc: 0.59000
Epoch: 022, Train Acc: 0.71429, Val Acc: 0.57800, Test Acc: 0.56200
Epoch: 023, Train Acc: 0.72857, Val Acc: 0.56800, Test Acc: 0.55000
Epoch: 024, Train Acc: 0.74286, Val Acc: 0.58200, Test Acc: 0.57100
Epoch: 025, Train Acc: 0.82143, Val Acc: 0.59600, Test Acc: 0.61300
Epoch: 026, Train Acc: 0.83571, Val Acc: 0.58800, Test Acc: 0.59400
Epoch: 027, Train Acc: 0.83571, Val Acc: 0.57800, Test Acc: 0.58000
Epoch: 028, Train Acc: 0.83571, Val Acc: 0.58000, Test Acc: 0.57600
Epoch: 029, Train Acc: 0.82143, Val Acc: 0.57800, Test Acc: 0.58200
Epoch: 030, Train Acc: 0.82857, Val Acc: 0.58600, Test Acc: 0.58000
Epoch: 031, Train Acc: 0.83571, Val Acc: 0.59200, Test Acc: 0.58500
Epoch: 032, Train Acc: 0.86429, Val Acc: 0.61000, Test Acc: 0.59800
Epoch: 033, Train Acc: 0.94286, Val Acc: 0.62000, Test Acc: 0.62400
Epoch: 034, Train Acc: 0.95714, Val Acc: 0.62400, Test Acc: 0.62800
Epoch: 035, Train Acc: 0.97143, Val Acc: 0.62200, Test Acc: 0.63600
Epoch: 036, Train Acc: 0.97857, Val Acc: 0.61400, Test Acc: 0.63400
Epoch: 037, Train Acc: 0.98571, Val Acc: 0.61200, Test Acc: 0.63700
Epoch: 038, Train Acc: 0.98571, Val Acc: 0.61000, Test Acc: 0.64200
Epoch: 039, Train Acc: 0.98571, Val Acc: 0.61200, Test Acc: 0.63400
Epoch: 040, Train Acc: 0.98571, Val Acc: 0.61800, Test Acc: 0.63200
Epoch: 041, Train Acc: 0.98571, Val Acc: 0.61600, Test Acc: 0.63400
Epoch: 042, Train Acc: 0.98571, Val Acc: 0.61000, Test Acc: 0.63300
Epoch: 043, Train Acc: 0.98571, Val Acc: 0.60800, Test Acc: 0.63300
Epoch: 044, Train Acc: 0.98571, Val Acc: 0.60800, Test Acc: 0.63300
Epoch: 045, Train Acc: 0.98571, Val Acc: 0.60600, Test Acc: 0.62900
Epoch: 046, Train Acc: 0.98571, Val Acc: 0.60200, Test Acc: 0.63000
Epoch: 047, Train Acc: 0.98571, Val Acc: 0.60200, Test Acc: 0.63100
Epoch: 048, Train Acc: 0.98571, Val Acc: 0.60400, Test Acc: 0.63300
Epoch: 049, Train Acc: 0.98571, Val Acc: 0.60400, Test Acc: 0.63300
Epoch: 050, Train Acc: 0.99286, Val Acc: 0.60400, Test Acc: 0.62800
profile
수학계산학부 석사생입니다.

0개의 댓글