본문 바로가기
스터디/인공지능, 딥러닝, 머신러닝

[PyG] GIN 예제 코드

by 궁금한 준이 2023. 7. 12.
728x90
반응형

Introduction

GIN은 graph-level task에 적합한 GNN 모델이다.

protein dataset을 이용하여 graph classification을 수행해보자.

 

Setup

import gc
import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, global_add_pool, global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device) # cuda:1


def fix_seed(seed=777):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

Load Dataset

dataset = TUDataset(root='/tmp/Proteins', name='PROTEINS')

print(len(dataset)) # 1113
print(dataset.num_classes) # 2
print(dataset.num_node_features) # 3
print(dataset.num_classes) # 2

# train-valid-test split
TRAIN_SIZE = int(len(dataset) * 0.8)
TEST_SIZE = int(len(dataset) * 0.1)

dataset = dataset.shuffle()
train_dataset = dataset[:TRAIN_SIZE]
valid_dataset = dataset[TRAIN_SIZE:-TEST_SIZE]
test_dataset = dataset[-TEST_SIZE:]

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for graph in train_loader:
    print(graph) # DataBatch(edge_index=[2, 8412], x=[2299, 3], y=[64], batch=[2299], ptr=[65])
    break
반응형

GIN Model

GIN은 MLP로 aggregate와 combine 함수를 MLP를 이용한다.

또한 graph READOUT 할 때 sum pooling이 더 좋다고 한다. (실험을 통해 mean pooling과 비교해보면 실제로 더 좋다)

class SimpleGIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3):
        super(SimpleGIN, self).__init__()
        self.convs = nn.ModuleList([self.build_conv_layer(input_dim, hidden_dim)])
        self.bns = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(self.build_conv_layer(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        
        self.lin1 = nn.Linear(hidden_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, output_dim)
        
        self.pool = global_add_pool
        self.act = nn.ReLU()
        
    def build_conv_layer(self, input_dim, hidden_dim):
        return GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                     nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for bn, conv in zip(self.bns, self.convs[:-1]):
            x = conv(x, edge_index)
            x = bn(x)
            x = self.act(x)
            x = F.dropout(x, training=self.training)
        x = self.convs[-1](x , edge_index)
        
        # pooling
        x = self.pool(x, batch)
        
        # classification
        x = F.dropout(F.relu(self.lin1(x)))
        x = self.lin2(x)
        
        return x
        
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

Train & Test

@torch.no_grad()
def test(model, data_loader, device='cpu'):
    loss, correct = 0, 0
    criterion = nn.CrossEntropyLoss()
    model.eval()
    for data in data_loader:
        data = data.to(device)
        out = model(data)
        loss += criterion(out, data.y) / len(data_loader)
        correct += torch.sum(out.argmax(dim=1) == data.y)
        
    acc = correct / len(data_loader.dataset)
    return loss, acc
def train(model, train_loader, device='cpu', n_epochs=200, learning_rate=0.001):
    model = model.to(device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    best_valid_loss = float('inf')
    best_model = None
    
    model.train()
    for epoch in range(n_epochs + 1):
        train_loss, train_acc = 0, 0
        
        for data in train_loader:
            data = data.to(device)
            
            optimizer.zero_grad()
            out = model(data)
            #print(out.argmax(dim=1).shape, data.y.shape)
            loss = criterion(out, data.y)
            train_loss += loss.item() / len(train_loader)
            train_acc += torch.sum(out.argmax(dim=1) == data.y) / len(train_loader.dataset)
            #print(train_acc)
            loss.backward()
            optimizer.step()
            
        if epoch % 10 == 0:
            valid_loss, valid_acc = test(model, valid_loader, device=device)
            print(f'epoch: {epoch:>3}, train loss: {train_loss:.6f}, train_accuracy: {train_acc:.4f}, valid loss: {valid_loss:.6f}, valid_accuracy: {valid_acc:.4f}')
            
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                best_model = model
                
    return best_model

Experiment

fix_seed()
model = SimpleGIN(input_dim=dataset.num_node_features, 
                  hidden_dim=64, 
                  output_dim=dataset.num_classes, 
                  num_layers=3).to(device)

model.reset_parameters()
best_model = train(model, train_loader, device=device)
test_loss, test_acc = test(best_model, data_loader=test_loader, device=device)
print(f'Test Accuracy: {test_acc}')

epoch:   0, train loss: 1.724246, train_accuracy: 0.5596, valid loss: 0.595034, valid_accuracy: 0.6875
epoch:  10, train loss: 0.585427, train_accuracy: 0.6955, valid loss: 0.639831, valid_accuracy: 0.6696
epoch:  20, train loss: 0.574599, train_accuracy: 0.7371, valid loss: 0.519535, valid_accuracy: 0.7321
epoch:  30, train loss: 0.563444, train_accuracy: 0.7315, valid loss: 0.508432, valid_accuracy: 0.7232
epoch:  40, train loss: 0.540419, train_accuracy: 0.7618, valid loss: 0.521829, valid_accuracy: 0.7679
epoch:  50, train loss: 0.540308, train_accuracy: 0.7404, valid loss: 0.541889, valid_accuracy: 0.7946
epoch:  60, train loss: 0.559025, train_accuracy: 0.7416, valid loss: 0.601877, valid_accuracy: 0.7232
epoch:  70, train loss: 0.517523, train_accuracy: 0.7551, valid loss: 0.673514, valid_accuracy: 0.7500
epoch:  80, train loss: 0.526982, train_accuracy: 0.7551, valid loss: 0.619448, valid_accuracy: 0.7143
epoch:  90, train loss: 0.516150, train_accuracy: 0.7596, valid loss: 0.771738, valid_accuracy: 0.7232
epoch: 100, train loss: 0.510959, train_accuracy: 0.7708, valid loss: 0.492880, valid_accuracy: 0.7500
epoch: 110, train loss: 0.518415, train_accuracy: 0.7697, valid loss: 0.653470, valid_accuracy: 0.7321
epoch: 120, train loss: 0.501291, train_accuracy: 0.7775, valid loss: 0.841137, valid_accuracy: 0.7321
epoch: 130, train loss: 0.517517, train_accuracy: 0.7596, valid loss: 0.706645, valid_accuracy: 0.7411
epoch: 140, train loss: 0.490670, train_accuracy: 0.7798, valid loss: 0.655255, valid_accuracy: 0.7679
epoch: 150, train loss: 0.493445, train_accuracy: 0.7753, valid loss: 0.658685, valid_accuracy: 0.7679
epoch: 160, train loss: 0.509193, train_accuracy: 0.7730, valid loss: 0.484392, valid_accuracy: 0.7679
epoch: 170, train loss: 0.488481, train_accuracy: 0.7787, valid loss: 0.574501, valid_accuracy: 0.7500
epoch: 180, train loss: 0.482161, train_accuracy: 0.7978, valid loss: 0.656573, valid_accuracy: 0.7143
epoch: 190, train loss: 0.480034, train_accuracy: 0.7876, valid loss: 0.633464, valid_accuracy: 0.7232
epoch: 200, train loss: 0.449777, train_accuracy: 0.8079, valid loss: 0.600172, valid_accuracy: 0.6964
Test Accuracy: 0.6936936974525452

 

728x90
반응형