728x90
반응형
Import
import gc
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid
Load Dataset (Cora dataset)
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device) # cuda:1
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)
reproductivity를 위해 랜덤 시드값을 고정하자.
def fix_seed(seed=777):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
GNN Network
GCN
class GCN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super(GCN, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim, cached=True)
self.conv2 = GCNConv(hidden_dim, output_dim, cached=True)
self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim)])
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
self.bns.append(nn.BatchNorm1d(hidden_dim))
self.convs.append(GCNConv(hidden_dim, output_dim))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv, bn in zip(self.convs[:-1], self. bns):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.convs[-1](x, edge_index)
return x
GraphSAGE
aggregate 함수로 max를 사용하였다.
class GraphSAGE(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super(GraphSAGE, self).__init__()
self.convs = nn.ModuleList([SAGEConv(input_dim, hidden_dim, 'max')])
self.bns = nn.ModuleList([nn.BatchNorm1d(hidden_dim)])
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_dim, hidden_dim, 'max'))
self.bns.append(nn.BatchNorm1d(hidden_dim))
self.convs.append(SAGEConv(hidden_dim, output_dim, 'max'))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv, bn in zip(self.convs[:-1], self. bns):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.convs[-1](x, edge_index)
return x
GAT
class GAT(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, input_heads=8):
super(GAT, self).__init__()
self.convs = nn.ModuleList([GATConv(input_dim, hidden_dim, heads=input_heads, dropout=0.5)])
for _ in range(num_layers - 2):
self.convs.append(GATConv(hidden_dim * input_heads, hidden_dim, heads=input_heads, dropout=0.5))
self.convs.append(GATConv(hidden_dim * input_heads, output_dim, concat=False, heads=1, dropout=0.5))
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv in self.convs:
x = conv(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
return x
Train and Test
def get_model(model_type='gcn'):
if model_type == 'gcn':
model = GCN(dataset.num_node_features, 256, dataset.num_classes, 3)
elif model_type == 'sage':
model = GraphSAGE(dataset.num_node_features, 256, dataset.num_classes, 3)
elif model_type == 'gat':
model = GAT(dataset.num_node_features, 8, dataset.num_classes, 8)
return model
def train(model, device='cpu', n_epochs=100, learning_rate=0.01):
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
model.train()
for epoch in range(1, n_epochs + 1):
optimizer.zero_grad()
out = model(data)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
if epoch % 10 == 0:
print(f'epoch: {epoch + 1}, loss: {loss.item()}')
loss.backward()
optimizer.step()
def test(model):
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.6f}')
Experiments
print('===== GCN =====')
fix_seed()
model = get_model(model_type='gcn')
train(model, device=device, learning_rate=0.05)
test(model)
===== GCN =====
epoch: 11, loss: 0.023094383999705315
epoch: 21, loss: 0.0105555085465312
epoch: 31, loss: 0.003940434195101261
epoch: 41, loss: 0.00966612622141838
epoch: 51, loss: 0.004165519494563341
epoch: 61, loss: 0.002340392442420125
epoch: 71, loss: 0.03174349293112755
epoch: 81, loss: 0.011421283707022667
epoch: 91, loss: 0.0033531629014760256
epoch: 101, loss: 0.0034052524715662003
Test Accuracy: 0.768000
print('===== GraphSAGE =====')
fix_seed()
graph_sage = get_model(model_type='sage')
train(graph_sage, device=device, learning_rate=0.01)
test(graph_sage)
===== GraphSAGE =====
epoch: 11, loss: 1.146528959274292
epoch: 21, loss: 1.1325287818908691
epoch: 31, loss: 1.1231029033660889
epoch: 41, loss: 1.121543049812317
epoch: 51, loss: 1.1198463439941406
epoch: 61, loss: 1.1174838542938232
epoch: 71, loss: 1.1175700426101685
epoch: 81, loss: 1.1164284944534302
epoch: 91, loss: 1.1153371334075928
epoch: 101, loss: 1.1151783466339111
Test Accuracy: 0.795000
print('===== GAT =====')
fix_seed()
gat = get_model(model_type='gat')
train(gat, device=device, n_epochs=300, learning_rate=0.005)
test(gat)
===== GAT =====
epoch: 11, loss: 2.1975746154785156
epoch: 21, loss: 2.132791042327881
epoch: 31, loss: 1.9391498565673828
epoch: 41, loss: 1.8031173944473267
epoch: 51, loss: 1.980901837348938
epoch: 61, loss: 1.8231042623519897
epoch: 71, loss: 1.8520184755325317
epoch: 81, loss: 1.681148886680603
epoch: 91, loss: 1.78931725025177
epoch: 101, loss: 1.9704439640045166
epoch: 111, loss: 1.743062973022461
epoch: 121, loss: 1.5524016618728638
epoch: 131, loss: 1.5907924175262451
epoch: 141, loss: 1.6539884805679321
epoch: 151, loss: 1.5278500318527222
epoch: 161, loss: 1.5078015327453613
epoch: 171, loss: 1.4780718088150024
epoch: 181, loss: 1.5131406784057617
epoch: 191, loss: 1.746212124824524
epoch: 201, loss: 1.5512117147445679
epoch: 211, loss: 1.4683003425598145
epoch: 221, loss: 1.5917233228683472
epoch: 231, loss: 1.5137397050857544
epoch: 241, loss: 1.3165178298950195
epoch: 251, loss: 1.3876081705093384
epoch: 261, loss: 1.5510534048080444
epoch: 271, loss: 1.5977932214736938
epoch: 281, loss: 1.1832685470581055
epoch: 291, loss: 1.4997584819793701
epoch: 301, loss: 1.2910630702972412
Test Accuracy: 0.745000
728x90
반응형
'스터디 > 인공지능, 딥러닝, 머신러닝' 카테고리의 다른 글
[PyG] GIN 예제 코드 (0) | 2023.07.12 |
---|---|
[CS224w] Label Propagation on Graphs (1) - Outline (0) | 2023.07.12 |
[CS224w] General Tips for Debugging GNN (0) | 2023.06.18 |
[CS224w] Theory of GNNs (2) - Neighbor Aggregation, GIN (Graph Isomorphism Network) (0) | 2023.06.17 |
[CS224w] Theory of GNNs (1) - Local Neighborhood Structures, Computational Graph (0) | 2023.06.14 |