본문 바로가기
스터디/인공지능, 딥러닝, 머신러닝

[CS224w] Relational GCN (RGCN)

by 궁금한 준이 2023. 8. 9.
728x90
반응형

 

Extending GCN to handle heterogeneous grphas

Message Passing on Directed Graph with one relation

그림과 같이 relation은 하나인 directed graph에 대하여 GCN을 적용해보자. 그렇다면 edge의 방향에 따라 message passing이 이루어지도록 설계하면 될 것이다.

 

Message Passing on Heterogeneous Graph with Multiple Relation Types

그렇다면 다양한 relation type이 존재한다면 어떻게 message passing을 할 것인가?

어쩔 수 없이 relation 마다 학습하는 weight가 다르게 설계를 한다.

(cs224w 강의자료에서는 보기 쉽게 같은 색으로 구분하였다)

즉, 각각의 relation type마다 다른 신경망을 적용하여 convlution을 구현할 수 있다.

Introduce a set of neural networks for each relation type!

Relational RCN

기존 GCN에서 relation type마다 weight를 계산하는 항만 추가하면 된다.

node degree를 $N(v)$ 대신에 $c_{v, r} = |N_v^r|$를 이용하였다.

 

RGCN: Definition

RGCN은 scaling issue가 존재한다.

$L$개의 layer가 있다고 하자. 그러면 각 relation마다 weight matrix가 있으므로 $\mathbf{W}_r^{(1)}, \mathbf{W}_r^{(2)}, \dots, \mathbf{W}_r^{(L)}$ 가 있게된다.

그리고 각 $\mathbf{W}_r^{(l)}$의 shape은 $d^{(l+1)} \times d^{(l)}$가 된다. ($d$는 hidden layer의 dimension)

파라미터가 굉장히 많아지므로 Overfitting이 될 가능성이 매우 높다.

크게 block diagonal matrix를 이용하거나 basis/dictionary learning을 이용하여 오버피팅을 방지한다.

Regularization 1: Block Diagonal Matrix

insight: weight를 sparse matrix로 만들어보자

Block Diagonal Matrix

 

$B$차원 행렬을 이용하게 되면 파라미터 개수는 $B \times \cfrac{d^{(l+1)}}{B} \times \cfrac{d^{(l)}}{B}$ 로 줄어든다.

파라미터 수가 줄어들면 학습 속도도 빨라지고, 오버피팅 가능성 역시 줄어든다.

 

그러나 블록행렬의 단점도 있다. 같은 블록에 있는 뉴런끼리는 W를 통해 상호작용할 수 있지만, 다른 블록에 있는 뉴런끼리는 그러지 못한다.

Regularization 2: Basis Learning

insight: 다른 relation끼리도 weight를 공유해보자

basis transformation의 linear combination으로 relation matrix를 나타내는 방법이다.

 

basis matrix(혹은 dictionary matrix라고도 하여 dictionary learning이라고도 부른다)를 $\mathbf{V}_b$라 하고, 각 basis마다 가중치(importance weight)를 $a_{rb}$라 하면 weight는 basis matrix의 가중평균으로 구한다

\[ \mathbf{W}_r = \sum_{b=1}^{B} a_{rb} \cdot \mathbf{V}_b \]

따라서 각 relation은 $\{ a_{rb} \}_{b=1}^{B}$ 만을 학습하면 된다. ($B$는 스칼라)

 

 

Tasks of RGCN

 

Input Graph

Entity/Node Classification

node A의 클래스를 k개 중에 어느 클래스에 속하는지 예측하는 모델을 학습할 수 있다.

final layer가 다음을 만족하도록 학습한다. $\mathbf{h}_A^{(L)} \in \mathbb{R}^k$ 

곧 $\mathbf{h}_A^{(L)}$은 클래스에 속할 확률을 나타낸다.

 

Link Prediction

illustration of link prediction

위 그림에서 두 노드 A와 E 사이에 edge가 있는지 예측해보자.

training supervision edge $(E, r_3, A)$가 있다고 가정하고, 다른 edge들을 training message edge로 간주한다.

그리고 RGCN으로 $(E, r_3, A)$의 점수(score)를 구한다.

  • RGCN으로 학습한 final layer의 차원을 $d$라 하자. 즉 $\mathbf{h}_A^{(L)}, \mathbf{h}_E^{(L)} \in \mathbb{R}^d$
  • relation-specific score function 정의하자 $f_r: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}$ 
    • 가장 단순한 예시로는 $f_{r_1}(\mathbf{h}_E, \mathbf{h}_A) = \mathbf{h}_E^T \mathbf{W}_{r_1} \mathbf{h}_A $ 와 같은 형태가 가능하다.

Training (내용 추가 필요)

  1. RGCN으로 supervision edge $(E, r_3, A)$ 학습
  2. negative edge 생성
  3. GNN으로 negative edge 점수 계산
  4. Cross Entropy Loss로 최적화
    1. maximize: training supervision edge
    2. minimize: negative edge

Evaluation (내용 보충 필요)

이제 $(E, r_3, D)$에 대해 검증을 해보자.

Evaluation on Link Prediction

  1. $(E, r_3, D)$의 score 계산
  2. negative edge에 대하여 score 계산
    1. negative edges: $\{ (E, r_3, v) \vert v \in \{ B, F\} \}$이다. 왜냐하면 $(E, r_3, A)$와 $(E, r_3, C)$는 실제로 training message edge 또는 training supervision edge이기 때문
  3. $(E, r_3, D)$의 ranking RK 계산
  4. metric 계산
    1. Hits@k: (높을수록 좋음)
    2. Reciprocal Rank: $\cfrac{1}{RK}$. (높을수록 좋음)

Benchmark for Heterogeneous Graphs

oggn-mag from Microsoft Academic Graph (MAG)

Illustrate of MAG Data

SOTA method: SeHGNN

ComplEx와 Simplified GCN을 이용하여 구현한 모델이 sota라고 한다. (이 두 방법론은 이후 강의에서 소개됨)

Summary of RGCN

Relation GCN은 heterogeneous graph에 사용할 수 있는 GNN이다.

entity classification과 link prediction 모두 성능이 좋다.

아이디어들이 확장가능하다. (RGNN, RGraphSage, RGAT 등)

 

728x90
반응형