본 논문은 추천 시스템 관련 논문은 아니고 GraphSAGE가 처음 제안된 논문이다. 본 논문을 정리한 이유는 본 논문 이후에 바로 다룰 모델인 PinSAGE(GraphSAGE를 Pinterest의 Task에 맞춰서 커스텀한 모델)를 조금 더 잘 이해하기 위해서 이다.

Introduction

  • 저차원의 Vector로 매우 큰 그래프 안에 노드를 Embedding 하는 것은 매우 유용함 (노드 예측, 링크 예측, 군집화 등 그래프 분석과 관련된 Task에 feature로 사용할 수 있음)
  • 본 논문이 나오기 이전에 논문들은 전체 그래프를 가지고 노드를 임베딩함(노드의 Embedding을 먼저 구하고 해당 Embedding을 학습 시키는 방식, Embedding Look Up 방식)
  • 그래서 새로운 노드가 생기면 해당 노드를 임베딩할 수 없다는 문제점이 존재함
  • 실제 현실에서는 새롭게 생기는 노드나 그래프를 빠르게 임베딩하는 것이 중요함(새롭게 생긴 노드를 빠르게 현재 Task에 활용하기 위해서)
  • 따라서 본 논문에서는 GraphSAGE(SAmple and aggreGatE)라는 모델을 사용하여, 임베딩을 학습하는 것이 아닌, node의 정보를(degree, text 등) input으로 하여 임베딩을 생성함 (새롭게 생긴 노드의 Embedding을 빠르게 생성할 수 있음)

  • GraphSAGE는 위의 그림과 같은 형식으로 이루어졌고, 비지도학습으로 모델이 학습됨
  • 우선 임베딩을 하고자 하는 Target 노드의 이웃을 샘플링 한 후, 이웃 노드의 정보를 Aggregate하여 Target 노드의 Embedding을 생성함
  • GraphSAGE의 핵심은 비지도학습을 위한 Loss function, Target 노드의 이웃을 샘플링, 이웃 노드의 정보를 Aggregate 하는 것임

Related work

  • Factorization-based embedding approaches / Supervised learning over graphs / Graph convolutional networks 등의 방식은 새로운 노드에 대한 임베딩을 구할 수 없거나, 존재하는 노드에 대한 임베딩을 표현하거나, full graph Laplacian을 알아야 모델을 학습시킬 수 있다는 단점이 존재함

Proposed method: GraphSAGE

  • GraphSAGE는 이웃 노드의 정보를 aggregate하는 방식으로 모델이 학습됨
  • 학습된 GraphSAGE는 새로운 노드에 대한 임베딩을 생성할 수 있음
  • GraphSAGE는 SGD와 backpropagation을 사용하여 파라미터를 업데이트 함

1) Embedding generation (i.e., forward propagation) algorithm

  • Embedding generation을 algorithm은 위와 같은 수도 코드로 이루어짐
  • 본 알고리즘의 직관은 각 반복 또는 검색 깊이에서 노드가 정보를 계속 합치면서, 노드의 임베딩을 구하는 것임
  • 우리는 노드를 Embedding 하기 위해서 노드 x feature로 이루어진 행렬이 필요함
  • 우리는 이웃 노드 정보를 aggregate하는 K개의 AGGREGAE가 필요함
  • 우리는 전파된 정보와 이전 레이어의 정보를 합치는 K개의 W가 필요함
  • 여기서 K는 Target 노드를 임베딩 하기 위한 탐색의 깊이를 의미함(k-hop의 개념, k가 1이면 Target 노드를 중심으로 이웃 노드만 사용하고, k가 2라면 Target 노드를 중심으로 이웃 노드에 이웃 노드 까지 사용함)
  • K는 0에 노드 임베딩은 노드의 기존 feature로 설정하고, 그 후 K-1의 노드 정보를 AGGREGATE 하여 노드의 feature를 구하고, 해당 레이어의 노드 정보를 K-1 노드 정보와 concat 후 비선형함수를 통과하여 현재 K에 해당하는 노드의 Embedding을 얻고, Embedding을 정규화하고,(값이 너무 커지는 것을 방지) 이와 같은 과정을 K번 반복 해줌
  • GraphSAGE 알고리즘은 Weisfeiler-Lehman Isomorphism Test에 영감을 받아서 만들어졌다고 함
  • GraphSAGE 알고리즘은 학습 시에 고정된 크기의 이웃 set을 샘플링하여 사용함(고정된 크기의 subgraph를 만들었다는 것)
  • K가 2라면 각 Target 노드를 중심으로 이웃 노드에 이웃 노드 까지 사용한 것이고, 각 k에 위치하는 노드의 수는 자기가 직접 설정하는 것임(각 k 마다의 이웃 노드의 수가 많아지면 그 만큼 그래프의 크기가 커지기 때문에 모델의 학습 속도는 느려질 것)

