Variational Graph Auto-Encoder(VGAE) 논문 요약 및 코드 리뷰

정리용 블로그·2024년 2월 1일
0

GNN

목록 보기
3/8
post-thumbnail

VGAE paper

VGAE는 GCN에 VAE를 섞는 느낌으로 만들어진 모델이다.
GCN을 encoder로 삼고 inner product를 decoder로 삼는다.

모델의 요소는 아래와 같다.

  • 그래프 G=(V,E)
  • N = |V|
  • adjacency matrix를 A로 둔다.(A의 대각선 요소는 모두 1로 둔다. self-connection)
  • D를 degree matrix로 둔다.
  • stochastic latent variables ziz_iN×FN \times F matrix Z를 만든다.
  • Node feature은 N×DN \times D matrix X에 들어있다.

Inference model

Inference 모델은 two-layer GCN으로 파라미터화한다.
q(ZX,A)=i=1Nq(ziX,A) with q(ziX,A)=N(ziμi,diag(σ2))q(Z|X,A)=\prod^N_{i=1}q(z_i|X,A)\ with\ q(z_i|X,A) =\mathcal{N}(z_i|\mu_i, diag(\sigma^2))
여기서 μ=GCNμ(X,A)\mu = GCN_{\mu}(X, A) 이고 σ=GCNσ(X,A)\sigma = GCN_{\sigma}(X, A)이다.
GCN(X,A)=A~ReLU(A~XW0)W1GCN(X,A) = \tilde A ReLU(\tilde A X W_0)W_1 이고, GCNμGCN_\muGCNσGCN_\sigma는 첫번째 layer 파라미터 W0W_0를 공유한다.

Generative model

Generative model은 위에서 설명한바와 같이 latent variables의 inner product로 주어진다.
p(AZ)=i=1Nj=1Np(Aijzi,zj) with p(Aij=1zi,zj)=σ(zizj)p(A|Z) = \prod^N_{i=1}\prod^N_{j=1}p(A_{ij}|z_i, z_j)\ with\ p(A_{ij}=1|z_i, z_j) = \sigma(z_i^\intercal z_j)

Loss function

L=Eq(ZX,A)[log p(AZ)]KL[q(ZX,A)P(Z)]\mathcal{L} = \mathbb{E}_{q(Z|X,A)}[log\ p(A|Z)] - KL[q(Z|X,A) || P(Z)]
여기서 P(Z)=ip(zi)=iN(zi0,I)P(Z) = \prod_i p(z_i) = \prod_i \mathcal{N} (z_i | 0,I)

GAE model

non-probabilisitc 모델은 embeddings Z와 A^\hat A를 계산한다.
A^=σ(ZZ) with Z=GCN(X,A)\hat A = \sigma(ZZ^\intercal)\ with \ Z=GCN(X,A)

그저 VAE를 GCN과 접목시켰을 뿐인 간단한 논문인데 이게 잘 된다는게 신기하다.

VGAE pytorch
로 코드리뷰를 한다.

class VGAE(nn.Module):
	def __init__(self, adj):
		super(VGAE,self).__init__()
		self.base_gcn = GraphConvSparse(args.input_dim, args.hidden1_dim, adj)
		self.gcn_mean = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
		self.gcn_logstddev = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)

모델을 먼저 보면 앞에서 나왔듯 gcn layer 하나를 mean과 stddev가 공유하고, 그 후 2개로 나뉘어져 있는 것을 볼 수 있다.

def encode(self, X):
	hidden = self.base_gcn(X)
	self.mean = self.gcn_mean(hidden)
	self.logstd = self.gcn_logstddev(hidden)
	gaussian_noise = torch.randn(X.size(0), args.hidden2_dim)
	sampled_z = gaussian_noise*torch.exp(self.logstd) + self.mean
	return sampled_z

그리고 이를 통해 나온 mean과 stddev를 이용해
q(ZX,A)=i=1Nq(ziX,A) with q(ziX,A)=N(ziμi,diag(σ2))q(Z|X,A)=\prod^N_{i=1}q(z_i|X,A)\ with\ q(z_i|X,A) =\mathcal{N}(z_i|\mu_i, diag(\sigma^2))
이 부분을 샘플링 하는 것을 확인 가능하다.

def dot_product_decode(Z):
	A_pred = torch.sigmoid(torch.matmul(Z,Z.t()))
	return A_pred

def forward(self, X):
	Z = self.encode(X)
	A_pred = dot_product_decode(Z)
	return A_pred

p(AZ)=i=1Nj=1Np(Aijzi,zj) with p(Aij=1zi,zj)=σ(zizj)p(A|Z) = \prod^N_{i=1}\prod^N_{j=1}p(A_{ij}|z_i, z_j)\ with\ p(A_{ij}=1|z_i, z_j) = \sigma(z_i^\intercal z_j)
위의 수식 대로 Z의 inner product를 이용해 A의 pred를 생성하는 것을 확인 가능하다.

이 모델을 이용해 train을 하는 과정을 살펴보면

def load_data(dataset):
    # load the data: x, tx, allx, graph
    names = ['x', 'tx', 'allx', 'graph']
    objects = []
    for i in range(len(names)):
        with open("data/ind.{}.{}".format(dataset, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))
    x, tx, allx, graph = tuple(objects)
    test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset))
    test_idx_range = np.sort(test_idx_reorder)
    
	features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    return adj, features


github 코드에 내장되어있던 data 파일을 pickle을 이용해 불러와 adj matrix와 features로 만드는 것을 볼 수 있다.

adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
adj_orig.eliminate_zeros()

adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train
adj_norm = preprocess_graph(adj)

이 코드는 adj_orig를 만들어 adj에서 대각선을 제거한 matrix를 넣어놓는다.
이 후 test를 위해 mask_test_edges를 이용해 train, val, test 데이터로 나누고 모델에 들어가는 데이터와 gt 데이터로 나눠준다.
이후 preprocess_graph를 이용해서 A~=D12AD12\tilde A = D^{-\frac{1}{2}}AD^{-\frac{1}{2}}를 계산한다.

norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
A_pred = model(features)
optimizer.zero_grad()
loss = log_lik = norm*F.binary_cross_entropy(A_pred.view(-1), adj_label.to_dense().view(-1), weight = weight_tensor)
if args.model == 'VGAE':
    kl_divergence = 0.5/ A_pred.size(0) * (1 + 2*model.logstd - model.mean**2 - torch.exp(model.logstd)**2).sum(1).mean()
    loss -= kl_divergence

이후 epoch만큼 model에 features를 넣고 loss를 구한다.
L=Eq(ZX,A)[log p(AZ)]KL[q(ZX,A)P(Z)]\mathcal{L} = \mathbb{E}_{q(Z|X,A)}[log\ p(A|Z)] - KL[q(Z|X,A) || P(Z)]

0개의 댓글