# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch_geometric.data import Data
from ..clusterings import KMeansGPU
from ..metrics import DGCMetric
from ..utils import Logger
from typing import Tuple, Any
from torch import Tensor
from yacs.config import CfgNode as CN
from . import DGCModel
[docs]def comprehensive_similarity(Z1, Z2, E1, E2, alpha):
"""Comprehensive similarity function.
Args:
Z1 (torch.Tensor): Latent representation of the first view.
Z2 (torch.Tensor): Latent representation of the second view.
E1 (torch.Tensor): Latent representation of the first view.
E2 (torch.Tensor): Latent representation of the second view.
alpha (float): Weight of the similarity function.
Returns:
torch.Tensor: Comprehensive similarity matrix.
"""
Z1_Z2 = torch.cat([torch.cat([Z1 @ Z1.T, Z1 @ Z2.T], dim=1),
torch.cat([Z2 @ Z1.T, Z2 @ Z2.T], dim=1)], dim=0)
E1_E2 = torch.cat([torch.cat([E1 @ E1.T, E1 @ E2.T], dim=1),
torch.cat([E2 @ E1.T, E2 @ E2.T], dim=1)], dim=0)
S = alpha * Z1_Z2 + (1 - alpha) * E1_E2
return S
[docs]def hard_sample_aware_infoNCE(S, M, pos_neg_weight, pos_weight, node_num):
"""Hard sample aware InfoNCE loss function.
Args:
S (torch.Tensor): Comprehensive similarity matrix.
M (torch.Tensor): Mask matrix.
pos_neg_weight (float): Weight of the negative samples.
pos_weight (float): Weight of the positive samples.
node_num (int): Number of nodes.
Returns:
torch.Tensor: InfoNCE loss.
"""
pos_neg = M * torch.exp(S * pos_neg_weight)
pos = torch.cat([torch.diag(S, node_num), torch.diag(S, -node_num)], dim=0)
pos = torch.exp(pos * pos_weight)
neg = (torch.sum(pos_neg, dim=1) - pos)
infoNEC = (-torch.log(pos / (pos + neg))).sum() / (2 * node_num)
return infoNEC
[docs]def square_euclid_distance(Z, center):
"""Square Euclidean distance function.
Args:
Z (torch.Tensor): Latent representation.
center (torch.Tensor): Clustering centers.
Returns:
torch.Tensor: Square Euclidean distance matrix.
"""
ZZ = (Z * Z).sum(-1).reshape(-1, 1).repeat(1, center.shape[0])
CC = (center * center).sum(-1).reshape(1, -1).repeat(Z.shape[0], 1)
ZZ_CC = ZZ + CC
ZC = Z @ center.T
distance = ZZ_CC - 2 * ZC
return distance
[docs]def phi(embedding, cluster_num):
"""Clustering function.
Args:
embedding (torch.Tensor): Latent representation.
cluster_num (int): Number of clusters.
Returns:
torch.Tensor: Clustering labels.
torch.Tensor: Clustering centers.
"""
labels_, clustering_centers_ = KMeansGPU(cluster_num).fit(embedding)
return labels_, clustering_centers_
[docs]class HSAN(DGCModel):
"""Hard Sample Aware Network for Contrastive Deep Graph Clustering.
Reference: https://ojs.aaai.org/index.php/AAAI/article/view/26071
Args:
logger (Logger): Logger object.
cfg (CN): Configuration object.
"""
def __init__(self, logger: Logger, cfg: CN):
super(HSAN, self).__init__(logger, cfg)
self.device = torch.device(cfg.device)
input_dim = cfg.dataset.num_features if cfg.model.dims.input_dim == -1 else cfg.model.dims.input_dim
hidden_dim = cfg.model.dims.hidden_dim
n_num = cfg.dataset.num_nodes
self.AE1 = nn.Linear(input_dim, hidden_dim).to(self.device)
self.AE2 = nn.Linear(input_dim, hidden_dim).to(self.device)
self.SE1 = nn.Linear(n_num, hidden_dim).to(self.device)
self.SE2 = nn.Linear(n_num, hidden_dim).to(self.device)
self.alpha = nn.Parameter(torch.Tensor(1, ))
self.alpha.data = torch.tensor(0.99999).to(self.device)
self.pos_weight = torch.ones(n_num * 2).to(self.device)
self.pos_neg_weight = torch.ones([n_num * 2, n_num * 2]).to(self.device)
if self.cfg.model.act == "ident":
self.activate = lambda x: x
if self.cfg.model.act == "sigmoid":
self.activate = nn.Sigmoid()
self.loss_curve = []
self.nmi_curve = []
self.best_embedding = None
self.best_predicted_labels = None
self.best_results = {'ACC': -1}
[docs] def reset_parameters(self):
pass
[docs] def forward(self, data) -> Any:
x = data.x.to(self.device)
A = data.adj.to(self.device)
Z1 = self.activate(self.AE1(x))
Z2 = self.activate(self.AE2(x))
Z1 = F.normalize(Z1, dim=1, p=2)
Z2 = F.normalize(Z2, dim=1, p=2)
E1 = F.normalize(self.SE1(A), dim=1, p=2)
E2 = F.normalize(self.SE2(A), dim=1, p=2)
return Z1, Z2, E1, E2
[docs] def high_confidence(self, Z, center):
distance_norm = torch.min(F.softmax(square_euclid_distance(Z, center), dim=1), dim=1).values
value, _ = torch.topk(distance_norm, int(Z.shape[0] * (1 - self.cfg.train.tau)))
index = torch.where(distance_norm <= value[-1],
torch.ones_like(distance_norm), torch.zeros_like(distance_norm))
high_conf_index_v1 = torch.nonzero(index).reshape(-1, )
high_conf_index_v2 = high_conf_index_v1 + Z.shape[0]
H = torch.cat([high_conf_index_v1, high_conf_index_v2], dim=0)
H_mat = np.ix_(H.cpu(), H.cpu())
return H, H_mat
[docs] def pseudo_matrix(self, P, S, node_num):
P = P.detach().clone()
P = torch.cat([P, P], dim=0)
Q = (P == P.unsqueeze(1)).float().to(self.device)
S_norm = (S - S.min()) / (S.max() - S.min())
M_mat = torch.abs(Q - S_norm) ** self.cfg.train.beta
M = torch.cat([torch.diag(M_mat, node_num), torch.diag(M_mat, -node_num)], dim=0)
return M, M_mat
[docs] def loss(self, *args, **kwargs) -> Tensor:
pass
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN HSAN"):
if cfg is None:
cfg = self.cfg.train
node_num = self.cfg.dataset.num_nodes
# adam optimizer
optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr))
# positive and negative sample pair index matrix
mask = torch.ones([node_num * 2, node_num * 2]) - torch.eye(node_num * 2)
mask = mask.to(self.device)
# training
for epoch in range(0, cfg.max_epoch):
# train mode
self.train()
# encoding with Eq. (3)-(5)
Z1, Z2, E1, E2 = self.forward(data)
# calculate comprehensive similarity by Eq. (6)
S = comprehensive_similarity(Z1, Z2, E1, E2, self.alpha)
# calculate hard sample aware contrastive loss by Eq. (10)-(11)
loss = hard_sample_aware_infoNCE(S, mask, self.pos_neg_weight, self.pos_weight, node_num)
# optimization
loss.backward()
optimizer.step()
self.loss_curve.append(loss.item())
self.logger.loss(epoch, loss)
# testing and update weights of sample pairs
if epoch % 10 == 0:
self.eval()
# encoding
Z1, Z2, E1, E2 = self.forward(data)
# calculate comprehensive similarity by Eq. (6)
S = comprehensive_similarity(Z1, Z2, E1, E2, self.alpha)
# fusion and testing
embedding = (Z1 + Z2) / 2
P, center = phi(embedding, self.cfg.dataset.n_clusters)
# select high confidence samples
H, H_mat = self.high_confidence(embedding, center)
# calculate new weight of sample pair by Eq. (9)
M, M_mat = self.pseudo_matrix(P, S, node_num)
# update weight
self.pos_weight[H] = M[H].data
self.pos_neg_weight[H_mat] = M_mat[H_mat].data
if epoch % 1 == 0:
if self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
self.nmi_curve.append(results['NMI'])
if results['ACC'] > self.best_results['ACC']:
self.best_embedding = embedding
self.best_predicted_labels = predicted_labels
self.best_results = results
if not self.cfg.evaluate.each:
embedding, predicted_labels, results = self.evaluate(data)
return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results
return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data) -> Tuple[Tensor, Tensor]:
with torch.no_grad():
self.eval()
# encoding
Z1, Z2, E1, E2 = self.forward(data)
# calculate comprehensive similarity by Eq. (6)
S = comprehensive_similarity(Z1, Z2, E1, E2, self.alpha)
# fusion and testing
embedding = (Z1 + Z2) / 2
return embedding.detach(), S.detach()
[docs] def clustering(self, data) -> Tuple[Tensor, Tensor, Tensor]:
embedding, S = self.get_embedding(data)
labels_, clustering_centers_ = KMeansGPU(self.cfg.dataset.n_clusters).fit(embedding)
return embedding, labels_, clustering_centers_
[docs] def evaluate(self, data):
embedding, predicted_labels, clustering_centers = self.clustering(data)
ground_truth = data.y.numpy()
metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index)
results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate)
return embedding, predicted_labels, results