2) Learning the parameters of GraphSAGE

  • GraphSAGE는 비지도학습을 위하여 위와 같은 Loss function을 사용함
  • 우선 내가 이해한 바로는 Target 노드의 실제 이웃만으로 sampling된 pos set과 Target 노드와 이웃하지 않은 노드 만으로 sampling된 neg set이 필요함
  • 그 후 pos set의 노드들을 Embedding, neg set의 노드들을 Embedding 함
  • 그리고 pos set을 이용하여 Embedding 된 노드들(식에서는 v)은 Target 노드의 이웃이기 때문에 pos set으로 임베딩된 Target 노드(식에서는 u)와 서로 가깝게, neg set을 이용하여 Embedding 된 노드들은 Target 노드의 이웃이 아니기 때문에 pos set으로 임베딩된 Target 노드(식에서는 u)와 서로 멀어지게 학습됨
  • 즉, pos 노드와는 유사하게, neg 노드와는 유사하지 않게 학습 하는 것이 본 loss의 목표라고 할 수 있음

3) Aggregator Architectures

  • Aggregator Architectures는 이웃 노드의 정보를 집계하는 방식임

  • 위와 같이 각 노드의 정보를 취합 후에 각 노드의 degree로 나눠 평균 정보를 구한 후 비선형 변환을 하는 Mean aggregator가 있음

  • 또는 각 노드의 정보를 LSTM을 사용하여 합취는 LSTM aggregator도 있음, 그러나 본질적으로 이웃 노드의 순서는 정렬되지 않기 때문에 LSTM을 사용하는 것은 적합하지 않을 수 있음(아직 논문은 읽지 않았지만 이러한 이유 때문에 Attention을 사용하는 GAT가 나오지 않았을까 생각됨)

  • 또는 위와 같이 각 노드의 기능 마다 제일 큰 값을 사용하는 MaxPooling aggregator도 있음
  • 그런데 본 논문에서는 Mean aggregator와 MaxPooling aggregator 사이에 큰 차이가 없었다고 함 (나는 구현이 좀 더 쉬운 Mean aggregator를 사용하여 본 모델을 구현함)

  • aggregator를 구현하는 것은 본 논문의 수식만으로는 이해하기 어려워서 설명하자면, 위처럼 노드들의 feature 행렬과 subgraph를 인접 행렬로 나타낸 행렬을 서로 행렬곱 하면 이웃 노드의 정보를 취합하는 aggregator를 구현할 수 있음(GCN 논문을 보면 이해해하기 쉬울 것임)

Experiments

  • 본 모델을 citation dataset, Reddit dataset, PPI dataset을 가지고 평가를 함

  • 학습 시에는 batches size는 512로 설정하고, K=2, S1=25, S2=10으로 설정하여 모델을 학습 시킴
  • 학습 후 Embedding을 생성할 때는 전체 그래프를 사용하여 노드의 임베딩을 생성함(나는 MovieLens 데이터를 사용했기 때문에 노드 임베딩 생성시에도 subgraph를 생성하여 임베딩을 생성함. 왜냐하면 MovieLens 자체가 이분 그래프(유저 - 영화)라서 전체 그래프로 표현하기 애매했기 때문임. 추후에 MultiSAGE나 PinSAGE 같은 경우는 이분 그래프를 사용할 수 있지 않을까 생각됨)

Conclusion

  • 학습시에 존재하지 않았던 노드에 대해서 임베딩을 구할 수 있는 방법을 제안함
  • 이웃 노드를 샘플링하여 학습하는 GraphSAGE는 학습 속도도 빠르며 성능도 좋아 다른 베이스라인 모델을 뛰어넘는 SOTA 모델이다.

Code

Reference

profile
Machine Learning Engineer at Konan Technology

0개의 댓글

관련 채용 정보

Powered by GraphCDN, the GraphQL CDN