GAT, Graph Attention Networks
Idea
GCN의 경우, 모든 이웃 노드들로부터 동일한 가중치를 갖는다.
그렇지만 경우에 따라 이웃 노드를로부터 얻는 정보의 가중치가 다를 수 있다. (많은 경우가 그럴 것이다.)
이제 이웃 노드들로부터 얻는 임베딩 $\mathbf{h}_u^{l}$에 가중치 $\alpha{vu}$를 곱한다.
이 수식은 GCN을 보다 일반화 한것으로 볼 수 있다.
GCN이라면 $\alpha{vu} \cfrac{1}{|N(v)|}$ 으로 간주할 수 있다.
Not all neighbors are equally important !
Computing the attention weight
노드 $u$가 노드 $v$에 메시징을 할 때, 그 중요도를 $e_{vu}$라 하고 attention weight를 $\alpha_{vu}$라 하자.
아래 그림을 통해, 어떻게 두 노드 $A, B$의 importance $e_{AB}$를 계산하는지 이해할 수 있다.
이 때 attention weight는 importance를 정규화한 것으로 해석하여 다음과 같은 공식을 만들 수 있다.
Defining the function $a(\cdot)$
간단한 형태의 $a(\cdot)$는 Concat과 Linear를 차례로 합성한 함수일 것이다.
또한, Transformer의 multi-head attention의 아이디어를 차용하여 GAT에도 적용할 수 있다.
GraphSAGE
Generalization of GCN, GAT
다시 GCN와 GAT를 review하면 다음과 같다.
이웃 노드로부터 받는 message를 가중치를 주면서 aggregate를 하는 일반적인 방법을 구현해보자.
$\text{AGG}$ 함수로는 다양한 방법이 가능하다. 단, 이웃노드로부터 메시지를 받을 때 순서에 영향이 있으면 안되므로 symmetric한 함수를 이용한다. GraphSAGE가 제안된 논문에서는 아래 3가지 함수를 AGG로 사용하기를 제안한다.
- Mean
- GCN과 동일하다.
- $\text{AGG}=\sum_{u \in N(v)}\cfrac{\mathbf{h}_u^{(l)}}{|N(v)|}$
- Pool
- 이웃 벡터를 변환하여 symmetric vector function을 으로 적용
- $\text{AGG} = \gamma \left( \bigl\{ \text{MLP}(\mathbf{h}_u^{(l)}) | u \in N(v) \bigr\} \right)$.
- $\gamma$는 element-wise 연산을 하는 mean, max, min 함수가 가능하다.
- LSTM
- shuffled neighbors 에 사용
- $\text{AGG} = \text{LSTM}\left( \bigl[ \mathbf{h}_u^{(l)} | u \in \pi(N(v)) \bigr] \right)$
GNN Layer Overview
GNN Layer는 message와 aggregation 2가지를 갖는다.
- Message
- $\mathbf{m}_u^{(l)} = \text{MSG}^{(l)}(\mathbf{h}_u^{(l)})$
- $\text{MSG}$로는 간단하게 linear layer($\mathbf{W_l}$)가 가능하다.
- Aggregation
- $\mathbf{h}_v^{(l)} = \text{AGG}\left( \bigl\{ \mathbf{m}_u^{(l)} | u \in N(v) \bigr\} \right)$
- $\text{AGG}$로는 $\text{SUM}$, $\text{MEAN}$, $\text{MAX}$ 등이 가능하다.
여러개의 벡터 집합을 하나의 벡터로 압축한다. (GCN의 경우, 이웃 노드로부터 여러 임베딩 벡터를 받아 하나의 임베딩 벡터를 반환)
다양한 GNN layer가 존재하며, GCN, GAT, GraphSAGE 등이 있다.
'스터디 > 인공지능, 딥러닝, 머신러닝' 카테고리의 다른 글
[CS224w] Prediction with GNNs (0) | 2023.04.26 |
---|---|
[CS224w] Graph Augmentation (0) | 2023.04.24 |
[GCN] Graph Convolutional Network (0) | 2023.04.18 |
[논문리뷰] TimesNet, Temporal 2D-Variation Modeling for General Time Series Analysis (0) | 2023.04.08 |
[CS224w] Colab 2 - PyG, OGB, GNN (0) | 2023.03.21 |