# -*- coding: utf-8 -*-
import os
import torch
import torch.nn.functional as F
from . import AE
from ..metrics import DGCMetric
from torch import Tensor
from .dgc_model import DGCModel
from pydgc.modules import SSCLayer
from typing import Tuple, Any
from yacs.config import CfgNode as CN
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from ..utils import Logger, target_distribution, validate_and_create_path
[docs]class SDCN(DGCModel):
"""Structural Deep Clustering Network.
Reference: https://doi.org/10.1145/3366423.3380214
Args:
logger (Logger): Logger object.
cfg (CN): Configuration object.
"""
def __init__(self, logger: Logger, cfg: CN):
super(SDCN, self).__init__(logger, cfg)
self.ae = AE(logger, cfg)
dims = self.cfg.model.dims.copy()
dims.insert(0, self.cfg.dataset.num_features)
self.gcn_1 = GCNConv(dims[0], dims[1], add_self_loops=cfg.dataset.augmentation.add_self_loops).to(self.device)
self.gcn_2 = GCNConv(dims[1], dims[2], add_self_loops=cfg.dataset.augmentation.add_self_loops).to(self.device)
self.gcn_3 = GCNConv(dims[2], dims[3], add_self_loops=cfg.dataset.augmentation.add_self_loops).to(self.device)
self.gcn_4 = GCNConv(dims[3], dims[-1], add_self_loops=cfg.dataset.augmentation.add_self_loops).to(self.device)
self.gcn_5 = GCNConv(dims[-1], cfg.dataset.n_clusters, add_self_loops=cfg.dataset.augmentation.add_self_loops).to(self.device)
self.ssc = SSCLayer(in_channels=dims[-1], out_channels=self.cfg.dataset.n_clusters, method='kl_div').to(self.device)
self.loss_curve = []
self.nmi_curve = []
self.pretrain_loss_curve = []
self.best_embedding = None
self.best_predicted_labels = None
self.best_results = {'ACC': -1}
self.reset_parameters()
[docs] def reset_parameters(self):
self.ae.reset_parameters()
self.gcn_1.reset_parameters()
self.gcn_2.reset_parameters()
self.gcn_3.reset_parameters()
self.gcn_4.reset_parameters()
self.gcn_5.reset_parameters()
self.ssc.reset_parameters()
[docs] def forward(self, data: Data, sigma: float = 0.5) -> Any:
x = data.x.to(self.device)
edge_index = data.edge_index.to(self.device)
encodes, decodes = self.ae.forward(x)
hat_x = decodes[-1]
h = F.relu(self.gcn_1(x, edge_index))
h = F.relu(self.gcn_2((1 - sigma) * h + sigma * encodes[0], edge_index))
h = F.relu(self.gcn_3((1 - sigma) * h + sigma * encodes[1], edge_index))
h = F.relu(self.gcn_4((1 - sigma) * h + sigma * encodes[2], edge_index))
embedding = self.gcn_5((1 - sigma) * h + sigma * encodes[-1], edge_index)
predict = F.softmax(embedding, dim=1)
q = self.ssc(encodes[-1])
return predict, embedding, hat_x, q
[docs] def loss(self, x, hat_x, q, pred) -> Tensor:
reconstruct_loss = self.ae.loss(x, hat_x)
ssc_loss = self.ssc.loss(q, method='kl_div')
p = target_distribution(q.detach())
ce_loss = F.kl_div(pred.log(), p, reduction='batchmean')
alpha = float(self.cfg.train.alpha)
beta = float(self.cfg.train.beta)
loss_total = reconstruct_loss + alpha * ssc_loss + beta * ce_loss
return loss_total
[docs] def pretrain(self, data: Data, cfg: CN = None, flag: str = "PRETRAIN AE"):
if cfg is None:
cfg = self.cfg.train.pretrain
self.pretrain_loss_curve = self.ae.train_model(data, cfg, flag)
validate_and_create_path(cfg.dir)
pretrain_file_name = os.path.join(cfg.dir, f'ae.pth')
torch.save(self.ae.state_dict(), pretrain_file_name)
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN SDCN"):
if cfg is None:
cfg = self.cfg.train
# load pretrained ae model
pretrain_file_name = os.path.join(cfg.pretrain.dir, f'ae.pth')
if not os.path.exists(pretrain_file_name):
self.pretrain(data, cfg.pretrain, flag='PRETRAIN AE')
self.ae.load_state_dict(torch.load(pretrain_file_name, map_location='cpu'))
optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr))
# initialize ssc layer
_, _, cluster_centers = self.ae.clustering(data)
self.ssc.cluster_centers.data = cluster_centers.to(self.device)
self.ae.evaluate(data)
self.logger.flag(flag)
# train
for epoch in range(1, cfg.max_epoch + 1):
self.train()
optimizer.zero_grad()
predict, embedding, hat_x, q = self.forward(data, cfg.sigma)
loss = self.loss(data.x, hat_x, q, predict)
loss.backward()
optimizer.step()
self.loss_curve.append(loss.item())
self.logger.loss(epoch, loss)
if epoch % 1 == 0:
if self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
self.nmi_curve.append(results['NMI'])
if results['ACC'] > self.best_results['ACC']:
self.best_embedding = embedding
self.best_predicted_labels = predicted_labels
self.best_results = results
if not self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results
return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data) -> Tuple[Tensor, Tensor]:
with torch.no_grad():
self.eval()
predict, embedding, _, _ = self.forward(data, self.cfg.train.sigma)
return predict, embedding
[docs] def clustering(self, data) -> Tuple[Tensor, Tensor, Tensor]:
predict, embedding = self.get_embedding(data)
labels_ = torch.from_numpy(predict.detach().cpu().numpy().argmax(axis=1))
clustering_centers = self.ssc.cluster_centers.data
return embedding, labels_, clustering_centers
[docs] def evaluate(self, data: Data):
embedding, predicted_labels, clustering_centers = self.clustering(data)
ground_truth = data.y.numpy()
metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index)
results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate)
return embedding, predicted_labels, results