Deep Graph Infomax (DGI) [ICLR 2019]
Abstract
DGI는 node representation을 unsupervised manner로 얻는 일반적인 방법론이다.
DGI는 patch representation과 graph summary의 mutual information을 최대화하는 방법으로 학습한다.
patch representation은 관심이 되는 node의 subgraph의 summary를 얻기 때문에 node-wise task로 downstream하여 적용할 수 있다.
과거의 GCN기반 unsupervised 방법과 달리, DGI는 random walk에만 의존하지 않기 때문에 transductive와 inductive learning setup 모두 적용할 수 있다.
1. Introduction
과거 graph에서 unsupervised learning은 random walk 기반 방법론이었다. (node2vec 등)
GCN, GAT 등의 graph convolutional 방법론은 성공적이지만 supervised 방법론이기 때문에 실제 unlabeled한 graph data에는 직접 적용하기 어려웠다.
graph-structured data에서 unsupervised learning에서 가장 주된 방법론은 random walk 기반 방법이다. (node2vec, Deepwalk, LINE 등) 그러나 random walk의 경우 proximity information을 over-emphasize하고, hyperparameter가 성능에 매우 영향을 주는 두가지 문제가 있다.
이 논문은 (random walk 대신에) mutual information(상호의존정보) 기반한 unsupervised graph learning을 제안한다. 최근에 MINE, DeepInfomax(DIM) 등에서 mutual information을 이용한 방법이 성공적이었다. DIM의 아이디어를 graph로 적용하여 Deep Graph Infomax (DGI)라는 방법을 제안한다.
2. Related Work
Contrastive methods unsupervised learning에서 가장 중요한 접근은 contrasitve이다. 주로 positive example과 negative sample의 scoring function을 학습하도록 한다. DGI에서는 local-global pair와 negative-sample을 이용한다.
Sampling strategies contrastive learning에서 중요한 부분은 positive/negative sample을 어떻게 얻을 것인가이다. positive sampling의 경우, short random walk 또는 node-anchored sample이 사용되었다. negative sampling의 경우, curriculum-based 또는 adversary 방법이 있다.
Predictive coding contrastive predictive coding (CPC)는 mutual information maximization에서 사용되는 방법이다. 그러나 CPC는 predictor가 input을 예측하도록 학습한다. 우리는 CPC와는 다르고, global/local part을 동시에 contrast할 것이다.
3. DGI Methodology
$\mathbf{X} = \{x_1, \dots, x_N \} $: node feature, $N$은 노드의 개수, $x_i \in \mathbb{R}^F$는 F차원의 i번째 node의 feature.
$\mathbf{A} \in \mathbb{R}^{N \times N}$: adjacency matrix. $A$는 임의의 실수로 이루어질 수 있지만, 우리 실험에서는 unweighted graph를 사용하였다. (연결되면 1, 연결되지 않으면 0)
$\mathcal{E}: \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times F'} \to \mathbb{R}^{N \times F'}$
우리가 학습시킬 encoder이다. 수식으로 표현하면 $\mathcal{E}(\mathbf{X, A})=\mathbf{H} = \{h_1, \dots, h_N \}$이다.
$h_i \in \mathbb{R}^{F'}$는 i번째 노드의 high-level representation이다. $H$는 downstream task에 사용될 수 있다.
우리는 encoder로 GCN을 사용할 것이다.
우리는 encoder가 local mutual information을 maximize하도록 학습되길 원한다. 즉, 우리는 encoder가 global information 벡터인 $s$를 capture하도록 할 것이다. (summary vector). $s$를 얻기 위해서, 우리는 readout function을 leverage할 것이다. $\mathcal{R}: \mathbb{R}^{N \times F} \to \mathbb{R}^F$. patch representation으로 graph-level representaion을 얻을 것이다. 즉, $s = \mathcal{R}(\mathcal{E}(\mathbf{X, A}))$ 이다.
local mutual information을 최대화하는 방법으로 discriminator를 도입한다. $\mathcal{D}: \mathbb{R}^F \times \mathbb{R}^F \to \mathbb{R}$. $D(h_i, s)$는 probability score를 표현한다. 높은 값일 수록 patch가 summary에 포함된다는 뜻이다.
negative sample은 tilde 표시를 한다. $(\tilde{X}, \tilde{A})$. single graph의 경우, (stochastic) corruption function을 이용하여 original graph로부터 negative sample을 만든다. $\mathcal{C}: \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \to \mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M}$.
즉 $(\tilde{\mathbf{X}}, \tilde{\mathbf{A}}) \sim \mathcal{C}(\mathbf{X, A})$
objective function으로 다음의 loss를 이용한다.
DGI가 학습되는 전반적인 순서는 다음과 같다.
input: single graph setup. $(\mathbf{X, A})$
- corroption function으로 negative sample을 구한다. $(\tilde{\mathbf{X}}, \tilde{\mathbf{A}}) \sim \mathcal{C}(\mathbf{X, A})$
- patch representation을 encoder를 통해 구한다. $\mathbf{H} = \mathcal{E}(\mathbf{X,A}) = \{h_1, \dots, h_N \}$
- negative sample의 patch representation을 구한다. $\tilde{\mathbf{H}} = \mathcal{E}(\tilde{\mathbf{X}}, \tilde{\mathbf{A}}) = \{ \tilde{h}_1, \dots, \tilde{h}_M \}$
- patch representation을 readout을 통해 summary를 구한다. $s = \mathcal{R}(\mathbf{H})$
- objective function을 이용하여 $\mathcal{E, R, D}$의 parameter를 update한다.
4. Classification Performance
Transductive learning
encoder로 1-layer GCN을 사용했다. propagation rule은 다음과 같다.
$\hat{\mathbf{A}} = \mathbf{A+I_N}$ (self-loop가 있는 adjacency matrix)이고 $\hat{D}$는 $\hat{A}$의 degree matrix 이다. nonlinearity로는 $\sigma=\text{PReLU}$를 사용했다. learnable linear transormation은 $\Theta \in \mathbb{R}^{F \times F'}$이고 $F'=512$또는 $F'=256$ (Pumbed 데이터셋, memory issue)이다.
corruption function의 경우, $X$의 row-wise shuffle을 이용해 $\tilde{X}$를 얻었고, 인접행렬은 원래 그래프의 것을 그대로 이용했다. 따라서 corrupted graph data는 원본 그래프와 동일한 그래프 구조를 갖고 feature만 다르기 때문에 patch representation만 다르다.
Inductive learning
encoder로 mean-pooling을 이용하는 GraphSAGE-GCN을 이용하였다.
Reddit 데이터셋의 경우, 3-layer에 skip connection을 이용하였다.
nonlinear function으로 동일하게 PReLU를 이용하였다.
Inductive learning on multiple graphs
PPI 데이터셋의 encoder는 3-layer mean-pooling model + dense skip connection을 이용하였다.
Readout, Discriminator
readout function은 단순히 전체 node embedding의 평균값에 nonlinearity(logistic sigmoid)를 통과한 형태이다.
각 실험마다 최적의 readout은 있었으나, 이와 같은 방법이 그래프의 크기에 따라 성능이 감소할 여지가 있다. 이 경우, set2vec에서 제안한 readout 방법이나, DiffPool이 대안이 될 수 있다.
discriminator score는 summary와 patch representation의 pair를 input으로 한다. $\mathbf{W}$는 learnable scoring matrix이고, $\sigma$는 logistic sigmoid으로 score를 확률(positive example일 확률)로 변환한다.
(실험 결과 table은 생략; DGI가 가장 좋았음)
5. Qualitative Analysis
Cora dataset이 node개수가 적고 깨끗하기에 이것으로 t-SNE 임베딩 분포를 살펴본다.
cluster를 평가하는 silhouette score가 0.234로 높았다.
'스터디 > 인공지능, 딥러닝, 머신러닝' 카테고리의 다른 글
[CS224W] GNN for RecSys (2) - Embedding-Based Models (0) | 2024.11.07 |
---|---|
[CS224W] GNN for RecSys (1) - Task and Evaluation (2) | 2024.11.06 |
[Bayesian] Bayesian Linear Regression (베이지안 선형 회귀) (0) | 2024.05.08 |
Double Descent: new approach of bias-variance trade-off (0) | 2024.03.03 |
Overfitting을 막는 방법들 (regularization, cross-validation, early stopping) (0) | 2024.03.02 |