본문 바로가기
스터디/인공지능, 딥러닝, 머신러닝

GAT, GraphSAGE

by 궁금한 준이 2023. 4. 19.
728x90
반응형

GAT, Graph Attention Networks

Idea

GCN의 경우, 모든 이웃 노드들로부터 동일한 가중치를 갖는다.

GCN has same weights
Same Weight in GCN

그렇지만 경우에 따라 이웃 노드를로부터 얻는 정보의 가중치가 다를 수 있다. (많은 경우가 그럴 것이다.)

이제 이웃 노드들로부터 얻는 임베딩 $\mathbf{h}_u^{l}$에 가중치 $\alpha{vu}$를 곱한다.

The concept of attention weight
Attention weight

이 수식은 GCN을 보다 일반화 한것으로 볼 수 있다.

GCN이라면 $\alpha{vu} \cfrac{1}{|N(v)|}$ 으로 간주할 수 있다.

Not all neighbors are equally important !

Computing the attention weight

노드 $u$가 노드 $v$에 메시징을 할 때, 그 중요도를 $e_{vu}$라 하고 attention weight를 $\alpha_{vu}$라 하자.

아래 그림을 통해, 어떻게 두 노드 $A, B$의 importance $e_{AB}$를 계산하는지 이해할 수 있다.

How to compute the importance
Computing importance

 

이 때 attention weight는 importance를 정규화한 것으로 해석하여 다음과 같은 공식을 만들 수 있다.

The definition of attention weight
Computing the attention weight

Defining the function $a(\cdot)$

간단한 형태의 $a(\cdot)$는 Concat과 Linear를 차례로 합성한 함수일 것이다.

A simple implementation of function 'a'
Simple implementation of $a$

또한, Transformer의 multi-head attention의 아이디어를 차용하여 GAT에도 적용할 수 있다.

Multi-head attention in GAT
Multi-head attention in GAT

 

GraphSAGE

Generalization of GCN, GAT

다시 GCN와 GAT를 review하면 다음과 같다.

GCN networkGAT network
The aggregating messages in GCN and GAT

 

이웃 노드로부터 받는 message를 가중치를 주면서 aggregate를 하는 일반적인 방법을 구현해보자.

Aggregating messages in GraphSAGE
Aggregating messages in GraphSAGE

$\text{AGG}$ 함수로는 다양한 방법이 가능하다. 단, 이웃노드로부터 메시지를 받을 때 순서에 영향이 있으면 안되므로 symmetric한 함수를 이용한다. GraphSAGE가 제안된 논문에서는 아래 3가지 함수를 AGG로 사용하기를 제안한다.

  • Mean
    • GCN과 동일하다.
    • $\text{AGG}=\sum_{u \in N(v)}\cfrac{\mathbf{h}_u^{(l)}}{|N(v)|}$
  • Pool
    • 이웃 벡터를 변환하여 symmetric vector function을 으로 적용
    • $\text{AGG} = \gamma \left( \bigl\{ \text{MLP}(\mathbf{h}_u^{(l)}) | u \in N(v) \bigr\} \right)$.
    • $\gamma$는 element-wise 연산을 하는 mean, max, min 함수가 가능하다.
  • LSTM
    • shuffled neighbors 에 사용
    • $\text{AGG} = \text{LSTM}\left( \bigl[ \mathbf{h}_u^{(l)} | u \in \pi(N(v)) \bigr] \right)$

 

GNN Layer Overview

GNN Layer는 message와 aggregation 2가지를 갖는다.

  • Message
    • $\mathbf{m}_u^{(l)} = \text{MSG}^{(l)}(\mathbf{h}_u^{(l)})$
    • $\text{MSG}$로는 간단하게 linear layer($\mathbf{W_l}$)가 가능하다. 
  • Aggregation
    • $\mathbf{h}_v^{(l)} = \text{AGG}\left( \bigl\{ \mathbf{m}_u^{(l)} | u \in N(v) \bigr\} \right)$
    • $\text{AGG}$로는 $\text{SUM}$, $\text{MEAN}$, $\text{MAX}$ 등이 가능하다.

여러개의 벡터 집합을 하나의 벡터로 압축한다. (GCN의 경우, 이웃 노드로부터 여러 임베딩 벡터를 받아 하나의 임베딩 벡터를 반환)

다양한 GNN layer가 존재하며, GCN, GAT, GraphSAGE 등이 있다.

Different types of GNN layers
Different types of GNN layers

 

728x90
반응형