Graph Neural Network - (1)

Suho Cho·2021년 4월 5일
0

Deep Learning

목록 보기
1/7

Graph Neural Network는 뭘까?

이 글 번역을 하자
Graph Neural Network(GNN)는 그래프 구조에 직접적으로 동작하는 네트워크입니다. 가장 잘 알려진 Task로는 node classification이 있는데, 기본적으로 그래프의 모든 노드가 label을 갖고 있으며 우리는 이걸 예측하는 작업을 말한다. 이 섹션에서는 논문에서 소개된 알고리즘에 대해 설명한다.

Node classification 문제

현재 그래프: Partially labeled graph G\mathbf{G}
조건: 각 노드 vv는 feature xv\mathbf{x}_v와 ground-truth label tv\mathbf{t}_v로 이루어져 있다.
목표: Unlabeled node의 label prediction
학습 대상: 각 노드의 embedding hvRs\mathbf{h}_v\in\mathbb{R^s}가 각 노드의 이웃들의 정보를 포함하도록
Notation

  • hv=f(xv,xco[v],hne[v],xne[v])\mathbf{h}_v = f(\mathbf{x}_v, \mathbf{x}_{co[v]}, \mathbf{h}_{ne[v]}, \mathbf{x}_{ne[v]})
  • xco[v]\mathbf{x}_{co[v]}: v와 연결된 edges의 features
  • hne[v]\mathbf{h}_{ne[v]}: v의 이웃노드(ne[v]ne{[v]})의 embeddings (states)
  • xne[v]\mathbf{x}_{ne[v]}: v의 이웃노드(ne[v]ne{[v]})의 features
  • ff: Input을 d차원으로 projection 시키는 Transition function

우리가 알고 싶은건 hv\mathbf{h}_v이고 Banach Fixed Point Theorem을 적용해서 위의 식을 iteratively update process로 바꿔 쓸 수 있다 (이 정리는 몇가지 가정 하에 iterative 연산을 하다보면 f(x)=xf(x) = x가 되어 수렴하게 된다는 정리이다). 이 연산을 message passing 또는 neighborhood aggregation이라고 한다.

Ht+1=F(Ht,X)\mathbf{H}^{t+1} = F(\mathbf{H}^{t}, \mathbf{X}) (H\mathbf{H}X\mathbf{X}는 각각 모든 hv\mathbf{h}_vxv\mathbf{x}_v를 concatenation시킨 행렬이다)

GNN의 결과는 output function gghv\mathbf{h}_{v}(state)와 xv\mathbf{x}_{v}(feature)를 넣고 나오는 결과 ov\mathbf{o}_{v}다.
ffgg를 가만히 보면 feed-forward fully-connected Network라고 할 수 있고 loss는 목적에 따라 다르겠지만 L1L_1을 사용할 수 있다. (classification이지만 L1L_1을 사용하나?)
loss=i=1p(tioi)loss = \sum_{i=1}^{p}(\mathbf{t}_i - \mathbf{o}_i)

논문의 주요 단점

  1. Fixed point라는 가정을 조금 느슨하게 하면(relaxed), 그냥 Multi-layer Perceptron(MLP)이 더 낫지 않나? iterative process가 없어도 말이다. 왜냐하면 논문에서 iteration 과정에서 똑같은 parameter를 ff를 통해 업데이트 한다고 설명했는데, MLP에서는 다른 parameter를 각기 다른 layer에서 학습하니까 hierarchical learning을 할 수 있는 것 아닌가? 그러면 더 좋은거 아닐까?
  2. Edge의 정보는 업데이트 안한다. (Knowledge graph에서 서로 다른 edge는 각 노드 사이의 다양한 관계를 표현할텐데 말이다)
  3. Fixed point라는 가정이 node의 분포 다양성을 저해할 수 있다. 그리고 이 방법이 node representation을 학습하는데 적합하지도 않아 보인다.

GNN의 주요 장점

  1. Various size of input graph. CNNs과 RNNs은 그래프 형태의 입력을 해석하기 어려운 구조를 가지고 있다. 왜냐하면 CNNs은 인접한 데이터와의 연산을 지원하고 RNNs은 데이터를 순차적으로 해석하기 때문에 데이터의 구조적인 문제를 해결하기 어렵다. (특히 convolution은 온세상을 네모난 grid로 해석하니까 재미가 없다)
  2. Node relativity. GNNs은 edges를 통해 임의의 두 노드 사이의 관계를 직접적으로 추론할 수 있다.
  3. Reasoning. 뇌가 여러 대상 사이의 관계를 추론하는 방식을 직관적으로 표현했다.

나올 수 있는 질문들

  1. 임의의 두 노드가 서로 이웃인지 어떻게 알아요?
  2. Graph classification은 이해했는데, node classification을 할 때는 inductive하게 풀 수 없나요? (그래프 전체 말고 국지적 구조만을 가지고 학습할 수는 없나)
  3. Node feature는 어떻게 얻어요?
  4. Node의 성질이 서로 다른 경우는 어떻게 그래프를 구성하나요?
    (Molecule structure 그래프가 있다. 모든 node는 atom이라는 점이 같다. 반면, 영화-유저 리뷰 그래프가 있다고 하자. node는 영화와 유저 두 종류로 구분된다.)
  5. 또 뭐가 있을까

다음 편으로 이어진다

profile
당신을 한 줄로 소개해보세요

0개의 댓글