# -*- coding: utf-8 -*-
import random
from typing import Tuple, Any
from sklearn.cluster import Birch
from torch import Tensor
from torch.optim import lr_scheduler
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv
from . import DGCModel
from yacs.config import CfgNode as CN
import numpy as np
from ..metrics import DGCMetric
from ..utils import Logger
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]def convert_scipy_torch_sp(sp_adj):
"""Convert scipy sparse matrix to torch sparse matrix.
Args:
sp_adj (scipy.sparse.csr_matrix): Input sparse matrix.
Returns:
torch.sparse_coo_tensor: Output sparse matrix.
"""
sp_adj = sp_adj.tocoo()
indices = torch.tensor(np.vstack((sp_adj.row, sp_adj.col)))
sp_adj = torch.sparse_coo_tensor(indices, torch.tensor(sp_adj.data), size=sp_adj.shape)
return sp_adj
[docs]def aux_objective(output, s, oh_labels):
"""Auxiliary objective function.
Args:
output (torch.Tensor): Output tensor.
s (torch.Tensor): Sample indices.
oh_labels (torch.Tensor): One-hot labels.
Returns:
torch.Tensor: Auxiliary objective loss.
"""
sample_size = len(s)
out = output[s, :].float()
C = oh_labels[s, :].float()
X = C.sum(dim=0)
X = X ** 2
X = X.sum()
Y = torch.matmul(torch.t(out), C)
Y = torch.matmul(Y, torch.t(Y))
Y = torch.trace(Y)
t1 = torch.matmul(torch.t(C), C)
t1 = torch.matmul(t1, t1)
t1 = torch.trace(t1)
t2 = torch.matmul(torch.t(out), out)
t2 = torch.matmul(t2, t2)
t2 = torch.trace(t2)
t3 = torch.matmul(torch.t(out), C)
t3 = torch.matmul(t3, torch.t(t3))
t3 = torch.trace(t3)
aux_objective_loss = 1 / (sample_size ** 2) * (t1 + t2 - 2 * t3)
return aux_objective_loss
[docs]def regularization(output, s):
"""Regularization function.
Args:
output (torch.Tensor): Output tensor.
s (torch.Tensor): Sample indices.
Returns:
torch.Tensor: Regularization loss.
"""
out = output[s, :]
ss = out.sum(dim=0)
ss = ss ** 2
ss = ss.sum()
avg_sim = 1 / (len(s) ** 2) * ss
return avg_sim ** 2
[docs]class DGCLUSTER(DGCModel):
"""DGCLUSTER: A Neural Framework for Attributed Graph Clustering via Modularity Maximization.
Reference: https://ojs.aaai.org/index.php/AAAI/article/view/28983
Args:
logger (Logger): Logger object.
cfg (CN): Configuration object.
"""
def __init__(self, logger: Logger, cfg: CN):
super(DGCLUSTER, self).__init__(logger, cfg)
dims = cfg.model.dims.copy()
dims.insert(0, cfg.dataset.num_features)
if cfg.model.gnn_type == 'gcn':
self.conv1 = GCNConv(dims[0], dims[1])
self.conv2 = GCNConv(dims[1], dims[2])
self.conv3 = GCNConv(dims[2], dims[-1])
elif cfg.model.gnn_type == 'gat':
self.conv1 = GATConv(dims[0], dims[1])
self.conv2 = GATConv(dims[1], dims[2])
self.conv3 = GATConv(dims[2], dims[-1])
elif cfg.model.gnn_type == 'gin':
self.conv1 = GINConv(nn.Linear(dims[0], dims[1]))
self.conv2 = GINConv(nn.Linear(dims[1], dims[2]))
self.conv3 = GINConv(nn.Linear(dims[2], dims[-1]))
else:
self.conv1 = SAGEConv(dims[0], dims[1])
self.conv2 = SAGEConv(dims[1], dims[2])
self.conv3 = SAGEConv(dims[2], dims[-1])
self.conv1.to(self.device)
self.conv2.to(self.device)
self.conv3.to(self.device)
self.loss_curve = []
self.nmi_curve = []
self.best_embedding = None
self.best_predicted_labels = None
self.best_results = {'ACC': -1}
[docs] def reset_parameters(self):
pass
[docs] def forward(self, data: Data) -> Any:
x, edge_index = data.x.to(self.device), data.edge_index.to(self.device)
x = self.conv1(x, edge_index)
x = F.selu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.selu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = x / (x.sum())
x = (F.tanh(x)) ** 2
x = F.normalize(x)
return x
[docs] def loss(self, output, data, lam, alp) -> Tensor:
num_nodes = self.cfg.dataset.num_nodes
num_edges = self.cfg.dataset.num_edges
sparse_adj = data.sparse_adj
degree = data.degree
sample_size = int(1 * num_nodes)
s = random.sample(range(0, num_nodes), sample_size)
s_output = output[s, :]
s_adj = sparse_adj[s, :][:, s]
s_adj = convert_scipy_torch_sp(s_adj)
s_degree = degree[s]
x = torch.matmul(torch.t(s_output).double(), s_adj.double().to(self.device))
x = torch.matmul(x, s_output.double())
x = torch.trace(x)
y = torch.matmul(torch.t(s_output).double(), s_degree.double().to(self.device))
y = (y ** 2).sum()
y = y / (2 * num_edges)
# scaling=1
scaling = num_nodes ** 2 / (sample_size ** 2)
m_loss = -((x - y) / (2 * num_edges)) * scaling
aux_loss = lam * aux_objective(output, s, data.oh_labels)
reg_loss = alp * regularization(output, s)
loss = m_loss + aux_loss + reg_loss
return loss
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN DGCLUSTER"):
if cfg is None:
cfg = self.cfg.train
self.logger.flag(flag)
optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr), betas=(0.9, 0.999), weight_decay=0.001, amsgrad=True)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=cfg.max_epoch)
self.train()
for epoch in range(cfg.max_epoch):
optimizer.zero_grad()
out = self.forward(data)
loss = self.loss(out, data, cfg.lam, cfg.alp)
loss.backward()
self.loss_curve.append(loss.item())
self.logger.loss(epoch, loss)
torch.nn.utils.clip_grad_norm_(self.parameters(), 0.1)
optimizer.step()
scheduler.step()
if epoch % 1 == 0:
if self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
self.nmi_curve.append(results['NMI'])
if results['ACC'] > self.best_results['ACC']:
self.best_embedding = embedding
self.best_predicted_labels = predicted_labels
self.best_results = results
if not self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results
return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data: Data) -> Tensor:
with torch.no_grad():
embedding = self.forward(data)
return embedding.detach()
[docs] def clustering(self, data: Data) -> Tuple[Tensor, Tensor, Tensor]:
embedding = self.get_embedding(data)
birch = Birch(n_clusters=self.cfg.dataset.n_clusters, threshold=0.5)
labels = torch.from_numpy(birch.fit_predict(embedding.cpu().numpy()))
return embedding, labels, birch.subcluster_centers_
[docs] def evaluate(self, data: Data):
embedding, predicted_labels, clustering_centers = self.clustering(data)
ground_truth = data.y.cpu().numpy()
metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index)
results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate)
return embedding, predicted_labels, results