본문 바로가기
스터디/인공지능, 딥러닝, 머신러닝

[PyG] GCN, GraphSAGE, GAT 예제 코드

by 궁금한 준이 2023. 7. 12.
728x90
반응형

Import

import gc
import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid​

Load Dataset (Cora dataset)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device) # cuda:1

dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)

 

reproductivity를 위해 랜덤 시드값을 고정하자.

def fix_seed(seed=777):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)​

GNN Network

GCN

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim, cached=True)
        self.conv2 = GCNConv(hidden_dim, output_dim, cached=True)
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim)])
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        self.convs.append(GCNConv(hidden_dim, output_dim))
        
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv, bn in zip(self.convs[:-1], self. bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
            
        x = self.convs[-1](x, edge_index)
                
        return x

GraphSAGE

aggregate 함수로 max를 사용하였다.

class GraphSAGE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList([SAGEConv(input_dim, hidden_dim, 'max')])
        self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim)])
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim, 'max'))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        self.convs.append(SAGEConv(hidden_dim, output_dim, 'max'))
        
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv, bn in zip(self.convs[:-1], self. bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, training=self.training)
            
        x = self.convs[-1](x, edge_index)
                
        return x

GAT

class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, input_heads=8):
        super(GAT, self).__init__()        
        self.convs = nn.ModuleList([GATConv(input_dim, hidden_dim, heads=input_heads, dropout=0.5)])
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * input_heads, hidden_dim, heads=input_heads, dropout=0.5))
        self.convs.append(GATConv(hidden_dim * input_heads, output_dim, concat=False, heads=1, dropout=0.5))
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=0.6, training=self.training)
                
        return x

Train and Test

def get_model(model_type='gcn'):
    if model_type == 'gcn':
        model = GCN(dataset.num_node_features, 256, dataset.num_classes, 3)
    elif model_type == 'sage':
        model = GraphSAGE(dataset.num_node_features, 256, dataset.num_classes, 3)
    elif model_type == 'gat':
        model = GAT(dataset.num_node_features, 8, dataset.num_classes, 8)
    
    return model
def train(model, device='cpu', n_epochs=100, learning_rate=0.01):
    model = model.to(device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    
    model.train()
    for epoch in range(1, n_epochs + 1):
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        
        if epoch % 10 == 0:
            print(f'epoch: {epoch + 1}, loss: {loss.item()}')
        
        loss.backward()
        optimizer.step()
def test(model):
    model.eval()
    pred = model(data).argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    print(f'Test Accuracy: {acc:.6f}')​

Experiments

print('===== GCN =====')
fix_seed()
model = get_model(model_type='gcn')
train(model, device=device, learning_rate=0.05)
test(model)

===== GCN =====
epoch: 11, loss: 0.023094383999705315
epoch: 21, loss: 0.0105555085465312
epoch: 31, loss: 0.003940434195101261
epoch: 41, loss: 0.00966612622141838
epoch: 51, loss: 0.004165519494563341
epoch: 61, loss: 0.002340392442420125
epoch: 71, loss: 0.03174349293112755
epoch: 81, loss: 0.011421283707022667
epoch: 91, loss: 0.0033531629014760256
epoch: 101, loss: 0.0034052524715662003
Test Accuracy: 0.768000​
print('===== GraphSAGE =====')
fix_seed()
graph_sage = get_model(model_type='sage')
train(graph_sage, device=device, learning_rate=0.01)
test(graph_sage)

===== GraphSAGE =====
epoch: 11, loss: 1.146528959274292
epoch: 21, loss: 1.1325287818908691
epoch: 31, loss: 1.1231029033660889
epoch: 41, loss: 1.121543049812317
epoch: 51, loss: 1.1198463439941406
epoch: 61, loss: 1.1174838542938232
epoch: 71, loss: 1.1175700426101685
epoch: 81, loss: 1.1164284944534302
epoch: 91, loss: 1.1153371334075928
epoch: 101, loss: 1.1151783466339111
Test Accuracy: 0.795000
print('===== GAT =====')
fix_seed()
gat = get_model(model_type='gat')
train(gat, device=device, n_epochs=300, learning_rate=0.005)
test(gat)

===== GAT =====
epoch: 11, loss: 2.1975746154785156
epoch: 21, loss: 2.132791042327881
epoch: 31, loss: 1.9391498565673828
epoch: 41, loss: 1.8031173944473267
epoch: 51, loss: 1.980901837348938
epoch: 61, loss: 1.8231042623519897
epoch: 71, loss: 1.8520184755325317
epoch: 81, loss: 1.681148886680603
epoch: 91, loss: 1.78931725025177
epoch: 101, loss: 1.9704439640045166
epoch: 111, loss: 1.743062973022461
epoch: 121, loss: 1.5524016618728638
epoch: 131, loss: 1.5907924175262451
epoch: 141, loss: 1.6539884805679321
epoch: 151, loss: 1.5278500318527222
epoch: 161, loss: 1.5078015327453613
epoch: 171, loss: 1.4780718088150024
epoch: 181, loss: 1.5131406784057617
epoch: 191, loss: 1.746212124824524
epoch: 201, loss: 1.5512117147445679
epoch: 211, loss: 1.4683003425598145
epoch: 221, loss: 1.5917233228683472
epoch: 231, loss: 1.5137397050857544
epoch: 241, loss: 1.3165178298950195
epoch: 251, loss: 1.3876081705093384
epoch: 261, loss: 1.5510534048080444
epoch: 271, loss: 1.5977932214736938
epoch: 281, loss: 1.1832685470581055
epoch: 291, loss: 1.4997584819793701
epoch: 301, loss: 1.2910630702972412
Test Accuracy: 0.745000

 

728x90
반응형