Source code for pydgc.models.ccgc

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from typing import Tuple, Any

from torch import Tensor
from torch_geometric.data import Data

from . import DGCModel
from ..clusterings import KMeansGPU
from ..metrics import DGCMetric
from ..utils import Logger

from yacs.config import CfgNode as CN


[docs]def init_clustering(feature, cluster_num): """Initialize clustering with kmeans. Args: feature (Tensor): Input feature. cluster_num (int): Number of clusters. Returns: predict_labels (Tensor): Predicted labels. dis (Tensor): Pairwise distance. """ kmeans = KMeansGPU(n_clusters=cluster_num, distance="euclidean", device="cuda") predict_labels, _ = kmeans.fit(feature) dis = kmeans.pairwise_distance(feature, kmeans.cluster_centers_) return predict_labels, dis
[docs]class CCGC(DGCModel): """ Cluster-Guided Contrastive Graph Clustering Network. Reference: https://ojs.aaai.org/index.php/AAAI/article/view/26285 Args: logger (Logger): Logger. cfg (CN): Config. """ def __init__(self, logger: Logger, cfg: CN): super(CCGC, self).__init__(logger, cfg) self.device = torch.device(cfg.device) dims = cfg.model.dims.copy() dims.insert(0, cfg.dataset.num_features) self.layers1 = nn.Linear(dims[0], dims[1]).to(self.device) self.layers2 = nn.Linear(dims[0], dims[1]).to(self.device) self.loss_curve = [] self.nmi_curve = [] self.best_embedding = None self.best_predicted_labels = None self.best_results = {'ACC': -1}
[docs] def reset_parameters(self): pass
[docs] def forward(self, x) -> Any: x = x.to(self.device) out1 = self.layers1(x) out2 = self.layers2(x) out1 = F.normalize(out1, dim=1, p=2) out2 = F.normalize(out2, dim=1, p=2) return out1, out2
[docs] def loss(self, *args, **kwargs) -> Tensor: pass
[docs] def train_model(self, data: Data, cfg: CN = None, flag: str = "TRAIN CCGC"): if cfg is None: cfg = self.cfg.train smooth_fea = data.x.to(self.device) predict_labels, dis = init_clustering(smooth_fea, self.cfg.dataset.n_clusters) optimizer = torch.optim.Adam(self.parameters(), lr=float(cfg.lr)) sample_size = self.cfg.dataset.num_nodes target = torch.eye(sample_size).to(self.device) for epoch in range(cfg.max_epoch): self.train() z1, z2 = self.forward(smooth_fea) if epoch > 50: high_confidence = torch.min(dis.cpu(), dim=1).values threshold = torch.sort(high_confidence).values[int(len(high_confidence) * cfg.threshold)] high_confidence_idx = np.argwhere(high_confidence < threshold)[0] # pos samples index = torch.tensor(range(sample_size), device=self.device)[high_confidence_idx] y_sam = predict_labels.detach().clone()[high_confidence_idx].to(self.device) index = index[torch.argsort(y_sam)] class_num = {} for label in torch.sort(y_sam).values: label = label.item() if label in class_num.keys(): class_num[label] += 1 else: class_num[label] = 1 key = sorted(class_num.keys()) if len(class_num) < 2: continue pos_contrastive = 0 centers_1 = torch.tensor([], device=self.device) centers_2 = torch.tensor([], device=self.device) for i in range(len(key[:-1])): class_num[key[i + 1]] = class_num[key[i]] + class_num[key[i + 1]] now = index[class_num[key[i]]:class_num[key[i + 1]]] pos_embed_1 = z1[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)] pos_embed_2 = z2[np.random.choice(now.cpu(), size=int((now.shape[0] * 0.8)), replace=False)] pos_contrastive += (2 - 2 * torch.sum(pos_embed_1 * pos_embed_2, dim=1)).sum() centers_1 = torch.cat([centers_1, torch.mean(z1[now], dim=0).unsqueeze(0)], dim=0) centers_2 = torch.cat([centers_2, torch.mean(z2[now], dim=0).unsqueeze(0)], dim=0) pos_contrastive = pos_contrastive / self.cfg.dataset.n_clusters if pos_contrastive == 0: continue if len(class_num) < 2: loss = pos_contrastive else: centers_1 = F.normalize(centers_1, dim=1, p=2) centers_2 = F.normalize(centers_2, dim=1, p=2) S = centers_1 @ centers_2.T S_diag = torch.diag_embed(torch.diag(S)) S = S - S_diag neg_contrastive = F.mse_loss(S, torch.zeros_like(S)) loss = pos_contrastive + cfg.alpha * neg_contrastive else: S = z1 @ z2.T loss = F.mse_loss(S, target) loss.backward(retain_graph=True) optimizer.step() self.loss_curve.append(loss.item()) self.logger.loss(epoch, loss) if epoch % 1 == 0: self.eval() z1, z2 = self.forward(smooth_fea) embedding = (z1 + z2) / 2 predict_labels, dis = init_clustering(embedding, self.cfg.dataset.n_clusters) if self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) self.nmi_curve.append(results['NMI']) if results['ACC'] > self.best_results['ACC']: self.best_embedding = embedding self.best_predicted_labels = predicted_labels self.best_results = results if not self.cfg.evaluate.each: embedding, predicted_labels, results = self.evaluate(data) return self.loss_curve, self.nmi_curve, embedding, predicted_labels, results return self.loss_curve, self.nmi_curve, self.best_embedding, self.best_predicted_labels, self.best_results
[docs] def get_embedding(self, data) -> Tensor: x = data.x.to(self.device) with torch.no_grad(): z1, z2 = self.forward(x) embedding = (z1 + z2) / 2 return embedding.detach()
[docs] def clustering(self, data) -> Tuple[Tensor, Tensor, Tensor]: embedding = self.get_embedding(data) labels_, clustering_centers_ = KMeansGPU(self.cfg.dataset.n_clusters).fit(embedding) return embedding, labels_, clustering_centers_
[docs] def evaluate(self, data): embedding, predicted_labels, clustering_centers = self.clustering(data) ground_truth = data.y.numpy() metric = DGCMetric(ground_truth, predicted_labels.numpy(), embedding, data.edge_index) results = metric.evaluate_one_epoch(self.logger, self.cfg.evaluate) return embedding, predicted_labels, results